diff --git a/libs/milvus/langchain_milvus/vectorstores/milvus.py b/libs/milvus/langchain_milvus/vectorstores/milvus.py index 636eb23..61fa109 100644 --- a/libs/milvus/langchain_milvus/vectorstores/milvus.py +++ b/libs/milvus/langchain_milvus/vectorstores/milvus.py @@ -19,7 +19,19 @@ from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from langchain_core.vectorstores import VectorStore -from pymilvus import MilvusClient, RRFRanker, WeightedRanker +from pymilvus import ( + Collection, + CollectionSchema, + DataType, + FieldSchema, + MilvusClient, + MilvusException, + RRFRanker, + WeightedRanker, + utility, +) +from pymilvus.client.types import LoadState # type: ignore +from pymilvus.orm.types import infer_dtype_bydata # type: ignore from langchain_milvus import MilvusCollectionHybridSearchRetriever from langchain_milvus.utils.sparse import BaseSparseEmbedding @@ -280,14 +292,6 @@ def __init__( metadata_schema: Optional[dict[str, Any]] = None, ): """Initialize the Milvus vector store.""" - try: - from pymilvus import Collection, MilvusClient, utility - except ImportError: - raise ValueError( - "Could not import pymilvus python package. " - "Please install it with `pip install pymilvus`." - ) - # Default search params when one is not provided. self.default_search_params = { "FLAT": {"metric_type": "L2", "params": {}}, @@ -451,15 +455,6 @@ def _init( def _create_collection( self, embeddings: List[list], metadatas: Optional[list[dict]] = None ) -> None: - from pymilvus import ( - Collection, - CollectionSchema, - DataType, - FieldSchema, - MilvusException, - ) - from pymilvus.orm.types import infer_dtype_bydata # type: ignore - fields = [] vector_fields: List[str] = self._as_list(self._vector_field) # If enable_dynamic_field, we don't need to create fields, and just pass it. @@ -632,8 +627,6 @@ def _create_collection( raise e def _get_field_schema_from_dict(self, field_name: str, schema_dict: dict): # type: ignore[no-untyped-def] - from pymilvus import FieldSchema - assert "dtype" in schema_dict, ( f"Please provide `dtype` in the schema dict. " f"Existing keys are: {schema_dict.keys()}" @@ -645,8 +638,6 @@ def _get_field_schema_from_dict(self, field_name: str, schema_dict: dict): # ty def _extract_fields(self) -> None: """Grab the existing fields from the Collection""" - from pymilvus import Collection - if isinstance(self.col, Collection): schema = self.col.schema for x in schema.fields: @@ -654,8 +645,6 @@ def _extract_fields(self) -> None: def _get_index(self, field_name: Optional[str] = None) -> Optional[dict[str, Any]]: """Return the vector index information if it exists""" - from pymilvus import Collection - if not self._is_multi_vector: field_name: str = field_name or self._vector_field # type: ignore @@ -667,8 +656,6 @@ def _get_index(self, field_name: Optional[str] = None) -> Optional[dict[str, Any def _create_index(self) -> None: """Create an index on the collection""" - from pymilvus import Collection, MilvusException - if isinstance(self.col, Collection) and self._get_index() is None: embeddings_functions: List[EmbeddingType] = self._as_list( self.embedding_func @@ -740,8 +727,6 @@ def _create_search_params(self) -> None: """Generate search params based on the current index type""" import copy - from pymilvus import Collection - if isinstance(self.col, Collection) and self.search_params is None: vector_fields: List[str] = self._as_list(self._vector_field) search_params_list: List[dict] = [] @@ -768,9 +753,6 @@ def _load( timeout: Optional[float] = None, ) -> None: """Load the collection if available.""" - from pymilvus import Collection, utility - from pymilvus.client.types import LoadState # type: ignore - timeout = self.timeout or timeout if ( isinstance(self.col, Collection) @@ -934,7 +916,6 @@ def add_embeddings( Returns: List[str]: The resulting keys for each inserted element. """ - from pymilvus import Collection, MilvusException if not self._is_multi_vector: embeddings = [[embedding] for embedding in embeddings] # type: ignore @@ -1585,8 +1566,6 @@ def get_pks(self, expr: str, **kwargs: Any) -> List[int] | None: List[int]: List of IDs (Primary Keys) """ - from pymilvus import MilvusException - if self.col is None: logger.debug("No existing collection to get pk.") return None @@ -1617,8 +1596,6 @@ def upsert( # type: ignore List[str]: IDs of the added texts. """ - from pymilvus import MilvusException - if documents is None or len(documents) == 0: logger.debug("No documents to upsert.") return None diff --git a/libs/milvus/langchain_milvus/vectorstores/zilliz.py b/libs/milvus/langchain_milvus/vectorstores/zilliz.py index 3122131..f68d5d3 100644 --- a/libs/milvus/langchain_milvus/vectorstores/zilliz.py +++ b/libs/milvus/langchain_milvus/vectorstores/zilliz.py @@ -3,6 +3,8 @@ import logging from typing import List, Optional, Union, cast +from pymilvus import Collection, MilvusException + from langchain_milvus.vectorstores.milvus import EmbeddingType, Milvus logger = logging.getLogger(__name__) @@ -73,7 +75,6 @@ class Zilliz(Milvus): def _create_index(self) -> None: """Create an index on the collection""" - from pymilvus import Collection, MilvusException self.index_params = cast(Optional[Union[dict, List[dict]]], self.index_params) # type: ignore