Skip to content

Commit

Permalink
fix: Add restricts and crowding tag to MatchingEngineIndexEndpoint
Browse files Browse the repository at this point in the history
…query response.

PiperOrigin-RevId: 607012218
  • Loading branch information
lingyinw authored and copybara-github committed Feb 14, 2024
1 parent 14b41b5 commit 83cb52d
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 42 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,24 +42,6 @@
_LOGGER = base.Logger(__name__)


@dataclass
class MatchNeighbor:
"""The id and distance of a nearest neighbor match for a given query embedding.
Args:
id (str):
Required. The id of the neighbor.
distance (float):
Required. The distance to the query embedding.
feature_vector (List(float)):
Optional. The feature vector of the matching datapoint.
"""

id: str
distance: float
feature_vector: Optional[List[float]] = None


@dataclass
class Namespace:
"""Namespace specifies the rules for determining the datapoints that are eligible for each matching query, overall query is an AND across namespaces.
Expand Down Expand Up @@ -156,14 +138,136 @@ def __post_init__(self):
)
# Check operator validity
if (
self.op
self.op is not None
and self.op
not in gca_index_v1beta1.IndexDatapoint.NumericRestriction.Operator._member_names_
):
raise ValueError(
f"Invalid operator '{self.op}'," " must be one of the valid operators."
)


@dataclass
class MatchNeighbor:
"""The id and distance of a nearest neighbor match for a given query embedding.
Args:
id (str):
Required. The id of the neighbor.
distance (float):
Required. The distance to the query embedding.
feature_vector (List(float)):
Optional. The feature vector of the matching datapoint.
crowding_tag (Optional[str]):
Optional. Crowding tag of the datapoint, the
number of neighbors to return in each crowding,
can be configured during query.
restricts (List[Namespace]):
Optional. The restricts of the matching datapoint.
numeric_restricts:
Optional. The numeric restricts of the matching datapoint.
"""

id: str
distance: float
feature_vector: Optional[List[float]] = None
crowding_tag: Optional[str] = None
restricts: Optional[List[Namespace]] = None
numeric_restricts: Optional[List[NumericNamespace]] = None

def from_index_datapoint(
self, index_datapoint: gca_index_v1beta1.IndexDatapoint
) -> "MatchNeighbor":
"""Copies MatchNeighbor fields from an IndexDatapoint.
Args:
index_datapoint (gca_index_v1beta1.IndexDatapoint):
Required. An index datapoint.
Returns:
MatchNeighbor
"""
if not index_datapoint:
return self
self.feature_vector = index_datapoint.feature_vector
if (
index_datapoint.crowding_tag is not None
and index_datapoint.crowding_tag.crowding_attribute is not None
):
self.crowding_tag = index_datapoint.crowding_tag.crowding_attribute
self.restricts = [
Namespace(
name=restrict.namespace,
allow_tokens=restrict.allow_list,
deny_tokens=restrict.deny_list,
)
for restrict in index_datapoint.restricts
]
if index_datapoint.numeric_restricts is not None:
self.numeric_restricts = []
for restrict in index_datapoint.numeric_restricts:
numeric_namespace = None
restrict_value_type = restrict._pb.WhichOneof("Value")
if restrict_value_type == "value_int":
numeric_namespace = NumericNamespace(
name=restrict.namespace, value_int=restrict.value_int
)
elif restrict_value_type == "value_float":
numeric_namespace = NumericNamespace(
name=restrict.namespace, value_float=restrict.value_float
)
elif restrict_value_type == "value_double":
numeric_namespace = NumericNamespace(
name=restrict.namespace, value_double=restrict.value_double
)
self.numeric_restricts.append(numeric_namespace)
return self

def from_embedding(self, embedding: match_service_pb2.Embedding) -> "MatchNeighbor":
"""Copies MatchNeighbor fields from an Embedding.
Args:
embedding (gca_index_v1beta1.Embedding):
Required. An embedding.
Returns:
MatchNeighbor
"""
if not embedding:
return self
self.feature_vector = embedding.float_val
if not self.crowding_tag and embedding.crowding_attribute is not None:
self.crowding_tag = str(embedding.crowding_attribute)
self.restricts = [
Namespace(
name=restrict.name,
allow_tokens=restrict.allow_tokens,
deny_tokens=restrict.deny_tokens,
)
for restrict in embedding.restricts
]
if embedding.numeric_restricts:
self.numeric_restricts = []
for restrict in embedding.numeric_restricts:
numeric_namespace = None
restrict_value_type = restrict.WhichOneof("Value")
if restrict_value_type == "value_int":
numeric_namespace = NumericNamespace(
name=restrict.name, value_int=restrict.value_int
)
elif restrict_value_type == "value_float":
numeric_namespace = NumericNamespace(
name=restrict.name, value_float=restrict.value_float
)
elif restrict_value_type == "value_double":
numeric_namespace = NumericNamespace(
name=restrict.name, value_double=restrict.value_double
)
self.numeric_restricts.append(numeric_namespace)
return self


class MatchingEngineIndexEndpoint(base.VertexAiResourceNounWithFutureManager):
"""Matching Engine index endpoint resource for Vertex AI."""

Expand Down Expand Up @@ -1333,10 +1437,8 @@ def find_neighbors(
return [
[
MatchNeighbor(
id=neighbor.datapoint.datapoint_id,
distance=neighbor.distance,
feature_vector=neighbor.datapoint.feature_vector,
)
id=neighbor.datapoint.datapoint_id, distance=neighbor.distance
).from_index_datapoint(index_datapoint=neighbor.datapoint)
for neighbor in embedding_neighbors.neighbors
]
for embedding_neighbors in response.nearest_neighbors
Expand Down Expand Up @@ -1572,13 +1674,17 @@ def match(
response = stub.BatchMatch(batch_request)

# Wrap the results in MatchNeighbor objects and return
return [
[
MatchNeighbor(
id=embedding_neighbors.neighbor[i].id,
distance=embedding_neighbors.neighbor[i].distance,
match_neighbors_response = []
for resp in response.responses[0].responses:
match_neighbors_id_map = {}
for neighbor in resp.neighbor:
match_neighbors_id_map[neighbor.id] = MatchNeighbor(
id=neighbor.id, distance=neighbor.distance
)
for i in range(len(embedding_neighbors.neighbor))
]
for embedding_neighbors in response.responses[0].responses
]
for embedding in resp.embeddings:
if embedding.id in match_neighbors_id_map:
match_neighbors_id_map[embedding.id] = match_neighbors_id_map[
embedding.id
].from_embedding(embedding=embedding)
match_neighbors_response.append(list(match_neighbors_id_map.values()))
return match_neighbors_response
Loading

0 comments on commit 83cb52d

Please sign in to comment.