From e843d7479aee27b871072e6712d0b7e326468d15 Mon Sep 17 00:00:00 2001 From: Stan Girard Date: Tue, 10 Sep 2024 19:08:45 +0200 Subject: [PATCH] tests(embeddings): added tests for embeddings --- backend/api/tests/settings/test_settings.py | 38 +++++++++++++++++++++ 1 file changed, 38 insertions(+) create mode 100644 backend/api/tests/settings/test_settings.py diff --git a/backend/api/tests/settings/test_settings.py b/backend/api/tests/settings/test_settings.py new file mode 100644 index 000000000000..9125590dc89d --- /dev/null +++ b/backend/api/tests/settings/test_settings.py @@ -0,0 +1,38 @@ +from unittest.mock import patch, MagicMock +from quivr_api.modules.dependencies import get_embedding_client +from langchain_community.embeddings.ollama import OllamaEmbeddings +from langchain_openai import AzureOpenAIEmbeddings + +def test_ollama_embedding(): + with patch("quivr_api.modules.dependencies.settings") as mock_settings: + mock_settings.ollama_api_base_url = "http://ollama.example.com" + mock_settings.azure_openai_embeddings_url = None + + embedding_client = get_embedding_client() + + assert isinstance(embedding_client, OllamaEmbeddings) + assert embedding_client.base_url == "http://ollama.example.com" + +def test_azure_embedding(): + with patch("quivr_api.modules.dependencies.settings") as mock_settings: + mock_settings.ollama_api_base_url = None + mock_settings.azure_openai_embeddings_url = "https://quivr-test.openai.azure.com/openai/deployments/embedding/embeddings?api-version=2023-05-15" + + embedding_client = get_embedding_client() + + assert isinstance(embedding_client, AzureOpenAIEmbeddings) + assert embedding_client.azure_endpoint == "https://quivr-test.openai.azure.com" + +def test_openai_embedding(): + with patch("quivr_api.modules.dependencies.settings") as mock_settings, \ + patch("quivr_api.modules.dependencies.OpenAIEmbeddings") as mock_openai_embeddings: + mock_settings.ollama_api_base_url = None + mock_settings.azure_openai_embeddings_url = None + + # Create a mock instance for OpenAIEmbeddings + mock_openai_instance = MagicMock() + mock_openai_embeddings.return_value = mock_openai_instance + + embedding_client = get_embedding_client() + + assert embedding_client == mock_openai_instance