From 3ab39a4536dc72b8a93d30c89bff04f25c724ef5 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Mon, 18 Nov 2024 15:15:24 -0800 Subject: [PATCH] feat: Add PSC automation support to matching engine index endpoint `deploy_index()`, `find_neighbors()`, `match()`, and `read_index_datapoints()`. PiperOrigin-RevId: 697774175 --- .../matching_engine_index_endpoint.py | 260 ++++++++-- .../test_matching_engine_index_endpoint.py | 473 +++++++++++++++++- 2 files changed, 701 insertions(+), 32 deletions(-) diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py index 9aeb53b344..6c51b07281 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index_endpoint.py @@ -426,7 +426,7 @@ def create( Private services access must already be configured for the network. If left unspecified, the network set with aiplatform.init will be used. - Format: projects/{project}/global/networks/{network}, where + Format: 'projects/{project}/global/networks/{network}', where {project} is a project number, as in '12345', and {network} is network name. public_endpoint_enabled (bool): @@ -564,8 +564,8 @@ def _create( to which the IndexEndpoint should be peered. Private services access must already be configured for the network. - `Format `__: - projects/{project}/global/networks/{network}. Where + Format: + {projects/{project}/global/networks/{network}. Where {project} is a project number, as in '12345', and {network} is network name. public_endpoint_enabled (bool): @@ -693,10 +693,76 @@ def _instantiate_public_match_client( api_path_override=self.public_endpoint_domain_name, ) + def _get_psc_automated_deployed_index_ip_address( + self, + deployed_index_id: Optional[str] = None, + deployed_index: Optional[ + gca_matching_engine_index_endpoint.DeployedIndex + ] = None, + psc_network: Optional[str] = None, + ) -> str: + """Helper method to get the ip address for a psc automated endpoint. + Returns: + deployed_index_id (str): + Optional. Required for private service access endpoint. + The user specified ID of the DeployedIndex. + deployed_index (gca_matching_engine_index_endpoint.DeployedIndex): + Optional. Required for private service access endpoint. + The DeployedIndex resource. + psc_network (str): + Optional. Required for private service automation enabled + deployed index. This network is the PSC network the match + query is executed in. This (project, network) + pair must already be specified for psc automation when the index + is deployed. The format is + `projects/{project_id}/global/networks/{network_name}`. + + An alternative is to set the + `private_service_connect_ip_address` field for this + MatchingEngineIndexEndpoint instance, if the ip address is + already known. + + Raises: + RuntimeError: No valid ip found for deployed index with id + deployed_index_id within network psc_network, invalid PSC + network provided. A list of valid PSC networks are: + psc_network_list. + """ + if deployed_index_id: + deployed_indexes = [ + deployed_index + for deployed_index in self.deployed_indexes + if deployed_index.id == deployed_index_id + ] + deployed_index = deployed_indexes[0] + else: + deployed_index_id = deployed_index.id + + ip_address = None + # PSC Automation, iterate through psc automation configs to find + # the ip address for the given network + psc_network_list = [ + endpoint.network + for endpoint in deployed_index.private_endpoints.psc_automated_endpoints + ] + for endpoint in deployed_index.private_endpoints.psc_automated_endpoints: + if psc_network == endpoint.network: + ip_address = endpoint.match_address + break + if not ip_address: + raise RuntimeError( + f"No valid ip found for deployed index with id " + f"'{deployed_index_id}' within network '{psc_network}', " + "invalid PSC network provided. A list of valid PSC networks" + f"are: '{psc_network_list}'." + ) + return ip_address + def _instantiate_private_match_service_stub( self, deployed_index_id: Optional[str] = None, ip_address: Optional[str] = None, + psc_network: Optional[str] = None, ) -> match_service_pb2_grpc.MatchServiceStub: """Helper method to instantiate private match service stub. Args: @@ -706,6 +772,18 @@ def _instantiate_private_match_service_stub( ip_address (str): Optional. Required for private service connect. The ip address the forwarding rule makes use of. + psc_network (str): + Optional. Required for private service automation enabled deployed + index. This network is the PSC + network the match query is executed in. This (project, network) + pair must already be specified for psc automation when the index + is deployed. The format is + `projects/{project_id}/global/networks/{network_name}`. + + An alternative is to set the + `private_service_connect_ip_address` field for this + MatchingEngineIndexEndpoint instance, if the ip address is + already known. Returns: stub (match_service_pb2_grpc.MatchServiceStub): Initialized match service stub. @@ -714,6 +792,7 @@ def _instantiate_private_match_service_stub( ValueError: Should not set ip address for networks other than private service connect. """ + if ip_address: # Should only set for Private Service Connect if self.public_endpoint_domain_name: @@ -729,7 +808,6 @@ def _instantiate_private_match_service_stub( "connection using provided ip address", ) else: - # Private Service Access, find server ip for deployed index deployed_indexes = [ deployed_index for deployed_index in self.deployed_indexes @@ -741,8 +819,14 @@ def _instantiate_private_match_service_stub( f"No deployed index with id '{deployed_index_id}' found" ) - # Retrieve server ip from deployed index - ip_address = deployed_indexes[0].private_endpoints.match_grpc_address + if deployed_indexes[0].private_endpoints.psc_automated_endpoints: + ip_address = self._get_psc_automated_deployed_index_ip_address( + deployed_index_id=deployed_index_id, + psc_network=psc_network, + ) + else: + # Private Service Access, find server ip for deployed index + ip_address = deployed_indexes[0].private_endpoints.match_grpc_address if ip_address not in self._match_grpc_stub_cache: # Set up channel and stub @@ -859,6 +943,7 @@ def _build_deployed_index( deployment_group: Optional[str] = None, auth_config_audiences: Optional[Sequence[str]] = None, auth_config_allowed_issuers: Optional[Sequence[str]] = None, + psc_automation_configs: Optional[Sequence[Tuple[str, str]]] = None, ) -> gca_matching_engine_index_endpoint.DeployedIndex: """Builds a DeployedIndex. @@ -950,6 +1035,18 @@ def _build_deployed_index( ``service-account-name@project-id.iam.gserviceaccount.com`` request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. + + psc_automation_configs (Sequence[Tuple[str, str]]): + Optional. A list of (project_id, network) pairs for Private + Service Connection endpoints to be setup for the deployed index. + The project_id is the project number of the project that the + network is in, and network is the name of the network. + Network is the full name of the Google Compute Engine + to which the index should be deployed to: + projects/{project}/global/networks/{network}, where + {project} is a project number, as in '12345', and {network} + is network name. + """ deployed_index = gca_matching_engine_index_endpoint.DeployedIndex( @@ -989,6 +1086,16 @@ def _build_deployed_index( max_replica_count=max_replica_count, ) ) + + if psc_automation_configs: + deployed_index.psc_automation_configs = [ + gca_service_networking.PSCAutomationConfig( + project_id=psc_automation_config[0], + network=psc_automation_config[1], + ) + for psc_automation_config in psc_automation_configs + ] + return deployed_index def deploy_index( @@ -1006,6 +1113,7 @@ def deploy_index( auth_config_allowed_issuers: Optional[Sequence[str]] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), deploy_request_timeout: Optional[float] = None, + psc_automation_configs: Optional[Sequence[Tuple[str, str]]] = None, ) -> "MatchingEngineIndexEndpoint": """Deploys an existing index resource to this endpoint resource. @@ -1102,6 +1210,24 @@ def deploy_index( deploy_request_timeout (float): Optional. The timeout for the request in seconds. + + psc_automation_configs (Sequence[Tuple[str, str]]): + Optional. A list of (project_id, network) pairs for Private + Service Connection endpoints to be setup for the deployed index. + The project_id is the project number of the project that the + network is in, and network is the name of the network. + Network is the full name of the Google Compute Engine + `network `__ + to which the index should be deployed to. + + Format:{projects/{project}/global/networks/{network}. Where + {project} is a project number, as in '12345', and {network} + is network name. + + For example: + [(project_id_1, network_1), (project_id_1, network_2))] will enable + PSC automation for the index to be deployed to project_id_1's network_1 + and network_2 and can be queried within these networks. Returns: MatchingEngineIndexEndpoint - IndexEndpoint resource object """ @@ -1126,6 +1252,7 @@ def deploy_index( deployment_group=deployment_group, auth_config_audiences=auth_config_audiences, auth_config_allowed_issuers=auth_config_allowed_issuers, + psc_automation_configs=psc_automation_configs, ) deploy_lro = self.api_client.deploy_index( @@ -1387,6 +1514,7 @@ def find_neighbors( numeric_filter: Optional[List[NumericNamespace]] = None, embedding_ids: Optional[List[str]] = None, signed_jwt: Optional[str] = None, + psc_network: Optional[str] = None, ) -> List[List[MatchNeighbor]]: """Retrieves nearest neighbors for the given embedding queries on the specified deployed index which is deployed to either public or private @@ -1412,13 +1540,17 @@ def find_neighbors( aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery. num_neighbors (int): - Required. The number of nearest neighbors to be retrieved from database for - each query. + Required. The number of nearest neighbors to be retrieved from + database for each query. filter (List[Namespace]): Optional. A list of Namespaces for filtering the matching results. - For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints - that satisfy "red color" but not include datapoints with "squared shape". - Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. + For example, [Namespace("color", ["red"], []), + Namespace("shape", [], ["squared"])] will match datapoints that + satisfy "red color" but not include datapoints with "squared + shape". + Please refer to + https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json + for more detail. per_crowding_attribute_neighbor_count (int): Optional. Crowding is a constraint on a neighbor list produced @@ -1459,6 +1591,18 @@ def find_neighbors( signed_jwt (str): Optional. A signed JWT for accessing the private endpoint. + psc_network (Optional[str]): + Optional. Required for private service automation enabled + deployed index. This network is the PSC network the match query + is executed in. This (project, network) pair must already be + specified for psc automation when the index is deployed. The + format is `projects/{project_id}/global/networks/{network_name}`. + + An alternative is to set the + `private_service_connect_ip_address` field for this + MatchingEngineIndexEndpoint instance, if the ip address is + already known. + Returns: List[List[MatchNeighbor]] - A list of nearest neighbors for each query. """ @@ -1475,6 +1619,7 @@ def find_neighbors( fraction_leaf_nodes_to_search_override=fraction_leaf_nodes_to_search_override, numeric_filter=numeric_filter, signed_jwt=signed_jwt, + psc_network=psc_network, ) # Create the FindNeighbors request @@ -1575,6 +1720,7 @@ def read_index_datapoints( deployed_index_id: str, ids: List[str] = [], signed_jwt: Optional[str] = None, + psc_network: Optional[str] = None, ) -> List[gca_index_v1beta1.IndexDatapoint]: """Reads the datapoints/vectors of the given IDs on the specified deployed index which is deployed to public or private endpoint. @@ -1594,6 +1740,19 @@ def read_index_datapoints( Required. IDs of the datapoints to be searched for. signed_jwt (str): Optional. A signed JWT for accessing the private endpoint. + psc_network (Optional[str]): + Optional. Required for private service automation enabled deployed + index. This network is the PSC + network the match query is executed in. This (project, network) + pair must already be specified for psc automation when the index + is deployed. The format is + `projects/{project_id}/global/networks/{network_name}`. + + An alternative is to set the + `private_service_connect_ip_address` field for this + MatchingEngineIndexEndpoint instance, if the ip address is + already known. + Returns: List[gca_index_v1beta1.IndexDatapoint] - A list of datapoints/vectors of the given IDs. """ @@ -1603,6 +1762,7 @@ def read_index_datapoints( deployed_index_id=deployed_index_id, ids=ids, signed_jwt=signed_jwt, + psc_network=psc_network, ) response = [] @@ -1650,6 +1810,7 @@ def _batch_get_embeddings( deployed_index_id: str, ids: List[str] = [], signed_jwt: Optional[str] = None, + psc_network: Optional[str] = None, ) -> List[match_service_pb2.Embedding]: """ Reads the datapoints/vectors of the given IDs on the specified index @@ -1662,13 +1823,30 @@ def _batch_get_embeddings( Required. IDs of the datapoints to be searched for. signed_jwt: Optional. A signed JWT for accessing the private endpoint. + psc_network (Optional[str]): + Optional. Required for private service automation enabled deployed + index. This network is the PSC + network the match query is executed in. This (project, network) + pair must already be specified for psc automation when the index + is deployed. The format is + `projects/{project_id}/global/networks/{network_name}`. + + An alternative is to set the + `private_service_connect_ip_address` field for this + MatchingEngineIndexEndpoint instance, if the ip address is + already known. Returns: List[match_service_pb2.Embedding] - A list of datapoints/vectors of the given IDs. """ - stub = self._instantiate_private_match_service_stub( - deployed_index_id=deployed_index_id, - ip_address=self._private_service_connect_ip_address, - ) + if psc_network: + stub = self._instantiate_private_match_service_stub( + deployed_index_id=deployed_index_id, psc_network=psc_network + ) + else: + stub = self._instantiate_private_match_service_stub( + deployed_index_id=deployed_index_id, + ip_address=self._private_service_connect_ip_address, + ) # Create the batch get embeddings request batch_request = match_service_pb2.BatchGetEmbeddingsRequest() @@ -1695,6 +1873,7 @@ def match( low_level_batch_size: int = 0, numeric_filter: Optional[List[NumericNamespace]] = None, signed_jwt: Optional[str] = None, + psc_network: Optional[str] = None, ) -> List[List[MatchNeighbor]]: """Retrieves nearest neighbors for the given embedding queries on the specified deployed index for private endpoint only. @@ -1711,22 +1890,27 @@ def match( For hybrid queries, each query is a hybrid query of type aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery. num_neighbors (int): - Required. The number of nearest neighbors to be retrieved from database for - each query. + Required. The number of nearest neighbors to be retrieved from + database for each query. filter (List[Namespace]): Optional. A list of Namespaces for filtering the matching results. - For example, [Namespace("color", ["red"], []), Namespace("shape", [], ["squared"])] will match datapoints - that satisfy "red color" but not include datapoints with "squared shape". - Please refer to https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json for more detail. + For example, [Namespace("color", ["red"], []), + Namespace("shape", [], ["squared"])] will match datapoints that + satisfy "red color" but not include datapoints with "squared + shape". Please refer to + https://cloud.google.com/vertex-ai/docs/matching-engine/filtering#json + for more detail. per_crowding_attribute_num_neighbors (int): - Optional. Crowding is a constraint on a neighbor list produced by nearest neighbor - search requiring that no more than some value k' of the k neighbors - returned have the same value of crowding_attribute. + Optional. Crowding is a constraint on a neighbor list produced + by nearest neighbor search requiring that no more than some + value k' of the k neighbors returned have the same value of + crowding_attribute. It's used for improving result diversity. This field is the maximum number of matches with the same crowding tag. approx_num_neighbors (int): - The number of neighbors to find via approximate search before exact reordering is performed. - If not set, the default value from scam config is used; if set, this value must be > 0. + The number of neighbors to find via approximate search before + exact reordering is performed. If not set, the default value + from scam config is used; if set, this value must be > 0. fraction_leaf_nodes_to_search_override (float): Optional. The fraction of the number of leaves to search, set at query time allows user to tune search performance. This value @@ -1746,14 +1930,32 @@ def match( will match datapoints that its cost is greater than 5. signed_jwt (str): Optional. A signed JWT for accessing the private endpoint. + psc_network (Optional[str]): + Optional. Required for private service automation enabled + deployed index. This network is the PSC + network the match query is executed in. This (project, network) + pair must already be specified for psc automation when the index + is deployed. The format is + `projects/{project_id}/global/networks/{network_name}`. + + An alternative is to set the + `private_service_connect_ip_address` field for this + MatchingEngineIndexEndpoint instance, if the ip address is + already known. Returns: List[List[MatchNeighbor]] - A list of nearest neighbors for each query. """ - stub = self._instantiate_private_match_service_stub( - deployed_index_id=deployed_index_id, - ip_address=self._private_service_connect_ip_address, - ) + if psc_network: + stub = self._instantiate_private_match_service_stub( + deployed_index_id=deployed_index_id, + psc_network=psc_network, + ) + else: + stub = self._instantiate_private_match_service_stub( + deployed_index_id=deployed_index_id, + ip_address=self._private_service_connect_ip_address, + ) # Create the batch match request batch_request = match_service_pb2.BatchMatchRequest() diff --git a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py index afff2a2a98..bfd85c41b6 100644 --- a/tests/unit/aiplatform/test_matching_engine_index_endpoint.py +++ b/tests/unit/aiplatform/test_matching_engine_index_endpoint.py @@ -106,6 +106,15 @@ _TEST_SIGNED_JWT = "signed_jwt" _TEST_AUTHORIZATION_METADATA = (("authorization", f"Bearer: {_TEST_SIGNED_JWT}"),) +_TEST_PSC_NETWORK1 = "projects/project1/global/networks/network1" +_TEST_PSC_NETWORK2 = "projects/project2/global/networks/network2" +_TEST_PSC_NETWORK3 = "projects/project3/global/networks/network3" +_TEST_PSC_AUTOMATION_CONFIGS = [ + ("project1", _TEST_PSC_NETWORK1), + ("project2", _TEST_PSC_NETWORK2), + ("project3", _TEST_PSC_NETWORK3), +] + # deployment_updated _TEST_MIN_REPLICA_COUNT_UPDATED = 4 _TEST_MAX_REPLICA_COUNT_UPDATED = 4 @@ -275,9 +284,19 @@ _TEST_ENCRYPTION_SPEC_KEY_NAME = "kms_key_name" _TEST_PROJECT_ALLOWLIST = ["project-1", "project-2"] _TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS = "10.128.0.5" +_TEST_PRIVATE_SERVICE_CONNECT_IP_AUTOMATION_ADDRESS_1 = "10.128.0.6" +_TEST_PRIVATE_SERVICE_CONNECT_IP_AUTOMATION_ADDRESS_2 = "10.128.0.7" +_TEST_PRIVATE_SERVICE_CONNECT_IP_AUTOMATION_ADDRESS_3 = "10.128.0.8" +_TEST_SERVICE_ATTACHMENT_URI = "projects/test-project/regions/test-region/serviceAttachments/test-service-attachment" _TEST_PRIVATE_SERVICE_CONNECT_URI = "{}:10000".format( _TEST_PRIVATE_SERVICE_CONNECT_IP_ADDRESS ) +_TEST_PRIVATE_SERVICE_CONNECT_AUTOMATION_URI_1 = "{}:10000".format( + _TEST_PRIVATE_SERVICE_CONNECT_IP_AUTOMATION_ADDRESS_1 +) +_TEST_PRIVATE_SERVICE_CONNECT_AUTOMATION_URI_3 = "{}:10000".format( + _TEST_PRIVATE_SERVICE_CONNECT_IP_AUTOMATION_ADDRESS_3 +) _TEST_READ_INDEX_DATAPOINTS_RESPONSE = [ gca_index_v1beta1.IndexDatapoint( datapoint_id="1", @@ -411,6 +430,205 @@ def get_index_endpoint_mock(): yield get_index_endpoint_mock +@pytest.fixture +def get_psa_index_endpoint_mock(): + with patch.object( + index_endpoint_service_client.IndexEndpointServiceClient, "get_index_endpoint" + ) as get_psa_index_endpoint_mock: + index_endpoint = gca_index_endpoint.IndexEndpoint( + name=_TEST_INDEX_ENDPOINT_NAME, + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + ) + index_endpoint.deployed_indexes = [ + gca_index_endpoint.DeployedIndex( + id=_TEST_DEPLOYED_INDEX_ID, + index=_TEST_INDEX_NAME, + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, + deployment_group=_TEST_DEPLOYMENT_GROUP, + automatic_resources={ + "min_replica_count": _TEST_MIN_REPLICA_COUNT, + "max_replica_count": _TEST_MAX_REPLICA_COUNT, + }, + deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig( + auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider( + audiences=_TEST_AUTH_CONFIG_AUDIENCES, + allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + ) + ), + private_endpoints=gca_index_endpoint.IndexPrivateEndpoints( + match_grpc_address="10.128.0.0", + service_attachment=_TEST_SERVICE_ATTACHMENT_URI, + ), + ), + gca_index_endpoint.DeployedIndex( + id=f"{_TEST_DEPLOYED_INDEX_ID}_2", + index=f"{_TEST_INDEX_NAME}_2", + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, + deployment_group=_TEST_DEPLOYMENT_GROUP, + automatic_resources={ + "min_replica_count": _TEST_MIN_REPLICA_COUNT, + "max_replica_count": _TEST_MAX_REPLICA_COUNT, + }, + deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig( + auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider( + audiences=_TEST_AUTH_CONFIG_AUDIENCES, + allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + ) + ), + private_endpoints=gca_index_endpoint.IndexPrivateEndpoints( + match_grpc_address="10.128.0.1", + service_attachment=_TEST_SERVICE_ATTACHMENT_URI, + ), + ), + ] + get_psa_index_endpoint_mock.return_value = index_endpoint + yield get_psa_index_endpoint_mock + + +@pytest.fixture +def get_manual_psc_index_endpoint_mock(): + with patch.object( + index_endpoint_service_client.IndexEndpointServiceClient, "get_index_endpoint" + ) as get_manual_psc_index_endpoint_mock: + index_endpoint = gca_index_endpoint.IndexEndpoint( + name=_TEST_INDEX_ENDPOINT_NAME, + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + ) + index_endpoint.deployed_indexes = [ + gca_index_endpoint.DeployedIndex( + id=_TEST_DEPLOYED_INDEX_ID, + index=_TEST_INDEX_NAME, + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, + deployment_group=_TEST_DEPLOYMENT_GROUP, + automatic_resources={ + "min_replica_count": _TEST_MIN_REPLICA_COUNT, + "max_replica_count": _TEST_MAX_REPLICA_COUNT, + }, + deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig( + auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider( + audiences=_TEST_AUTH_CONFIG_AUDIENCES, + allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + ) + ), + private_endpoints=gca_index_endpoint.IndexPrivateEndpoints( + service_attachment=_TEST_SERVICE_ATTACHMENT_URI, + ), + ), + gca_index_endpoint.DeployedIndex( + id=f"{_TEST_DEPLOYED_INDEX_ID}_2", + index=f"{_TEST_INDEX_NAME}_2", + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, + deployment_group=_TEST_DEPLOYMENT_GROUP, + automatic_resources={ + "min_replica_count": _TEST_MIN_REPLICA_COUNT, + "max_replica_count": _TEST_MAX_REPLICA_COUNT, + }, + deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig( + auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider( + audiences=_TEST_AUTH_CONFIG_AUDIENCES, + allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + ) + ), + private_endpoints=gca_index_endpoint.IndexPrivateEndpoints( + service_attachment=_TEST_SERVICE_ATTACHMENT_URI, + ), + ), + ] + get_manual_psc_index_endpoint_mock.return_value = index_endpoint + yield get_manual_psc_index_endpoint_mock + + +@pytest.fixture +def get_psc_automated_index_endpoint_mock(): + with patch.object( + index_endpoint_service_client.IndexEndpointServiceClient, + "get_index_endpoint", + ) as get_psc_automated_index_endpoint_mock: + index_endpoint = gca_index_endpoint.IndexEndpoint( + name=_TEST_INDEX_ENDPOINT_NAME, + display_name=_TEST_INDEX_ENDPOINT_DISPLAY_NAME, + description=_TEST_INDEX_ENDPOINT_DESCRIPTION, + ) + index_endpoint.deployed_indexes = [ + gca_index_endpoint.DeployedIndex( + id=_TEST_DEPLOYED_INDEX_ID, + index=_TEST_INDEX_NAME, + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + deployment_group=_TEST_DEPLOYMENT_GROUP, + automatic_resources={ + "min_replica_count": _TEST_MIN_REPLICA_COUNT, + "max_replica_count": _TEST_MAX_REPLICA_COUNT, + }, + deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig( + auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider( + audiences=_TEST_AUTH_CONFIG_AUDIENCES, + allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + ) + ), + private_endpoints=gca_index_endpoint.IndexPrivateEndpoints( + service_attachment=_TEST_SERVICE_ATTACHMENT_URI, + psc_automated_endpoints=[ + gca_service_networking.PscAutomatedEndpoints( + network=_TEST_PSC_NETWORK1, + project_id="test-project1", + match_address=_TEST_PRIVATE_SERVICE_CONNECT_IP_AUTOMATION_ADDRESS_1, + ), + gca_service_networking.PscAutomatedEndpoints( + network=_TEST_PSC_NETWORK2, + project_id="test-project2", + match_address=_TEST_PRIVATE_SERVICE_CONNECT_IP_AUTOMATION_ADDRESS_2, + ), + gca_service_networking.PscAutomatedEndpoints( + network=_TEST_PSC_NETWORK3, + project_id="test-project3", + match_address=_TEST_PRIVATE_SERVICE_CONNECT_IP_AUTOMATION_ADDRESS_3, + ), + ], + ), + ), + gca_index_endpoint.DeployedIndex( + id=f"{_TEST_DEPLOYED_INDEX_ID}_2", + index=f"{_TEST_INDEX_NAME}_2", + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + deployment_group=_TEST_DEPLOYMENT_GROUP, + automatic_resources={ + "min_replica_count": _TEST_MIN_REPLICA_COUNT, + "max_replica_count": _TEST_MAX_REPLICA_COUNT, + }, + deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig( + auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider( + audiences=_TEST_AUTH_CONFIG_AUDIENCES, + allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + ) + ), + private_endpoints=gca_index_endpoint.IndexPrivateEndpoints( + service_attachment=_TEST_SERVICE_ATTACHMENT_URI, + psc_automated_endpoints=[ + gca_service_networking.PscAutomatedEndpoints( + network="test-network2", + project_id="test-project2", + match_address="10.128.0.8", + ) + ], + ), + ), + ] + get_psc_automated_index_endpoint_mock.return_value = index_endpoint + yield get_psc_automated_index_endpoint_mock + + @pytest.fixture def get_index_public_endpoint_mock(): with patch.object( @@ -1038,6 +1256,64 @@ def test_deploy_index(self, deploy_index_mock, undeploy_index_mock): timeout=None, ) + @pytest.mark.usefixtures("get_psc_automated_index_endpoint_mock", "get_index_mock") + def test_deploy_index_psc_automation_configs(self, deploy_index_mock): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + # Get index + my_index = aiplatform.MatchingEngineIndex(index_name=_TEST_INDEX_NAME) + + my_index_endpoint = my_index_endpoint.deploy_index( + index=my_index, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + min_replica_count=_TEST_MIN_REPLICA_COUNT, + max_replica_count=_TEST_MAX_REPLICA_COUNT, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, + deployment_group=_TEST_DEPLOYMENT_GROUP, + auth_config_audiences=_TEST_AUTH_CONFIG_AUDIENCES, + auth_config_allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + psc_automation_configs=_TEST_PSC_AUTOMATION_CONFIGS, + request_metadata=_TEST_REQUEST_METADATA, + deploy_request_timeout=_TEST_TIMEOUT, + ) + + deploy_index_mock.assert_called_once_with( + index_endpoint=my_index_endpoint.resource_name, + deployed_index=gca_index_endpoint.DeployedIndex( + id=_TEST_DEPLOYED_INDEX_ID, + index=my_index.resource_name, + display_name=_TEST_DEPLOYED_INDEX_DISPLAY_NAME, + enable_access_logging=_TEST_ENABLE_ACCESS_LOGGING, + reserved_ip_ranges=_TEST_RESERVED_IP_RANGES, + deployment_group=_TEST_DEPLOYMENT_GROUP, + automatic_resources={ + "min_replica_count": _TEST_MIN_REPLICA_COUNT, + "max_replica_count": _TEST_MAX_REPLICA_COUNT, + }, + deployed_index_auth_config=gca_index_endpoint.DeployedIndexAuthConfig( + auth_provider=gca_index_endpoint.DeployedIndexAuthConfig.AuthProvider( + audiences=_TEST_AUTH_CONFIG_AUDIENCES, + allowed_issuers=_TEST_AUTH_CONFIG_ALLOWED_ISSUERS, + ) + ), + psc_automation_configs=[ + gca_service_networking.PSCAutomationConfig( + project_id=test_psc_automation_config[0], + network=test_psc_automation_config[1], + ) + for test_psc_automation_config in _TEST_PSC_AUTOMATION_CONFIGS + ], + ), + metadata=_TEST_REQUEST_METADATA, + timeout=_TEST_TIMEOUT, + ) + @pytest.mark.usefixtures("get_index_endpoint_mock", "get_index_mock") def test_mutate_deployed_index(self, mutate_deployed_index_mock): aiplatform.init(project=_TEST_PROJECT) @@ -1381,6 +1657,62 @@ def test_index_private_service_access_endpoint_find_neighbor_queries( batch_match_request, metadata=mock.ANY ) + @pytest.mark.usefixtures("get_psc_automated_index_endpoint_mock") + def test_index_private_service_connect_automation_endpoint_find_neighbor_queries( + self, index_endpoint_match_queries_mock, grpc_insecure_channel_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_private_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_private_index_endpoint.find_neighbors( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + queries=_TEST_QUERIES, + num_neighbors=_TEST_NUM_NEIGHBOURS, + filter=_TEST_FILTER, + per_crowding_attribute_neighbor_count=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + return_full_datapoint=_TEST_RETURN_FULL_DATAPOINT, + numeric_filter=_TEST_NUMERIC_FILTER, + psc_network=_TEST_PSC_NETWORK1, + ) + + batch_match_request = match_service_pb2.BatchMatchRequest( + requests=[ + match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + requests=[ + match_service_pb2.MatchRequest( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + num_neighbors=_TEST_NUM_NEIGHBOURS, + float_val=test_query, + restricts=[ + match_service_pb2.Namespace( + name="class", + allow_tokens=["token_1"], + deny_tokens=["token_2"], + ) + ], + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + fraction_leaf_nodes_to_search_override=_TEST_FRACTION_LEAF_NODES_TO_SEARCH_OVERRIDE, + numeric_restricts=_TEST_NUMERIC_NAMESPACE, + ) + for test_query in _TEST_QUERIES + ], + ) + ] + ) + index_endpoint_match_queries_mock.BatchMatch.assert_called_with( + batch_match_request, metadata=mock.ANY + ) + grpc_insecure_channel_mock.assert_called_with( + _TEST_PRIVATE_SERVICE_CONNECT_AUTOMATION_URI_1 + ) + @pytest.mark.usefixtures("get_index_endpoint_mock") def test_index_private_service_access_endpoint_find_neighbor_queries_with_jwt( self, index_endpoint_match_queries_mock @@ -1486,6 +1818,110 @@ def test_index_private_service_connect_endpoint_match_queries( grpc_insecure_channel_mock.assert_called_with(_TEST_PRIVATE_SERVICE_CONNECT_URI) + @pytest.mark.usefixtures("get_psc_automated_index_endpoint_mock") + def test_index_private_service_connect_automation_match_queries( + self, index_endpoint_match_queries_mock, grpc_insecure_channel_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_index_endpoint.match( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + queries=_TEST_QUERIES, + num_neighbors=_TEST_NUM_NEIGHBOURS, + filter=_TEST_FILTER, + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + psc_network=_TEST_PSC_NETWORK1, + ) + + batch_request = match_service_pb2.BatchMatchRequest( + requests=[ + match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + requests=[ + match_service_pb2.MatchRequest( + num_neighbors=_TEST_NUM_NEIGHBOURS, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + float_val=_TEST_QUERIES[0], + restricts=[ + match_service_pb2.Namespace( + name="class", + allow_tokens=["token_1"], + deny_tokens=["token_2"], + ) + ], + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + ) + ], + ) + ] + ) + + index_endpoint_match_queries_mock.BatchMatch.assert_called_with( + batch_request, metadata=mock.ANY + ) + + grpc_insecure_channel_mock.assert_called_with( + _TEST_PRIVATE_SERVICE_CONNECT_AUTOMATION_URI_1 + ) + + @pytest.mark.usefixtures("get_psc_automated_index_endpoint_mock") + def test_index_private_service_connect_automation_match_queries_find_ip_address( + self, index_endpoint_match_queries_mock, grpc_insecure_channel_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + my_index_endpoint.match( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + queries=_TEST_QUERIES, + num_neighbors=_TEST_NUM_NEIGHBOURS, + filter=_TEST_FILTER, + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + psc_network=_TEST_PSC_NETWORK3, + ) + + batch_request = match_service_pb2.BatchMatchRequest( + requests=[ + match_service_pb2.BatchMatchRequest.BatchMatchRequestPerIndex( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + requests=[ + match_service_pb2.MatchRequest( + num_neighbors=_TEST_NUM_NEIGHBOURS, + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + float_val=_TEST_QUERIES[0], + restricts=[ + match_service_pb2.Namespace( + name="class", + allow_tokens=["token_1"], + deny_tokens=["token_2"], + ) + ], + per_crowding_attribute_num_neighbors=_TEST_PER_CROWDING_ATTRIBUTE_NUM_NEIGHBOURS, + approx_num_neighbors=_TEST_APPROX_NUM_NEIGHBORS, + ) + ], + ) + ] + ) + + index_endpoint_match_queries_mock.BatchMatch.assert_called_with( + batch_request, metadata=mock.ANY + ) + + grpc_insecure_channel_mock.assert_called_with( + _TEST_PRIVATE_SERVICE_CONNECT_AUTOMATION_URI_3 + ) + @pytest.mark.usefixtures("get_index_public_endpoint_mock") def test_index_public_endpoint_find_neighbors_queries_backward_compatibility( self, index_public_endpoint_match_queries_mock @@ -1784,7 +2220,7 @@ def test_index_public_endpoint_read_index_datapoints( read_index_datapoints_request ) - @pytest.mark.usefixtures("get_index_endpoint_mock") + @pytest.mark.usefixtures("get_psa_index_endpoint_mock") def test_index_endpoint_batch_get_embeddings( self, index_endpoint_batch_get_embeddings_mock ): @@ -1806,7 +2242,7 @@ def test_index_endpoint_batch_get_embeddings( batch_request, metadata=mock.ANY ) - @pytest.mark.usefixtures("get_index_endpoint_mock") + @pytest.mark.usefixtures("get_psa_index_endpoint_mock") def test_index_endpoint_read_index_datapoints_for_private_service_access( self, index_endpoint_batch_get_embeddings_mock ): @@ -1856,7 +2292,7 @@ def test_index_endpoint_read_index_datapoints_for_private_service_access_with_jw assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE - @pytest.mark.usefixtures("get_index_endpoint_mock") + @pytest.mark.usefixtures("get_manual_psc_index_endpoint_mock") def test_index_endpoint_read_index_datapoints_for_private_service_connect( self, grpc_insecure_channel_mock, index_endpoint_batch_get_embeddings_mock ): @@ -1886,6 +2322,37 @@ def test_index_endpoint_read_index_datapoints_for_private_service_connect( assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE + @pytest.mark.usefixtures("get_psc_automated_index_endpoint_mock") + def test_index_endpoint_read_index_datapoints_for_private_service_connect_automation( + self, index_endpoint_batch_get_embeddings_mock, grpc_insecure_channel_mock + ): + aiplatform.init(project=_TEST_PROJECT) + + my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint( + index_endpoint_name=_TEST_INDEX_ENDPOINT_ID + ) + + response = my_index_endpoint.read_index_datapoints( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + ids=["1", "2"], + psc_network=_TEST_PSC_NETWORK1, + ) + + batch_request = match_service_pb2.BatchGetEmbeddingsRequest( + deployed_index_id=_TEST_DEPLOYED_INDEX_ID, + id=["1", "2"], + ) + + index_endpoint_batch_get_embeddings_mock.BatchGetEmbeddings.assert_called_with( + batch_request, metadata=mock.ANY + ) + + grpc_insecure_channel_mock.assert_called_with( + _TEST_PRIVATE_SERVICE_CONNECT_AUTOMATION_URI_1 + ) + + assert response == _TEST_READ_INDEX_DATAPOINTS_RESPONSE + class TestMatchNeighbor: def test_from_index_datapoint(self):