From ccdd10315c3eaddda785b365d4ed84236166bada Mon Sep 17 00:00:00 2001 From: iceyao Date: Tue, 15 Oct 2024 20:07:18 +0800 Subject: [PATCH] feat: Enable baiduvector intergration test --- .../vdb/__mock/baiduvectordb.py | 71 +++++++++++-------- .../integration_tests/vdb/baidu/test_baidu.py | 3 - dev/pytest/pytest_vdb.sh | 3 +- 3 files changed, 43 insertions(+), 34 deletions(-) diff --git a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py index a8eaf42b7de1de..f20833ea5666de 100644 --- a/api/tests/integration_tests/vdb/__mock/baiduvectordb.py +++ b/api/tests/integration_tests/vdb/__mock/baiduvectordb.py @@ -1,4 +1,5 @@ import os +from unittest.mock import MagicMock import pytest from _pytest.monkeypatch import MonkeyPatch @@ -10,26 +11,31 @@ from requests.adapters import HTTPAdapter +class AttrDict(dict): + def __getattr__(self, item): + return self.get(item) + + class MockBaiduVectorDBClass: def mock_vector_db_client( self, config=None, adapter: HTTPAdapter = None, ): - self._conn = None - self._config = None + self.conn = MagicMock() + self._config = MagicMock() def list_databases(self, config=None) -> list[Database]: return [ Database( - conn=self._conn, + conn=self.conn, database_name="dify", config=self._config, ) ] def create_database(self, database_name: str, config=None) -> Database: - return Database(conn=self._conn, database_name=database_name, config=config) + return Database(conn=self.conn, database_name=database_name, config=config) def list_table(self, config=None) -> list[Table]: return [] @@ -88,16 +94,18 @@ def query( read_consistency=ReadConsistency.EVENTUAL, config=None, ): - return { - "row": { - "id": "doc_id_001", - "vector": [0.23432432, 0.8923744, 0.89238432], - "text": "text", - "metadata": {"doc_id": "doc_id_001"}, - }, - "code": 0, - "msg": "Success", - } + return AttrDict( + { + "row": { + "id": primary_key.get("id"), + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": '{"doc_id": "doc_id_001"}', + }, + "code": 0, + "msg": "Success", + } + ) def delete(self, primary_key=None, partition_key=None, filter=None, config=None): return {"code": 0, "msg": "Success"} @@ -111,22 +119,24 @@ def search( read_consistency=ReadConsistency.EVENTUAL, config=None, ): - return { - "rows": [ - { - "row": { - "id": "doc_id_001", - "vector": [0.23432432, 0.8923744, 0.89238432], - "text": "text", - "metadata": {"doc_id": "doc_id_001"}, - }, - "distance": 0.1, - "score": 0.5, - } - ], - "code": 0, - "msg": "Success", - } + return AttrDict( + { + "rows": [ + { + "row": { + "id": "doc_id_001", + "vector": [0.23432432, 0.8923744, 0.89238432], + "text": "text", + "metadata": '{"doc_id": "doc_id_001"}', + }, + "distance": 0.1, + "score": 0.5, + } + ], + "code": 0, + "msg": "Success", + } + ) MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true" @@ -146,6 +156,7 @@ def setup_baiduvectordb_mock(request, monkeypatch: MonkeyPatch): monkeypatch.setattr(Table, "rebuild_index", MockBaiduVectorDBClass.rebuild_index) monkeypatch.setattr(Table, "describe_index", MockBaiduVectorDBClass.describe_index) monkeypatch.setattr(Table, "delete", MockBaiduVectorDBClass.delete) + monkeypatch.setattr(Table, "query", MockBaiduVectorDBClass.query) monkeypatch.setattr(Table, "search", MockBaiduVectorDBClass.search) yield diff --git a/api/tests/integration_tests/vdb/baidu/test_baidu.py b/api/tests/integration_tests/vdb/baidu/test_baidu.py index 01a7f8853ac367..5dc2ce4f82e18d 100644 --- a/api/tests/integration_tests/vdb/baidu/test_baidu.py +++ b/api/tests/integration_tests/vdb/baidu/test_baidu.py @@ -4,9 +4,6 @@ from tests.integration_tests.vdb.__mock.baiduvectordb import setup_baiduvectordb_mock from tests.integration_tests.vdb.test_vector_store import AbstractVectorTest, get_example_text, setup_mock_redis -mock_client = MagicMock() -mock_client.list_databases.return_value = [{"name": "test"}] - class BaiduVectorTest(AbstractVectorTest): def __init__(self): diff --git a/dev/pytest/pytest_vdb.sh b/dev/pytest/pytest_vdb.sh index 6809ef7c6f74c2..ea1e7d4c806a5e 100755 --- a/dev/pytest/pytest_vdb.sh +++ b/dev/pytest/pytest_vdb.sh @@ -8,4 +8,5 @@ pytest api/tests/integration_tests/vdb/chroma \ api/tests/integration_tests/vdb/qdrant \ api/tests/integration_tests/vdb/weaviate \ api/tests/integration_tests/vdb/elasticsearch \ - api/tests/integration_tests/vdb/vikingdb + api/tests/integration_tests/vdb/vikingdb \ + api/tests/integration_tests/vdb/baidu