Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine functions #20

Merged
merged 1 commit into from
Nov 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions libs/milvus/langchain_milvus/vectorstores/milvus.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from pymilvus import RRFRanker, WeightedRanker
from pymilvus import MilvusClient, RRFRanker, WeightedRanker

from langchain_milvus import MilvusCollectionHybridSearchRetriever
from langchain_milvus.utils.sparse import BaseSparseEmbedding
Expand Down Expand Up @@ -409,7 +409,7 @@ def embeddings(self) -> Union[EmbeddingType, List[EmbeddingType]]: # type: igno
return self.embedding_func

@property
def client(self) -> Any:
def client(self) -> MilvusClient:
"""Get client."""
return self._milvus_client

Expand All @@ -419,15 +419,11 @@ def _is_multi_vector(self) -> bool:

@property
def _is_sparse(self) -> bool:
if self.index_params is None:
return False
indexes_params = self._as_list(self.index_params)
if len(indexes_params) > 1:
return False
index_type = indexes_params[0]["index_type"]
if "SPARSE" in index_type:
embedding_func: List[EmbeddingType] = self._as_list(self.embedding_func)
if self._is_sparse_embedding(embedding_func[0]):
return True
return False
else:
return False

@staticmethod
def _is_sparse_embedding(embeddings_function: EmbeddingType) -> bool:
Expand Down Expand Up @@ -1396,12 +1392,20 @@ def _select_relevance_score_fn(self) -> Callable[[float], float]:
- etc.

"""
if self.index_params is None:
raise ValueError("No index params provided.")
if not self.col or not self.col.indexes:
raise ValueError(
"No index params provided. Could not determine relevance function."
)
if self._is_multi_vector:
raise ValueError("No supported normalization function for multi vectors.")
raise ValueError(
"No supported normalization function for multi vectors. "
"Could not determine relevance function."
)
if self._is_sparse:
raise ValueError("No supported normalization function for sparse indexes.")
raise ValueError(
"No supported normalization function for sparse indexes. "
"Could not determine relevance function."
)

def _map_l2_to_similarity(l2_distance: float) -> float:
"""Return a similarity score on a scale [0, 1].
Expand All @@ -1423,6 +1427,12 @@ def _map_ip_to_similarity(ip_score: float) -> float:
"""
return (ip_score + 1) / 2.0

if self.index_params is None:
logger.warning(
"No index params provided. Could not determine relevance function. "
"Use L2 distance as default."
)
return _map_l2_to_similarity
indexes_params = self._as_list(self.index_params)
metric_type = indexes_params[0]["metric_type"]
if metric_type == "L2":
Expand Down
Loading