Skip to content

Commit

Permalink
feat: add numeric_restricts to MatchingEngineIndex `find_neighbors(…
Browse files Browse the repository at this point in the history
…)` for querying

public endpoints.

PiperOrigin-RevId: 582893369
  • Loading branch information
lingyinw authored and copybara-github committed Nov 16, 2023
1 parent 3a8f22c commit 6c1f2cc
Show file tree
Hide file tree
Showing 2 changed files with 203 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,91 @@ class Namespace:
deny_tokens: list = field(default_factory=list)


@dataclass
class NumericNamespace:
"""NumericNamespace specifies the rules for determining the datapoints that
are eligible for each matching query, overall query is an AND across namespaces.
This uses numeric comparisons.
Args:
name (str):
Required. The name of this numeric namespace.
value_int (int):
Optional. 64 bit integer value for comparison. Must choose one among
`value_int`, `value_float` and `value_double` for intended
precision.
value_float (float):
Optional. 32 bit float value for comparison. Must choose one among
`value_int`, `value_float` and `value_double` for
intended precision.
value_double (float):
Optional. 64b bit float value for comparison. Must choose one among
`value_int`, `value_float` and `value_double` for
intended precision.
operator (str):
Optional. Should be specified for query only, not for a datapoints.
Specify one operator to use for comparison. Datapoints for which
comparisons with query's values are true for the operator and value
combination will be allowlisted. Choose among:
"LESS" for datapoints' values < query's value;
"LESS_EQUAL" for datapoints' values <= query's value;
"EQUAL" for datapoints' values = query's value;
"GREATER_EQUAL" for datapoints' values >= query's value;
"GREATER" for datapoints' values > query's value;
"""

name: str
value_int: Optional[int] = None
value_float: Optional[float] = None
value_double: Optional[float] = None
op: Optional[str] = None

def __post_init__(self):
"""Check NumericNamespace values are of correct types and values are
not all none.
Args:
None.
Raises:
ValueError: Numeric Namespace provided values must be of correct
types and one of value_int, value_float, value_double must exist.
"""
# Check one of
if (
self.value_int is None
and self.value_float is None
and self.value_double is None
):
raise ValueError(
"Must choose one among `value_int`,"
"`value_float` and `value_double` for "
"intended precision."
)

# Check value type
if self.value_int is not None and not isinstance(self.value_int, int):
raise ValueError(
"value_int must be of type int, got" f" { type(self.value_int)}."
)
if self.value_float is not None and not isinstance(self.value_float, float):
raise ValueError(
"value_float must be of type float, got " f"{ type(self.value_float)}."
)
if self.value_double is not None and not isinstance(self.value_double, float):
raise ValueError(
"value_double must be of type float, got "
f"{ type(self.value_double)}."
)
# Check operator validity
if (
self.op
not in gca_index_v1beta1.IndexDatapoint.NumericRestriction.Operator._member_names_
):
raise ValueError(
f"Invalid operator '{self.op}'," " must be one of the valid operators."
)


class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
"""Matching Engine index endpoint resource for Vertex AI."""

Expand Down Expand Up @@ -1034,6 +1119,7 @@ def find_neighbors(
approx_num_neighbors: Optional[int] = None,
fraction_leaf_nodes_to_search_override: Optional[float] = None,
return_full_datapoint: bool = False,
numeric_filter: Optional[List[NumericNamespace]] = [],
) -> List[List[MatchNeighbor]]:
"""Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to public endpoint.
Expand Down Expand Up @@ -1082,6 +1168,11 @@ def find_neighbors(
Note that returning full datapoint will significantly increase the
latency and cost of the query.
numeric_filter (Optional[list[NumericNamespace]]):
Optional. A list of NumericNamespaces for filtering the matching
results. For example:
[NumericNamespace(name="cost", value_int=5, op="GREATER")]
will match datapoints that its cost is greater than 5.
Returns:
List[List[MatchNeighbor]] - A list of nearest neighbors for each query.
"""
Expand Down Expand Up @@ -1110,12 +1201,22 @@ def find_neighbors(
fraction_leaf_nodes_to_search_override
)
datapoint = gca_index_v1beta1.IndexDatapoint(feature_vector=query)
# Token restricts
for namespace in filter:
restrict = gca_index_v1beta1.IndexDatapoint.Restriction()
restrict.namespace = namespace.name
restrict.allow_list.extend(namespace.allow_tokens)
restrict.deny_list.extend(namespace.deny_tokens)
datapoint.restricts.append(restrict)
# Numeric restricts
for numeric_namespace in numeric_filter:
numeric_restrict = gca_index_v1beta1.IndexDatapoint.NumericRestriction()
numeric_restrict.namespace = numeric_namespace.name
numeric_restrict.op = numeric_namespace.op
numeric_restrict.value_int = numeric_namespace.value_int
numeric_restrict.value_float = numeric_namespace.value_float
numeric_restrict.value_double = numeric_namespace.value_double
datapoint.numeric_restricts.append(numeric_restrict)
find_neighbors_query.datapoint = datapoint
find_neighbors_request.queries.append(find_neighbors_query)

Expand Down
102 changes: 102 additions & 0 deletions tests/unit/aiplatform/test_matching_engine_index_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from google.cloud.aiplatform.matching_engine._protos import match_service_pb2
from google.cloud.aiplatform.matching_engine.matching_engine_index_endpoint import (
Namespace,
NumericNamespace,
)
from google.cloud.aiplatform.compat.types import (
matching_engine_deployed_index_ref as gca_matching_engine_deployed_index_ref,
Expand Down Expand Up @@ -233,6 +234,11 @@
_TEST_FILTER = [
Namespace(name="class", allow_tokens=["token_1"], deny_tokens=["token_2"])
]
_TEST_NUMERIC_FILTER = [
NumericNamespace(name="cost", value_double=0.3, op="EQUAL"),
NumericNamespace(name="size", value_int=10, op="GREATER"),
NumericNamespace(name="seconds", value_float=20.5, op="LESS_EQUAL"),
]
_TEST_IDS = ["123", "456", "789"]
_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS = 3
_TEST_APPROX_NUM_NEIGHBORS = 2
Expand Down Expand Up @@ -1080,6 +1086,102 @@ def test_index_public_endpoint_match_queries(
find_neighbors_request
)

@pytest.mark.usefixtures("get_index_public_endpoint_mock")
def test_index_public_endpoint_match_queries_with_numeric_filtering(
self, index_public_endpoint_match_queries_mock
):
aiplatform.init(project=_TEST_PROJECT)

my_pubic_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=_TEST_INDEX_ENDPOINT_ID
)

my_pubic_index_endpoint.find_neighbors(
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=_TEST_QUERIES,
num_neighbors=_TEST_NUM_NEIGHBOURS,
filter=_TEST_FILTER,
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
numeric_filter=_TEST_NUMERIC_FILTER,
)

find_neighbors_request = gca_match_service_v1beta1.FindNeighborsRequest(
index_endpoint=my_pubic_index_endpoint.resource_name,
deployed_index_id=_TEST_DEPLOYED_INDEX_ID,
queries=[
gca_match_service_v1beta1.FindNeighborsRequest.Query(
neighbor_count=_TEST_NUM_NEIGHBOURS,
datapoint=gca_index_v1beta1.IndexDatapoint(
feature_vector=_TEST_QUERIES[0],
restricts=[
gca_index_v1beta1.IndexDatapoint.Restriction(
namespace="class",
allow_list=["token_1"],
deny_list=["token_2"],
)
],
numeric_restricts=[
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
namespace="cost", value_double=0.3, op="EQUAL"
),
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
namespace="size", value_int=10, op="GREATER"
),
gca_index_v1beta1.IndexDatapoint.NumericRestriction(
namespace="seconds", value_float=20.5, op="LESS_EQUAL"
),
],
),
per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS,
approximate_neighbor_count=_TEST_APPROX_NUM_NEIGHBORS,
fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE,
)
],
return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT,
)

index_public_endpoint_match_queries_mock.assert_called_with(
find_neighbors_request
)

def test_post_init_numeric_filter_invalid_operator_throws_exception(
self,
):
expected_message = (
"Invalid operator 'NOT_EQ', must be one of the valid operators."
)
with pytest.raises(ValueError) as exception:
NumericNamespace(name="cost", value_int=3, op="NOT_EQ")

assert str(exception.value) == expected_message

def test_post_init_numeric_namespace_missing_value_throws_exception(self):
aiplatform.init(project=_TEST_PROJECT)

expected_message = (
"Must choose one among `value_int`,"
"`value_float` and `value_double` for "
"intended precision."
)

with pytest.raises(ValueError) as exception:
NumericNamespace(name="cost", op="EQUAL")

assert str(exception.value) == expected_message

def test_index_public_endpoint_match_queries_with_numeric_filtering_value_type_mismatch_throws_exception(
self,
):
expected_message = "value_int must be of type int, got <class 'float'>."

with pytest.raises(ValueError) as exception:
NumericNamespace(name="cost", value_int=0.3, op="EQUAL")

assert str(exception.value) == expected_message

@pytest.mark.usefixtures("get_index_public_endpoint_mock")
def test_index_public_endpoint_read_index_datapoints(
self, index_public_endpoint_read_index_datapoints_mock
Expand Down

0 comments on commit 6c1f2cc

Please sign in to comment.