diff --git a/samples/model-builder/test_constants.py b/samples/model-builder/test_constants.py index cc62f171440..43e6b7b1dee 100644 --- a/samples/model-builder/test_constants.py +++ b/samples/model-builder/test_constants.py @@ -350,6 +350,26 @@ VECTOR_SEARCH_INDEX_ENDPOINT = "456" VECTOR_SEARCH_DEPLOYED_INDEX_ID = "789" VECTOR_SERACH_INDEX_QUERIES = [[0.1]] +VECTOR_SERACH_INDEX_HYBRID_QUERIES = [ + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3], + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[1.0, 1.0, 1.0], + rrf_ranking_alpha=0.5, + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3], + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[0.1, 0.2, 0.3], + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[0.1, 0.2, 0.3], + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3] + ), +] VECTOR_SEARCH_INDEX_DISPLAY_NAME = "my-vector-search-index" VECTOR_SEARCH_GCS_URI = "gs://fake-dir" VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME = "my-vector-search-index-endpoint" diff --git a/samples/model-builder/vector_search/vector_search_find_neighbors_sample.py b/samples/model-builder/vector_search/vector_search_find_neighbors_sample.py index 8afef4a82df..8cb46a5176e 100644 --- a/samples/model-builder/vector_search/vector_search_find_neighbors_sample.py +++ b/samples/model-builder/vector_search/vector_search_find_neighbors_sample.py @@ -55,5 +55,32 @@ def vector_search_find_neighbors( ) print(resp) + # Query hybrid datapoints, sparse-only datapoints, and dense-only datapoints. + hybrid_queries = [ + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3], + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[1.0, 1.0, 1.0], + rrf_ranking_alpha=0.5, + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3], + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[0.1, 0.2, 0.3], + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + sparse_embedding_dimensions=[10, 20, 30], + sparse_embedding_values=[0.1, 0.2, 0.3], + ), + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery( + dense_embedding=[1, 2, 3] + ), + ] + + hybrid_resp = my_index_endpoint.find_neighbors( + deployed_index_id=deployed_index_id, + queries=hybrid_queries, + num_neighbors=num_neighbors,) + print(hybrid_resp) # [END aiplatform_sdk_vector_search_find_neighbors_sample] diff --git a/samples/model-builder/vector_search/vector_search_find_neighbors_sample_test.py b/samples/model-builder/vector_search/vector_search_find_neighbors_sample_test.py index d4ba3143cb1..4fabaa5aefe 100644 --- a/samples/model-builder/vector_search/vector_search_find_neighbors_sample_test.py +++ b/samples/model-builder/vector_search/vector_search_find_neighbors_sample_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from unittest.mock import call + import test_constants as constants from vector_search import vector_search_find_neighbors_sample @@ -38,8 +40,18 @@ def test_vector_search_find_neighbors_sample( index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT) # Check index_endpoint.find_neighbors is called with right params. - mock_index_endpoint_find_neighbors.assert_called_with( - deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID, - queries=constants.VECTOR_SERACH_INDEX_QUERIES, - num_neighbors=10 + mock_index_endpoint_find_neighbors.assert_has_calls( + [ + call( + deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID, + queries=constants.VECTOR_SERACH_INDEX_QUERIES, + num_neighbors=10, + ), + call( + deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID, + queries=constants.VECTOR_SERACH_INDEX_HYBRID_QUERIES, + num_neighbors=10, + ), + ], + any_order=False, )