Skip to content

Commit

Permalink
chore: Add vector search samples for VPC peering and PSC
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691561705
  • Loading branch information
vertex-sdk-bot authored and copybara-github committed Oct 30, 2024
1 parent c0718e1 commit 7edeedb
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,7 @@ def create(
Private services access must already be configured for the network.
If left unspecified, the network set with aiplatform.init will be used.
`Format <https://cloud.google.com/compute/docs/reference/rest/v1/networks/insert>`__:
projects/{project}/global/networks/{network}. Where
Format: projects/{project}/global/networks/{network}, where
{project} is a project number, as in '12345', and {network}
is network name.
public_endpoint_enabled (bool):
Expand Down
2 changes: 2 additions & 0 deletions samples/model-builder/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,3 +408,5 @@
VECTOR_SEARCH_GCS_URI = "gs://fake-dir"
VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME = "my-vector-search-index-endpoint"
VECTOR_SEARCH_PRIVATE_ENDPOINT_SIGNED_JWT = "fake-signed-jwt"
VECTOR_SEARCH_VPC_NETWORK = "vpc-network"
VECTOR_SEARCH_PSC_PROJECT_ALLOWLIST = ["test-project", "test-project-2"]
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,68 @@ def vector_search_create_index_endpoint(


# [END aiplatform_sdk_vector_search_create_index_endpoint_sample]


# [START aiplatform_sdk_vector_search_create_index_endpoint_vpc_sample]
def vector_search_create_index_endpoint_vpc(
project: str, location: str, display_name: str, network: str
) -> aiplatform.MatchingEngineIndexEndpoint:
"""Create a vector search index endpoint within a VPC network.
Args:
project (str): Required. Project ID
location (str): Required. The region name
display_name (str): Required. The index endpoint display name
network(str): Required. The VPC network name, in the format of
projects/{project number}/global/networks/{network name}.
Returns:
aiplatform.MatchingEngineIndexEndpoint - The created index endpoint.
"""
# Initialize the Vertex AI client
aiplatform.init(project=project, location=location)

# Create Index Endpoint
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=display_name,
network=network,
description="Matching Engine VPC Index Endpoint",
)

return index_endpoint


# [END aiplatform_sdk_vector_search_create_index_endpoint_vpc_sample]


# [START aiplatform_sdk_vector_search_create_index_endpoint_psc_sample]
def vector_search_create_index_endpoint_private_service_connect(
project: str, location: str, display_name: str, project_allowlist: list[str]
) -> aiplatform.MatchingEngineIndexEndpoint:
"""Create a vector search index endpoint with Private Service Connect enabled.
Args:
project (str): Required. Project ID
location (str): Required. The region name
display_name (str): Required. The index endpoint display name
project_allowlist (list[str]): Required. A list of projects from which
the forwarding rule will be able to target the service attachment.
Returns:
aiplatform.MatchingEngineIndexEndpoint - The created index endpoint.
"""
# Initialize the Vertex AI client
aiplatform.init(project=project, location=location)

# Create Index Endpoint
index_endpoint = aiplatform.MatchingEngineIndexEndpoint.create(
display_name=display_name,
description="Matching Engine VPC Index Endpoint",
enable_private_service_connect=True,
project_allowlist=project_allowlist,
)

return index_endpoint


# [END aiplatform_sdk_vector_search_create_index_endpoint_psc_sample]
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest.mock import ANY
from unittest import mock

import test_constants as constants
from vector_search import vector_search_create_index_endpoint_sample
Expand All @@ -38,5 +38,56 @@ def test_vector_search_create_index_endpoint_sample(
mock_index_endpoint_create.assert_called_with(
display_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME,
public_endpoint_enabled=True,
description=ANY,
description=mock.ANY,
)


def test_vector_search_create_index_endpoint_vpc_sample(
mock_sdk_init,
mock_index_endpoint_create,
):
vector_search_create_index_endpoint_sample.vector_search_create_index_endpoint_vpc(
project=constants.PROJECT,
location=constants.LOCATION,
display_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME,
network=constants.VECTOR_SEARCH_VPC_NETWORK,
)

# Check client initialization
mock_sdk_init.assert_called_with(
project=constants.PROJECT,
location=constants.LOCATION,
)

# Check index creation
mock_index_endpoint_create.assert_called_with(
display_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME,
description=mock.ANY,
network=constants.VECTOR_SEARCH_VPC_NETWORK,
)


def test_vector_search_create_index_endpoint_psc_sample(
mock_sdk_init,
mock_index_endpoint_create,
):
vector_search_create_index_endpoint_sample.vector_search_create_index_endpoint_private_service_connect(
project=constants.PROJECT,
location=constants.LOCATION,
display_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME,
project_allowlist=constants.VECTOR_SEARCH_PSC_PROJECT_ALLOWLIST,
)

# Check client initialization
mock_sdk_init.assert_called_with(
project=constants.PROJECT,
location=constants.LOCATION,
)

# Check index creation
mock_index_endpoint_create.assert_called_with(
display_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT_DISPLAY_NAME,
description=mock.ANY,
enable_private_service_connect=True,
project_allowlist=constants.VECTOR_SEARCH_PSC_PROJECT_ALLOWLIST,
)
64 changes: 64 additions & 0 deletions samples/model-builder/vector_search/vector_search_match_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,70 @@
from google.cloud import aiplatform


# [START aiplatform_sdk_vector_search_match_hybrid_queries_sample]
def vector_search_match_hybrid_queries(
project: str,
location: str,
index_endpoint_name: str,
deployed_index_id: str,
num_neighbors: int,
) -> List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]]:
"""Query the vector search index.
Args:
project (str): Required. Project ID
location (str): Required. The region name
index_endpoint_name (str): Required. Index endpoint to run the query
against. The endpoint must be a private endpoint.
deployed_index_id (str): Required. The ID of the DeployedIndex to run
the queries against.
num_neighbors (int): Required. The number of neighbors to return.
Returns:
List[List[aiplatform.matching_engine.matching_engine_index_endpoint.MatchNeighbor]] - A list of nearest neighbors for each query.
"""
# Initialize the Vertex AI client
aiplatform.init(project=project, location=location)

# Create the index endpoint instance from an existing endpoint.
my_index_endpoint = aiplatform.MatchingEngineIndexEndpoint(
index_endpoint_name=index_endpoint_name
)

# Example queries containing hybrid datapoints, sparse-only datapoints, and
# dense-only datapoints.
hybrid_queries = [
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
dense_embedding=[1, 2, 3],
sparse_embedding_dimensions=[10, 20, 30],
sparse_embedding_values=[1.0, 1.0, 1.0],
rrf_ranking_alpha=0.5,
),
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
dense_embedding=[1, 2, 3],
sparse_embedding_dimensions=[10, 20, 30],
sparse_embedding_values=[0.1, 0.2, 0.3],
),
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
sparse_embedding_dimensions=[10, 20, 30],
sparse_embedding_values=[0.1, 0.2, 0.3],
),
aiplatform.matching_engine.matching_engine_index_endpoint.HybridQuery(
dense_embedding=[1, 2, 3]
),
]

# Query the index endpoint for matches.
resp = my_index_endpoint.match(
deployed_index_id=deployed_index_id,
queries=hybrid_queries,
num_neighbors=num_neighbors,
)
return resp

# [END aiplatform_sdk_vector_search_match_hybrid_queries_sample]


# [START aiplatform_sdk_vector_search_match_jwt_sample]
def vector_search_match_jwt(
project: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import test_constants as constants
from vector_search import vector_search_match_sample


def test_vector_search_match_hybrid_queries_sample(
mock_sdk_init, mock_index_endpoint_init, mock_index_endpoint_match
):
vector_search_match_sample.vector_search_match_hybrid_queries(
project=constants.PROJECT,
location=constants.LOCATION,
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT,
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
num_neighbors=10,
)

# Check client initialization
mock_sdk_init.assert_called_with(
project=constants.PROJECT, location=constants.LOCATION
)

# Check index endpoint initialization with right index endpoint name
mock_index_endpoint_init.assert_called_with(
index_endpoint_name=constants.VECTOR_SEARCH_INDEX_ENDPOINT)

# Check index_endpoint.match is called with right params.
mock_index_endpoint_match.assert_called_with(
deployed_index_id=constants.VECTOR_SEARCH_DEPLOYED_INDEX_ID,
queries=mock.ANY,
num_neighbors=10,
)


def test_vector_search_match_jwt_sample(
mock_sdk_init, mock_index_endpoint_init, mock_index_endpoint_match
):
Expand Down

0 comments on commit 7edeedb

Please sign in to comment.