diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index ff2b914f88..33eec35f58 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -16,7 +16,7 @@ # from dataclasses import dataclass, field -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple, Union from google.auth import credentials as auth_credentials from google.cloud.aiplatform import base @@ -148,6 +148,37 @@ def __post_init__(self): ) +@dataclass +class HybridQuery: + """ + Hyrbid query. Could be used for dense-only or sparse-only or hybrid queries. + + dense_embedding (List[float]): + Optional. The dense part of the hybrid queries. + sparse_embedding_values (List[float]): + Optional. The sparse values of the sparse part of the queries. + + sparse_embedding_dimensions (List[int]): + Optional. The corresponding dimensions of the sparse values. + For example, values [1,2,3] with dimensions [4,5,6] means value 1 is of the + 4th dimension, value 2 is of the 4th dimension, and value 3 is of the 6th + dimension. + + rrf_ranking_alpha (float): + Optional. This should not be specified for dense-only or sparse-only queries. + A value between 0 and 1 for ranking algorithm RRF, representing + the ratio for sparse v.s. dense embeddings returned in the query result. + If the alpha is 0, only sparse embeddings are being returned, and no dense + embedding is being returned. When alhpa is 1, only dense embeddings are being + returned, and no sparse embedding is being returned. + """ + + dense_embedding: List[float] = None + sparse_embedding_values: List[float] = None + sparse_embedding_dimensions: List[int] = None + rrf_ranking_alpha: float = None + + @dataclass class MatchNeighbor: """The id and distance of a nearest neighbor match for a given query embedding. @@ -157,7 +188,7 @@ class MatchNeighbor: Required. The id of the neighbor. distance (float): Required. The distance to the query embedding. - feature_vector (List(float)): + feature_vector (List[float]): Optional. The feature vector of the matching datapoint. crowding_tag (Optional[str]): Optional. Crowding tag of the datapoint, the @@ -167,6 +198,14 @@ class MatchNeighbor: Optional. The restricts of the matching datapoint. numeric_restricts: Optional. The numeric restricts of the matching datapoint. + sparse_embedding_values (List[float]): + Optional. The sparse values of the sparse part of the matching + datapoint. + sparse_embedding_dimensions (List[int]): + Optional. The corresponding dimensions of the sparse values. + For example, values [1,2,3] with dimensions [4,5,6] means value 1 is + of the 4th dimension, value 2 is of the 4th dimension, and value 3 is + of the 6th dimension. """ @@ -176,6 +215,8 @@ class MatchNeighbor: crowding_tag: Optional[str] = None restricts: Optional[List[Namespace]] = None numeric_restricts: Optional[List[NumericNamespace]] = None + sparse_embedding_values: Optional[List[float]] = None + sparse_embedding_dimensions: Optional[List[int]] = None def from_index_datapoint( self, index_datapoint: gca_index_v1beta1.IndexDatapoint @@ -207,22 +248,31 @@ def from_index_datapoint( ] if index_datapoint.numeric_restricts is not None: self.numeric_restricts = [] - for restrict in index_datapoint.numeric_restricts: - numeric_namespace = None - restrict_value_type = restrict._pb.WhichOneof("Value") - if restrict_value_type == "value_int": - numeric_namespace = NumericNamespace( - name=restrict.namespace, value_int=restrict.value_int - ) - elif restrict_value_type == "value_float": - numeric_namespace = NumericNamespace( - name=restrict.namespace, value_float=restrict.value_float - ) - elif restrict_value_type == "value_double": - numeric_namespace = NumericNamespace( - name=restrict.namespace, value_double=restrict.value_double - ) - self.numeric_restricts.append(numeric_namespace) + for restrict in index_datapoint.numeric_restricts: + numeric_namespace = None + restrict_value_type = restrict._pb.WhichOneof("Value") + if restrict_value_type == "value_int": + numeric_namespace = NumericNamespace( + name=restrict.namespace, value_int=restrict.value_int + ) + elif restrict_value_type == "value_float": + numeric_namespace = NumericNamespace( + name=restrict.namespace, value_float=restrict.value_float + ) + elif restrict_value_type == "value_double": + numeric_namespace = NumericNamespace( + name=restrict.namespace, value_double=restrict.value_double + ) + self.numeric_restricts.append(numeric_namespace) + # sparse embeddings + if ( + index_datapoint.sparse_embedding is not None + and index_datapoint.sparse_embedding.values is not None + ): + self.sparse_embedding_values = index_datapoint.sparse_embedding.values + self.sparse_embedding_dimensions = ( + index_datapoint.sparse_embedding.dimensions + ) return self def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighbor": @@ -250,22 +300,22 @@ def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighb ] if embedding.numeric_restricts: self.numeric_restricts = [] - for restrict in embedding.numeric_restricts: - numeric_namespace = None - restrict_value_type = restrict.WhichOneof("Value") - if restrict_value_type == "value_int": - numeric_namespace = NumericNamespace( - name=restrict.name, value_int=restrict.value_int - ) - elif restrict_value_type == "value_float": - numeric_namespace = NumericNamespace( - name=restrict.name, value_float=restrict.value_float - ) - elif restrict_value_type == "value_double": - numeric_namespace = NumericNamespace( - name=restrict.name, value_double=restrict.value_double - ) - self.numeric_restricts.append(numeric_namespace) + for restrict in embedding.numeric_restricts: + numeric_namespace = None + restrict_value_type = restrict.WhichOneof("Value") + if restrict_value_type == "value_int": + numeric_namespace = NumericNamespace( + name=restrict.name, value_int=restrict.value_int + ) + elif restrict_value_type == "value_float": + numeric_namespace = NumericNamespace( + name=restrict.name, value_float=restrict.value_float + ) + elif restrict_value_type == "value_double": + numeric_namespace = NumericNamespace( + name=restrict.name, value_double=restrict.value_double + ) + self.numeric_restricts.append(numeric_namespace) return self @@ -1322,7 +1372,7 @@ def find_neighbors( self, *, deployed_index_id: str, - queries: Optional[List[List[float]]] = None, + queries: Optional[Union[List[List[float]], List[HybridQuery]]] = None, num_neighbors: int = 10, filter: Optional[List[Namespace]] = None, per_crowding_attribute_neighbor_count: Optional[int] = None, @@ -1346,8 +1396,15 @@ def find_neighbors( Args: deployed_index_id (str): Required. The ID of the DeployedIndex to match the queries against. - queries (List[List[float]]): - Required. A list of queries. Each query is a list of floats, representing a single embedding. + queries (Union[List[List[float]], List[HybridQuery]]): + Optional. A list of queries. + + For regular dense-only queries, each query is a list of floats, + representing a single embedding. + + For hybrid queries, each query is a hybrid query of type + aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery. + num_neighbors (int): Required. The number of nearest neighbors to be retrieved from database for each query. @@ -1381,7 +1438,7 @@ def find_neighbors( Note that returning full datapoint will significantly increase the latency and cost of the query. - numeric_filter (list[NumericNamespace]): + numeric_filter (List[NumericNamespace]): Optional. A list of NumericNamespaces for filtering the matching results. For example: [NumericNamespace(name="cost", value_int=5, op="GREATER")] @@ -1437,30 +1494,54 @@ def find_neighbors( numeric_restrict.value_double = numeric_namespace.value_double numeric_restricts.append(numeric_restrict) # Queries - query_by_id = False if queries else True - queries = queries if queries else embedding_ids - if queries: - for query in queries: - find_neighbors_query = gca_match_service_v1beta1.FindNeighborsRequest.Query( - neighbor_count=num_neighbors, - per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count, - approximate_neighbor_count=approx_num_neighbors, - fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override, - ) - datapoint = gca_index_v1beta1.IndexDatapoint( - datapoint_id=query if query_by_id else None, - feature_vector=None if query_by_id else query, - ) - datapoint.restricts.extend(restricts) - datapoint.numeric_restricts.extend(numeric_restricts) - find_neighbors_query.datapoint = datapoint - find_neighbors_request.queries.append(find_neighbors_query) + query_by_id = False + query_is_hybrid = False + if embedding_ids: + query_by_id = True + query_iterators: list[str] = embedding_ids + elif queries: + query_is_hybrid = isinstance(queries[0], HybridQuery) + query_iterators = queries else: raise ValueError( "To find neighbors using matching engine," - "please specify `queries` or `embedding_ids`" + "please specify `queries` or `embedding_ids` or `hybrid_queries`" ) + for query in query_iterators: + find_neighbors_query = gca_match_service_v1beta1.FindNeighborsRequest.Query( + neighbor_count=num_neighbors, + per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count, + approximate_neighbor_count=approx_num_neighbors, + fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override, + ) + if query_by_id: + datapoint = gca_index_v1beta1.IndexDatapoint( + datapoint_id=query, + ) + elif query_is_hybrid: + datapoint = gca_index_v1beta1.IndexDatapoint( + feature_vector=query.dense_embedding, + sparse_embedding=gca_index_v1beta1.IndexDatapoint.SparseEmbedding( + values=query.sparse_embedding_values, + dimensions=query.sparse_embedding_dimensions, + ), + ) + if query.rrf_ranking_alpha: + find_neighbors_query.rrf = ( + gca_match_service_v1beta1.FindNeighborsRequest.Query.RRF( + alpha=query.rrf_ranking_alpha, + ) + ) + else: + datapoint = gca_index_v1beta1.IndexDatapoint( + feature_vector=query, + ) + datapoint.restricts.extend(restricts) + datapoint.numeric_restricts.extend(numeric_restricts) + find_neighbors_query.datapoint = datapoint + find_neighbors_request.queries.append(find_neighbors_query) + response = self._public_match_client.find_neighbors(find_neighbors_request) # Wrap the results in MatchNeighbor objects and return @@ -1543,7 +1624,6 @@ def read_index_datapoints( read_index_datapoints_request ) - # Wrap the results and return return response.datapoints def _batch_get_embeddings( diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index 8fc693a23c..a2658700b5 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -29,6 +29,7 @@ Namespace, NumericNamespace, MatchNeighbor, + HybridQuery, ) from google.cloud.aiplatform.compat.types import ( matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref, @@ -232,6 +233,18 @@ ] ] _TEST_QUERY_IDS = ["1", "2"] +_TEST_HYBRID_QUERIES = [ + HybridQuery( + sparse_embedding_dimensions=[1, 2, 3], + sparse_embedding_values=[0.1, 0.2, 0.3], + rrf_ranking_alpha=0.2, + ), + HybridQuery( + dense_embedding=_TEST_QUERIES[0], + sparse_embedding_dimensions=[1, 2, 3], + sparse_embedding_values=[0.1, 0.2, 0.3], + ), +] _TEST_NUM_NEIGHBOURS = 1 _TEST_FILTER = [ Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"]) @@ -1278,7 +1291,7 @@ def test_index_private_service_connect_endpoint_match_queries( index_endpoint_match_queries_mock.assert_called_with(batch_request) @pytest.mark.usefixtures("get_index_public_endpoint_mock") - def test_index_public_endpoint_find_neighbors_queries( + def test_index_public_endpoint_find_neighbors_queries_backward_compatibility( self, index_public_endpoint_match_queries_mock ): aiplatform.init(project=_TEST_PROJECT) @@ -1326,6 +1339,79 @@ def test_index_public_endpoint_find_neighbors_queries( find_neighbors_request ) + @pytest.mark.usefixtures("get_index_public_endpoint_mock") + def test_index_public_endpoint_find_neighbors_queries( + self, index_public_endpoint_match_queries_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_public_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_public_index_endpoint.find_neighbors( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + num_neighbors=_TEST_NUM_NEIGHBOURS, + filter=_TEST_FILTER, + per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT, + queries=_TEST_HYBRID_QUERIES, + ) + + find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest( + index_endpoint=my_public_index_endpoint.resource_name, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + queries=[ + gca_match_service_v1beta1.FindNeighborsRequest.Query( + neighbor_count=_TEST_NUM_NEIGHBOURS, + datapoint=gca_index_v1beta1.IndexDatapoint( + restricts=[ + gca_index_v1beta1.IndexDatapoint.Restriction( + namespace="class", + allow_list=["token_1"], + deny_list=["token_2"], + ) + ], + sparse_embedding=gca_index_v1beta1.IndexDatapoint.SparseEmbedding( + values=[0.1, 0.2, 0.3], dimensions=[1, 2, 3] + ), + ), + rrf=gca_match_service_v1beta1.FindNeighborsRequest.Query.RRF( + alpha=0.2, + ), + per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS, + fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + ), + gca_match_service_v1beta1.FindNeighborsRequest.Query( + neighbor_count=_TEST_NUM_NEIGHBOURS, + datapoint=gca_index_v1beta1.IndexDatapoint( + feature_vector=_TEST_QUERIES[0], + restricts=[ + gca_index_v1beta1.IndexDatapoint.Restriction( + namespace="class", + allow_list=["token_1"], + deny_list=["token_2"], + ) + ], + sparse_embedding=gca_index_v1beta1.IndexDatapoint.SparseEmbedding( + values=[0.1, 0.2, 0.3], dimensions=[1, 2, 3] + ), + ), + per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS, + fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + ), + ], + return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT, + ) + + index_public_endpoint_match_queries_mock.assert_called_with( + find_neighbors_request + ) + @pytest.mark.usefixtures("get_index_public_endpoint_mock") def test_index_public_endpoint_find_neiggbor_query_by_id( self, index_public_endpoint_match_queries_mock