diff --git a/python/.coveragerc b/python/.coveragerc index dc37f315b86e..b51952a0c8e8 100644 --- a/python/.coveragerc +++ b/python/.coveragerc @@ -7,7 +7,7 @@ omit = semantic_kernel/connectors/memory/azure_cosmosdb_no_sql/* semantic_kernel/connectors/memory/chroma/* semantic_kernel/connectors/memory/milvus/* - semantic_kernel/connectors/memory/mongodb_atlas/* + semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_memory_store.py semantic_kernel/connectors/memory/pinecone/* semantic_kernel/connectors/memory/postgres/* semantic_kernel/connectors/memory/qdrant/qdrant_memory_store.py diff --git a/python/semantic_kernel/connectors/ai/mistral_ai/prompt_execution_settings/mistral_ai_prompt_execution_settings.py b/python/semantic_kernel/connectors/ai/mistral_ai/prompt_execution_settings/mistral_ai_prompt_execution_settings.py index ea709172950f..8153db93577f 100644 --- a/python/semantic_kernel/connectors/ai/mistral_ai/prompt_execution_settings/mistral_ai_prompt_execution_settings.py +++ b/python/semantic_kernel/connectors/ai/mistral_ai/prompt_execution_settings/mistral_ai_prompt_execution_settings.py @@ -1,17 +1,10 @@ # Copyright (c) Microsoft. All rights reserved. import logging -import sys from typing import Annotated, Any, Literal from mistralai import utils - -if sys.version_info >= (3, 11): - pass # pragma: no cover -else: - pass # pragma: no cover - -from pydantic import Field, field_validator +from pydantic import Field from semantic_kernel.connectors.ai.prompt_execution_settings import PromptExecutionSettings @@ -29,7 +22,14 @@ class MistralAIChatPromptExecutionSettings(MistralAIPromptExecutionSettings): response_format: dict[Literal["type"], Literal["text", "json_object"]] | None = None messages: list[dict[str, Any]] | None = None - safe_mode: Annotated[bool, Field(exclude=True)] = False + safe_mode: Annotated[ + bool, + Field( + exclude=True, + deprecated="The 'safe_mode' setting is no longer supported and is being ignored, " + "it will be removed in the Future.", + ), + ] = False safe_prompt: bool = False max_tokens: Annotated[int | None, Field(gt=0)] = None seed: int | None = None @@ -56,12 +56,3 @@ class MistralAIChatPromptExecutionSettings(MistralAIPromptExecutionSettings): "on the function choice configuration.", ), ] = None - - @field_validator("safe_mode") - @classmethod - def check_safe_mode(cls, v: bool) -> bool: - """The safe_mode setting is no longer supported.""" - logger.warning( - "The 'safe_mode' setting is no longer supported and is being ignored, it will be removed in the Future." - ) - return v diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/const.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/const.py new file mode 100644 index 000000000000..398bdc41ef72 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/const.py @@ -0,0 +1,14 @@ +# Copyright (c) Microsoft. All rights reserved. + +from typing import Final + +from semantic_kernel.data.const import DistanceFunction + +DISTANCE_FUNCTION_MAPPING: Final[dict[DistanceFunction, str]] = { + DistanceFunction.EUCLIDEAN_DISTANCE: "euclidean", + DistanceFunction.COSINE_SIMILARITY: "cosine", + DistanceFunction.DOT_PROD: "dotProduct", +} +MONGODB_ID_FIELD: Final[str] = "_id" +DEFAULT_DB_NAME = "default" +DEFAULT_SEARCH_INDEX_NAME = "default" diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py new file mode 100644 index 000000000000..bbd524019510 --- /dev/null +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_collection.py @@ -0,0 +1,276 @@ +# Copyright (c) Microsoft. All rights reserved. + +import logging +import sys +from collections.abc import Sequence +from importlib import metadata +from typing import Any, ClassVar, Generic, TypeVar + +if sys.version_info >= (3, 12): + from typing import override # pragma: no cover +else: + from typing_extensions import override # pragma: no cover + +from pydantic import ValidationError +from pymongo import AsyncMongoClient +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.driver_info import DriverInfo + +from semantic_kernel.connectors.memory.mongodb_atlas.const import ( + DEFAULT_DB_NAME, + DEFAULT_SEARCH_INDEX_NAME, + MONGODB_ID_FIELD, +) +from semantic_kernel.connectors.memory.mongodb_atlas.utils import create_index_definition +from semantic_kernel.data.filter_clauses import AnyTagsEqualTo, EqualTo +from semantic_kernel.data.kernel_search_results import KernelSearchResults +from semantic_kernel.data.record_definition import VectorStoreRecordDefinition +from semantic_kernel.data.vector_search import ( + VectorSearchFilter, + VectorSearchOptions, +) +from semantic_kernel.data.vector_search.vector_search import VectorSearchBase +from semantic_kernel.data.vector_search.vector_search_result import VectorSearchResult +from semantic_kernel.data.vector_search.vectorized_search import VectorizedSearchMixin +from semantic_kernel.exceptions import ( + VectorSearchExecutionException, + VectorStoreInitializationException, + VectorStoreOperationException, +) +from semantic_kernel.utils.experimental_decorator import experimental_class + +logger: logging.Logger = logging.getLogger(__name__) + +TModel = TypeVar("TModel") + + +@experimental_class +class MongoDBAtlasCollection( + VectorSearchBase[str, TModel], + VectorizedSearchMixin[TModel], + Generic[TModel], +): + """MongoDB Atlas collection implementation.""" + + mongo_client: AsyncMongoClient + database_name: str + index_name: str + supported_key_types: ClassVar[list[str] | None] = ["str"] + supported_vector_types: ClassVar[list[str] | None] = ["float", "int"] + + def __init__( + self, + collection_name: str, + data_model_type: type[TModel], + data_model_definition: VectorStoreRecordDefinition | None = None, + index_name: str | None = None, + mongo_client: AsyncMongoClient | None = None, + **kwargs: Any, + ) -> None: + """Initializes a new instance of the MongoDBAtlasCollection class. + + Args: + data_model_type: The type of the data model. + data_model_definition: The model definition, optional. + collection_name: The name of the collection, optional. + mongo_client: The MongoDB client for interacting with MongoDB Atlas, + used for creating and deleting collections. + index_name: The name of the index to use for searching, when not passed, will use _idx. + **kwargs: Additional keyword arguments, including: + The same keyword arguments used for MongoDBAtlasStore: + database_name: The name of the database, will be filled from the env when this is not set. + connection_string: str | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None + + """ + managed_client = not mongo_client + if mongo_client: + super().__init__( + data_model_type=data_model_type, + data_model_definition=data_model_definition, + mongo_client=mongo_client, + collection_name=collection_name, + database_name=kwargs.get("database_name", DEFAULT_DB_NAME), + index_name=index_name or DEFAULT_SEARCH_INDEX_NAME, + managed_client=managed_client, + ) + return + + from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_settings import MongoDBAtlasSettings + + try: + mongodb_atlas_settings = MongoDBAtlasSettings.create( + env_file_path=kwargs.get("env_file_path"), + env_file_encoding=kwargs.get("env_file_encoding"), + connection_string=kwargs.get("connection_string"), + database_name=kwargs.get("database_name"), + index_name=index_name, + ) + except ValidationError as exc: + raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc + if not mongo_client: + mongo_client = AsyncMongoClient( + mongodb_atlas_settings.connection_string.get_secret_value(), + driver=DriverInfo("Microsoft Semantic Kernel", metadata.version("semantic-kernel")), + ) + + super().__init__( + data_model_type=data_model_type, + data_model_definition=data_model_definition, + collection_name=collection_name, + mongo_client=mongo_client, + managed_client=managed_client, + database_name=mongodb_atlas_settings.database_name, + index_name=mongodb_atlas_settings.index_name, + ) + + def _get_database(self) -> AsyncDatabase: + """Get the database. + + If you need control over things like read preference, you can override this method. + """ + return self.mongo_client.get_database(self.database_name) + + def _get_collection(self) -> AsyncCollection: + """Get the collection. + + If you need control over things like read preference, you can override this method. + """ + return self.mongo_client.get_database(self.database_name).get_collection(self.collection_name) + + @override + async def _inner_upsert( + self, + records: Sequence[Any], + **kwargs: Any, + ) -> Sequence[str]: + result = await self._get_collection().update_many(update=records, upsert=True, **kwargs) + return [str(ids) for ids in result.upserted_id] + + @override + async def _inner_get(self, keys: Sequence[str], **kwargs: Any) -> Sequence[dict[str, Any]]: + result = self._get_collection().find({MONGODB_ID_FIELD: {"$in": keys}}) + return await result.to_list(length=len(keys)) + + @override + async def _inner_delete(self, keys: Sequence[str], **kwargs: Any) -> None: + collection = self._get_collection() + await collection.delete_many({MONGODB_ID_FIELD: {"$in": keys}}) + + def _replace_key_field(self, record: dict[str, Any]) -> dict[str, Any]: + if self._key_field_name == MONGODB_ID_FIELD: + return record + return { + MONGODB_ID_FIELD: record.pop(self._key_field_name, None), + **record, + } + + def _reset_key_field(self, record: dict[str, Any]) -> dict[str, Any]: + if self._key_field_name == MONGODB_ID_FIELD: + return record + return { + self._key_field_name: record.pop(MONGODB_ID_FIELD, None), + **record, + } + + @override + def _serialize_dicts_to_store_models(self, records: Sequence[dict[str, Any]], **kwargs: Any) -> Sequence[Any]: + return [self._replace_key_field(record) for record in records] + + @override + def _deserialize_store_models_to_dicts(self, records: Sequence[Any], **kwargs: Any) -> Sequence[dict[str, Any]]: + return [self._reset_key_field(record) for record in records] + + @override + async def create_collection(self, **kwargs) -> None: + """Create a new collection in MongoDB Atlas. + + This first creates a collection, with the kwargs. + Then creates a search index based on the data model definition. + + Args: + **kwargs: Additional keyword arguments. + """ + collection = await self._get_database().create_collection(self.collection_name, **kwargs) + await collection.create_search_index(create_index_definition(self.data_model_definition, self.index_name)) + + @override + async def does_collection_exist(self, **kwargs) -> bool: + return bool(await self._get_database().list_collection_names(filter={"name": self.collection_name})) + + @override + async def delete_collection(self, **kwargs) -> None: + await self._get_database().drop_collection(self.collection_name, **kwargs) + + @override + async def _inner_search( + self, + options: VectorSearchOptions, + search_text: str | None = None, + vectorizable_text: str | None = None, + vector: list[float | int] | None = None, + **kwargs: Any, + ) -> KernelSearchResults[VectorSearchResult[TModel]]: + collection = self._get_collection() + vector_search_query: dict[str, Any] = { + "limit": options.top + options.skip, + "index": self.index_name, + } + if options.filter.filters: + vector_search_query["filter"] = self._build_filter_dict(options.filter) + if vector is not None: + vector_search_query["queryVector"] = vector + vector_search_query["path"] = options.vector_field_name + if "queryVector" not in vector_search_query: + raise VectorStoreOperationException("Vector is required for search.") + + projection_query: dict[str, int | dict] = { + field: 1 + for field in self.data_model_definition.get_field_names( + include_vector_fields=options.include_vectors, + include_key_field=False, # _id is always included + ) + } + projection_query["score"] = {"$meta": "vectorSearchScore"} + try: + raw_results = await collection.aggregate([ + {"$vectorSearch": vector_search_query}, + {"$project": projection_query}, + ]) + except Exception as exc: + raise VectorSearchExecutionException("Failed to search the collection.") from exc + return KernelSearchResults( + results=self._get_vector_search_results_from_results(raw_results, options), + total_count=None, # no way to get a count before looping through the result cursor + ) + + def _build_filter_dict(self, search_filter: VectorSearchFilter) -> dict[str, Any]: + """Create the filter dictionary based on the filters.""" + filter_dict = {} + for filter in search_filter.filters: + if isinstance(filter, EqualTo): + filter_dict[filter.field_name] = filter.value + elif isinstance(filter, AnyTagsEqualTo): + filter_dict[filter.field_name] = {"$in": filter.value} + return filter_dict + + @override + def _get_record_from_result(self, result: dict[str, Any]) -> dict[str, Any]: + return result + + @override + def _get_score_from_result(self, result: dict[str, Any]) -> float | None: + return result.get("score") + + @override + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + """Exit the context manager.""" + if self.managed_client: + await self.mongo_client.close() + + async def __aenter__(self) -> "MongoDBAtlasCollection": + """Enter the context manager.""" + await self.mongo_client.aconnect() + return self diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py index 0eec5591d15f..11a21183fcf2 100644 --- a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_settings.py @@ -4,7 +4,7 @@ from pydantic import SecretStr -from semantic_kernel.connectors.memory.mongodb_atlas.utils import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME +from semantic_kernel.connectors.memory.mongodb_atlas.const import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME from semantic_kernel.kernel_pydantic import KernelBaseSettings from semantic_kernel.utils.experimental_decorator import experimental_class @@ -16,6 +16,10 @@ class MongoDBAtlasSettings(KernelBaseSettings): Args: - connection_string: str - MongoDB Atlas connection string (Env var MONGODB_ATLAS_CONNECTION_STRING) + - database_name: str - MongoDB Atlas database name, defaults to 'default' + (Env var MONGODB_ATLAS_DATABASE_NAME) + - index_name: str - MongoDB Atlas search index name, defaults to 'default' + (Env var MONGODB_ATLAS_INDEX_NAME) """ env_prefix: ClassVar[str] = "MONGODB_ATLAS_" diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_store.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_store.py new file mode 100644 index 000000000000..b2072b9fba3c --- /dev/null +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/mongodb_atlas_store.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft. All rights reserved. + +import logging +import sys +from importlib import metadata +from typing import TYPE_CHECKING, Any, TypeVar + +if sys.version_info >= (3, 12): + from typing import override # pragma: no cover +else: + from typing_extensions import override # pragma: no cover + +from pydantic import ValidationError +from pymongo import AsyncMongoClient +from pymongo.asynchronous.database import AsyncDatabase +from pymongo.driver_info import DriverInfo + +from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import ( + MongoDBAtlasCollection, +) +from semantic_kernel.data.record_definition import VectorStoreRecordDefinition +from semantic_kernel.data.vector_storage import VectorStore +from semantic_kernel.exceptions import VectorStoreInitializationException +from semantic_kernel.utils.experimental_decorator import experimental_class + +if TYPE_CHECKING: + from semantic_kernel.data import VectorStoreRecordCollection + + +logger: logging.Logger = logging.getLogger(__name__) + +TModel = TypeVar("TModel") + + +@experimental_class +class MongoDBAtlasStore(VectorStore): + """MongoDB Atlas store implementation.""" + + mongo_client: AsyncMongoClient + database_name: str + + def __init__( + self, + connection_string: str | None = None, + database_name: str | None = None, + mongo_client: AsyncMongoClient | None = None, + env_file_path: str | None = None, + env_file_encoding: str | None = None, + ) -> None: + """Initializes a new instance of the MongoDBAtlasStore client. + + Args: + connection_string (str): The connection string for MongoDB Atlas, optional. + Can be read from environment variables. + database_name (str): The name of the database, optional. Can be read from environment variables. + mongo_client (MongoClient): The MongoDB client, optional. + env_file_path (str): Use the environment settings file as a fallback + to environment variables. + env_file_encoding (str): The encoding of the environment settings file. + + """ + from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_settings import ( + MongoDBAtlasSettings, + ) + + if mongo_client and database_name: + super().__init__( + mongo_client=mongo_client, + managed_client=False, + database_name=database_name, + ) + managed_client: bool = False + try: + mongodb_atlas_settings = MongoDBAtlasSettings.create( + env_file_path=env_file_path, + connection_string=connection_string, + database_name=database_name, + env_file_encoding=env_file_encoding, + ) + except ValidationError as exc: + raise VectorStoreInitializationException("Failed to create MongoDB Atlas settings.") from exc + if not mongo_client: + mongo_client = AsyncMongoClient( + mongodb_atlas_settings.connection_string.get_secret_value(), + driver=DriverInfo("Microsoft Semantic Kernel", metadata.version("semantic-kernel")), + ) + managed_client = True + + super().__init__( + mongo_client=mongo_client, + managed_client=managed_client, + database_name=mongodb_atlas_settings.database_name, + ) + + @override + def get_collection( + self, + collection_name: str, + data_model_type: type[TModel], + data_model_definition: VectorStoreRecordDefinition | None = None, + **kwargs: Any, + ) -> "VectorStoreRecordCollection": + """Get a MongoDBAtlasCollection tied to a collection. + + Args: + collection_name (str): The name of the collection. + data_model_type (type[TModel]): The type of the data model. + data_model_definition (VectorStoreRecordDefinition | None): The model fields, optional. + **kwargs: Additional keyword arguments, passed to the collection constructor. + """ + if collection_name not in self.vector_record_collections: + self.vector_record_collections[collection_name] = MongoDBAtlasCollection( + data_model_type=data_model_type, + data_model_definition=data_model_definition, + mongo_client=self.mongo_client, + collection_name=collection_name, + database_name=self.database_name, + **kwargs, + ) + return self.vector_record_collections[collection_name] + + @override + async def list_collection_names(self, **kwargs: Any) -> list[str]: + database: AsyncDatabase = self.mongo_client.get_database(self.database_name) + return await database.list_collection_names() + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + """Exit the context manager.""" + if self.managed_client: + await self.mongo_client.close() + + async def __aenter__(self) -> "MongoDBAtlasStore": + """Enter the context manager.""" + await self.mongo_client.aconnect() + return self diff --git a/python/semantic_kernel/connectors/memory/mongodb_atlas/utils.py b/python/semantic_kernel/connectors/memory/mongodb_atlas/utils.py index cb415f45377c..f05b94b45782 100644 --- a/python/semantic_kernel/connectors/memory/mongodb_atlas/utils.py +++ b/python/semantic_kernel/connectors/memory/mongodb_atlas/utils.py @@ -1,11 +1,17 @@ # Copyright (c) Microsoft. All rights reserved. from numpy import array +from pymongo.operations import SearchIndexModel +from semantic_kernel.connectors.memory.mongodb_atlas.const import DISTANCE_FUNCTION_MAPPING +from semantic_kernel.data.record_definition.vector_store_model_definition import VectorStoreRecordDefinition +from semantic_kernel.data.record_definition.vector_store_record_fields import ( + VectorStoreRecordDataField, + VectorStoreRecordVectorField, +) +from semantic_kernel.exceptions.service_exceptions import ServiceInitializationError from semantic_kernel.memory.memory_record import MemoryRecord -DEFAULT_DB_NAME = "default" -DEFAULT_SEARCH_INDEX_NAME = "default" NUM_CANDIDATES_SCALAR = 10 MONGODB_FIELD_ID = "_id" @@ -66,3 +72,44 @@ def memory_record_to_mongo_document(record: MemoryRecord) -> dict: MONGODB_FIELD_EMBEDDING: record._embedding.tolist(), MONGODB_FIELD_TIMESTAMP: record._timestamp, } + + +def create_vector_field(field: VectorStoreRecordVectorField) -> dict: + """Create a vector field. + + Args: + field (VectorStoreRecordVectorField): The vector field. + + Returns: + dict: The vector field. + """ + if field.distance_function not in DISTANCE_FUNCTION_MAPPING: + raise ServiceInitializationError(f"Invalid distance function: {field.distance_function}") + return { + "type": "vector", + "numDimensions": field.dimensions, + "path": field.name, + "similarity": DISTANCE_FUNCTION_MAPPING[field.distance_function], + } + + +def create_index_definition(record_definition: VectorStoreRecordDefinition, index_name: str) -> SearchIndexModel: + """Create an index definition. + + Args: + record_definition (VectorStoreRecordDefinition): The record definition. + index_name (str): The index name. + + Returns: + SearchIndexModel: The index definition. + """ + vector_fields = [create_vector_field(field) for field in record_definition.vector_fields] + data_fields = [ + {"path": field.name, "type": "filter"} + for field in record_definition.fields + if isinstance(field, VectorStoreRecordDataField) and (field.is_filterable or field.is_full_text_searchable) + ] + key_field = [{"path": record_definition.key_field.name, "type": "filter"}] + return SearchIndexModel( + type="vectorSearch", name=index_name, definition={"fields": vector_fields + data_fields + key_field} + ) diff --git a/python/tests/conftest.py b/python/tests/conftest.py index 3be0430a8ad4..35d00f624a73 100644 --- a/python/tests/conftest.py +++ b/python/tests/conftest.py @@ -367,6 +367,28 @@ def azure_ai_search_unit_test_env(monkeypatch, exclude_list, override_env_param_ return env_vars +@fixture() +def mongodb_atlas_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): + """Fixture to set environment variables for MongoDB Atlas Unit Tests.""" + if exclude_list is None: + exclude_list = [] + + if override_env_param_dict is None: + override_env_param_dict = {} + + env_vars = {"MONGODB_ATLAS_CONNECTION_STRING": "mongodb://test", "MONGODB_ATLAS_DATABASE_NAME": "test-database"} + + env_vars.update(override_env_param_dict) + + for key, value in env_vars.items(): + if key not in exclude_list: + monkeypatch.setenv(key, value) + else: + monkeypatch.delenv(key, raising=False) + + return env_vars + + @fixture() def bing_unit_test_env(monkeypatch, exclude_list, override_env_param_dict): """Fixture to set environment variables for BingConnector.""" diff --git a/python/tests/unit/connectors/memory/mongodb_atlas/conftest.py b/python/tests/unit/connectors/memory/mongodb_atlas/conftest.py new file mode 100644 index 000000000000..23f637104710 --- /dev/null +++ b/python/tests/unit/connectors/memory/mongodb_atlas/conftest.py @@ -0,0 +1,37 @@ +# Copyright (c) Microsoft. All rights reserved. + + +from unittest.mock import patch + +import pytest +from pymongo import AsyncMongoClient +from pymongo.asynchronous.collection import AsyncCollection +from pymongo.asynchronous.database import AsyncDatabase + +BASE_PATH = "pymongo.asynchronous.mongo_client.AsyncMongoClient" +DATABASE_PATH = "pymongo.asynchronous.database.AsyncDatabase" +COLLECTION_PATH = "pymongo.asynchronous.collection.AsyncCollection" + + +@pytest.fixture(autouse=True) +def mock_mongo_client(): + with patch(BASE_PATH, spec=AsyncMongoClient) as mock: + yield mock + + +@pytest.fixture(autouse=True) +def mock_get_database(mock_mongo_client): + with ( + patch(DATABASE_PATH, spec=AsyncDatabase) as mock_db, + patch.object(mock_mongo_client, "get_database", new_callable=lambda: mock_db) as mock, + ): + yield mock + + +@pytest.fixture(autouse=True) +def mock_get_collection(mock_get_database): + with ( + patch(COLLECTION_PATH, spec=AsyncCollection) as mock_collection, + patch.object(mock_get_database, "get_collection", new_callable=lambda: mock_collection) as mock, + ): + yield mock diff --git a/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py b/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py new file mode 100644 index 000000000000..00afe491e2a3 --- /dev/null +++ b/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_collection.py @@ -0,0 +1,96 @@ +# Copyright (c) Microsoft. All rights reserved. + +from unittest.mock import AsyncMock, patch + +from pymongo import AsyncMongoClient +from pymongo.asynchronous.cursor import AsyncCursor +from pymongo.results import UpdateResult +from pytest import mark, raises + +from semantic_kernel.connectors.memory.mongodb_atlas.const import DEFAULT_DB_NAME, DEFAULT_SEARCH_INDEX_NAME +from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import MongoDBAtlasCollection +from semantic_kernel.exceptions.vector_store_exceptions import VectorStoreInitializationException + + +def test_mongodb_atlas_collection_initialization(mongodb_atlas_unit_test_env, data_model_definition, mock_mongo_client): + collection = MongoDBAtlasCollection( + data_model_type=dict, + data_model_definition=data_model_definition, + collection_name="test_collection", + mongo_client=mock_mongo_client, + ) + assert collection.mongo_client is not None + assert isinstance(collection.mongo_client, AsyncMongoClient) + + +@mark.parametrize("exclude_list", [["MONGODB_ATLAS_CONNECTION_STRING"]], indirect=True) +def test_mongodb_atlas_collection_initialization_fail(mongodb_atlas_unit_test_env, data_model_definition): + with raises(VectorStoreInitializationException): + MongoDBAtlasCollection( + collection_name="test_collection", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + + +@mark.parametrize("exclude_list", [["MONGODB_ATLAS_DATABASE_NAME", "MONGODB_ATLAS_INDEX_NAME"]], indirect=True) +def test_mongodb_atlas_collection_initialization_defaults(mongodb_atlas_unit_test_env, data_model_definition): + collection = MongoDBAtlasCollection( + collection_name="test_collection", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + assert collection.database_name == DEFAULT_DB_NAME + assert collection.index_name == DEFAULT_SEARCH_INDEX_NAME + + +async def test_mongodb_atlas_collection_upsert(mongodb_atlas_unit_test_env, data_model_definition, mock_get_collection): + collection = MongoDBAtlasCollection( + data_model_type=dict, + data_model_definition=data_model_definition, + collection_name="test_collection", + ) + with patch.object(collection, "_get_collection", new=mock_get_collection) as mock_get: + result_mock = AsyncMock(spec=UpdateResult) + result_mock.upserted_id = ["test_id"] + mock_get.return_value.update_many.return_value = result_mock + result = await collection._inner_upsert([{"_id": "test_id", "data": "test_data"}]) + assert result == ["test_id"] + + +async def test_mongodb_atlas_collection_get(mongodb_atlas_unit_test_env, data_model_definition, mock_get_collection): + collection = MongoDBAtlasCollection( + data_model_type=dict, + data_model_definition=data_model_definition, + collection_name="test_collection", + ) + with patch.object(collection, "_get_collection", new=mock_get_collection) as mock_get: + result_mock = AsyncMock(spec=AsyncCursor) + result_mock.to_list.return_value = [{"_id": "test_id", "data": "test_data"}] + mock_get.return_value.find.return_value = result_mock + result = await collection._inner_get(["test_id"]) + assert result == [{"_id": "test_id", "data": "test_data"}] + + +async def test_mongodb_atlas_collection_delete(mongodb_atlas_unit_test_env, data_model_definition, mock_get_collection): + collection = MongoDBAtlasCollection( + data_model_type=dict, + data_model_definition=data_model_definition, + collection_name="test_collection", + ) + with patch.object(collection, "_get_collection", new=mock_get_collection) as mock_get: + await collection._inner_delete(["test_id"]) + mock_get.return_value.delete_many.assert_called_with({"_id": {"$in": ["test_id"]}}) + + +async def test_mongodb_atlas_collection_collection_exists( + mongodb_atlas_unit_test_env, data_model_definition, mock_get_database +): + collection = MongoDBAtlasCollection( + data_model_type=dict, + data_model_definition=data_model_definition, + collection_name="test_collection", + ) + with patch.object(collection, "_get_database", new=mock_get_database) as mock_get: + mock_get.return_value.list_collection_names.return_value = ["test_collection"] + assert await collection.does_collection_exist() diff --git a/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_store.py b/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_store.py new file mode 100644 index 000000000000..a06e68a99699 --- /dev/null +++ b/python/tests/unit/connectors/memory/mongodb_atlas/test_mongodb_atlas_store.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft. All rights reserved. + + +from pymongo import AsyncMongoClient + +from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_collection import MongoDBAtlasCollection +from semantic_kernel.connectors.memory.mongodb_atlas.mongodb_atlas_store import MongoDBAtlasStore + + +def test_mongodb_atlas_store_initialization(mongodb_atlas_unit_test_env): + store = MongoDBAtlasStore() + assert store.mongo_client is not None + assert isinstance(store.mongo_client, AsyncMongoClient) + + +def test_mongodb_atlas_store_get_collection(mongodb_atlas_unit_test_env, data_model_definition): + store = MongoDBAtlasStore() + collection = store.get_collection( + collection_name="test_collection", + data_model_type=dict, + data_model_definition=data_model_definition, + ) + assert collection is not None + assert isinstance(collection, MongoDBAtlasCollection) + + +async def test_mongodb_atlas_store_list_collection_names(mongodb_atlas_unit_test_env, mock_mongo_client): + store = MongoDBAtlasStore(mongo_client=mock_mongo_client, database_name="test_db") + store.mongo_client.get_database().list_collection_names.return_value = ["test_collection"] + result = await store.list_collection_names() + assert result == ["test_collection"]