From fad208c9b82975884d9d8c0048b7017a8c418c58 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Wed, 12 Jun 2024 10:28:15 -0700 Subject: [PATCH] feat: Add hybrid query example to vector search sample. FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/3932 from googleapis:release-please--branches--main 346f4c03f036a3a343b45723f0ba285fb5de2227 PiperOrigin-RevId: 642658354 --- samples/model-builder/test_constants.py | 20 ++++++++++++++ .../vector_search_find_neighbors_sample.py | 27 +++++++++++++++++++ ...ector_search_find_neighbors_sample_test.py | 20 +++++++++++--- 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index cc62f17144..43e6b7b1de 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -350,6 +350,26 @@ VECTOR_SEARCH_INDEX_ENDPOINT = "456" VECTOR_SEARCH_DEPLOYED_INDEX_ID = "789" VECTOR_SERACH_INDEX_QUERIES = [[0.1]] +VECTOR_SERACH_INDEX_HYBRID_QUERIES = [ + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3], + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[1.0, 1.0, 1.0], + rrf_ranking_alpha=0.5, + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3], + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[0.1, 0.2, 0.3], + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[0.1, 0.2, 0.3], + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3] + ), +] VECTOR_SEARCH_INDEX_DISPLAY_NAME = "my-vector-search-index" VECTOR_SEARCH_GCS_URI = "gs://fake-dir" VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME = "my-vector-search-index-endpoint" diff --git a/samples/model-builder/vector_search/vector_search_find_neighbors_sample.py b/samples/model-builder/vector_search/vector_search_find_neighbors_sample.py index 8afef4a82d..8cb46a5176 100644 --- a/samples/model-builder/vector_search/vector_search_find_neighbors_sample.py +++ b/samples/model-builder/vector_search/vector_search_find_neighbors_sample.py @@ -55,5 +55,32 @@ def vector_search_find_neighbors( ) print(resp) + # Query hybrid datapoints, sparse-only datapoints, and dense-only datapoints. + hybrid_queries = [ + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3], + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[1.0, 1.0, 1.0], + rrf_ranking_alpha=0.5, + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3], + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[0.1, 0.2, 0.3], + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[0.1, 0.2, 0.3], + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3] + ), + ] + + hybrid_resp = my_index_endpoint.find_neighbors( + deployed_index_id=deployed_index_id, + queries=hybrid_queries, + num_neighbors=num_neighbors,) + print(hybrid_resp) # [END aiplatform_sdk_vector_search_find_neighbors_sample] diff --git a/samples/model-builder/vector_search/vector_search_find_neighbors_sample_test.py b/samples/model-builder/vector_search/vector_search_find_neighbors_sample_test.py index d4ba3143cb..4fabaa5aef 100644 --- a/samples/model-builder/vector_search/vector_search_find_neighbors_sample_test.py +++ b/samples/model-builder/vector_search/vector_search_find_neighbors_sample_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import call + import test_constants as constants from vector_search import vector_search_find_neighbors_sample @@ -38,8 +40,18 @@ def test_vector_search_find_neighbors_sample( index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT) # Check index_endpoint.find_neighbors is called with right params. - mock_index_endpoint_find_neighbors.assert_called_with( - deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID, - queries=constants.VECTOR_SERACH_INDEX_QUERIES, - num_neighbors=10 + mock_index_endpoint_find_neighbors.assert_has_calls( + [ + call( + deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID, + queries=constants.VECTOR_SERACH_INDEX_QUERIES, + num_neighbors=10, + ), + call( + deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID, + queries=constants.VECTOR_SERACH_INDEX_HYBRID_QUERIES, + num_neighbors=10, + ), + ], + any_order=False, )