From 5015d25c5efdb9ba0a01bc60441f7eb8d5fddc52 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Thu, 28 Mar 2024 17:34:59 -0700 Subject: [PATCH] feat: Python SDK for Vertex Feature Store. PiperOrigin-RevId: 620105958 --- tests/unit/vertexai/conftest.py | 166 +++++- .../unit/vertexai/feature_store_constants.py | 264 +++++++++ .../vertexai/test_feature_online_store.py | 500 +++++++++++++++++ tests/unit/vertexai/test_feature_view.py | 508 +++++++++++++++++ vertexai/resources/preview/__init__.py | 21 + .../preview/feature_store/__init__.py | 48 ++ .../feature_store/feature_online_store.py | 524 ++++++++++++++++++ .../preview/feature_store/feature_view.py | 429 ++++++++++++++ .../resources/preview/feature_store/utils.py | 156 ++++++ 9 files changed, 2615 insertions(+), 1 deletion(-) create mode 100644 tests/unit/vertexai/feature_store_constants.py create mode 100644 tests/unit/vertexai/test_feature_online_store.py create mode 100644 tests/unit/vertexai/test_feature_view.py create mode 100644 vertexai/resources/preview/feature_store/__init__.py create mode 100644 vertexai/resources/preview/feature_store/feature_online_store.py create mode 100644 vertexai/resources/preview/feature_store/feature_view.py create mode 100644 vertexai/resources/preview/feature_store/utils.py diff --git a/tests/unit/vertexai/conftest.py b/tests/unit/vertexai/conftest.py index 8c3a4abbe2..3baa482abb 100644 --- a/tests/unit/vertexai/conftest.py +++ b/tests/unit/vertexai/conftest.py @@ -20,6 +20,7 @@ import tempfile from typing import Any from unittest import mock +from unittest.mock import patch import uuid from google import auth @@ -47,7 +48,20 @@ ResourceRuntimeSpec, ServiceAccountSpec, ) - +from google.cloud.aiplatform.compat.services import ( + feature_online_store_admin_service_client, +) +from feature_store_constants import ( + _TEST_BIGTABLE_FOS1, + _TEST_EMBEDDING_FV1, + _TEST_ESF_OPTIMIZED_FOS, + _TEST_ESF_OPTIMIZED_FOS2, + _TEST_FV1, + _TEST_OPTIMIZED_FV1, + _TEST_OPTIMIZED_FV2, + _TEST_PSC_OPTIMIZED_FOS, + _TEST_OPTIMIZED_EMBEDDING_FV, +) _TEST_PROJECT = "test-project" _TEST_PROJECT_NUMBER = "12345678" @@ -332,3 +346,153 @@ def create_persistent_resource_default_mock(): create_persistent_resource_lro_mock ) yield create_persistent_resource_default_mock + + +@pytest.fixture +def get_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_online_store", + ) as get_fos_mock: + get_fos_mock.return_value = _TEST_BIGTABLE_FOS1 + yield get_fos_mock + + +@pytest.fixture +def get_esf_optimized_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_online_store", + ) as get_fos_mock: + get_fos_mock.return_value = _TEST_ESF_OPTIMIZED_FOS + yield get_fos_mock + + +@pytest.fixture +def get_psc_optimized_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_online_store", + ) as get_fos_mock: + get_fos_mock.return_value = _TEST_PSC_OPTIMIZED_FOS + yield get_fos_mock + + +@pytest.fixture +def get_esf_optimized_fos_no_endpoint_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_online_store", + ) as get_fos_mock: + get_fos_mock.return_value = _TEST_ESF_OPTIMIZED_FOS2 + yield get_fos_mock + + +@pytest.fixture +def create_bigtable_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_online_store", + ) as create_bigtable_fos_mock: + create_fos_lro_mock = mock.Mock(ga_operation.Operation) + create_fos_lro_mock.result.return_value = _TEST_BIGTABLE_FOS1 + create_bigtable_fos_mock.return_value = create_fos_lro_mock + yield create_bigtable_fos_mock + + +@pytest.fixture +def create_esf_optimized_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_online_store", + ) as create_esf_optimized_fos_mock: + create_fos_lro_mock = mock.Mock(ga_operation.Operation) + create_fos_lro_mock.result.return_value = _TEST_ESF_OPTIMIZED_FOS + create_esf_optimized_fos_mock.return_value = create_fos_lro_mock + yield create_esf_optimized_fos_mock + + +@pytest.fixture +def create_psc_optimized_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_online_store", + ) as create_psc_optimized_fos_mock: + create_fos_lro_mock = mock.Mock(ga_operation.Operation) + create_fos_lro_mock.result.return_value = _TEST_PSC_OPTIMIZED_FOS + create_psc_optimized_fos_mock.return_value = create_fos_lro_mock + yield create_psc_optimized_fos_mock + + +@pytest.fixture +def get_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_fv_mock: + get_fv_mock.return_value = _TEST_FV1 + yield get_fv_mock + + +@pytest.fixture +def create_bq_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_view", + ) as create_bq_fv_mock: + create_bq_fv_lro_mock = mock.Mock(ga_operation.Operation) + create_bq_fv_lro_mock.result.return_value = _TEST_FV1 + create_bq_fv_mock.return_value = create_bq_fv_lro_mock + yield create_bq_fv_mock + + +@pytest.fixture +def create_embedding_fv_from_bq_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "create_feature_view", + ) as create_embedding_fv_mock: + create_embedding_fv_mock_lro = mock.Mock(ga_operation.Operation) + create_embedding_fv_mock_lro.result.return_value = _TEST_OPTIMIZED_EMBEDDING_FV + create_embedding_fv_mock.return_value = create_embedding_fv_mock_lro + yield create_embedding_fv_mock + + +@pytest.fixture +def get_optimized_embedding_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_fv_mock: + get_fv_mock.return_value = _TEST_OPTIMIZED_EMBEDDING_FV + yield get_fv_mock + + +@pytest.fixture +def get_optimized_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_optimized_fv_mock: + get_optimized_fv_mock.return_value = _TEST_OPTIMIZED_FV1 + yield get_optimized_fv_mock + + +@pytest.fixture +def get_embedding_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_embedding_fv_mock: + get_embedding_fv_mock.return_value = _TEST_EMBEDDING_FV1 + yield get_embedding_fv_mock + + +@pytest.fixture +def get_optimized_fv_no_endpointmock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view", + ) as get_optimized_fv_no_endpointmock: + get_optimized_fv_no_endpointmock.return_value = _TEST_OPTIMIZED_FV2 + yield get_optimized_fv_no_endpointmock diff --git a/tests/unit/vertexai/feature_store_constants.py b/tests/unit/vertexai/feature_store_constants.py new file mode 100644 index 0000000000..f13fc1ab0d --- /dev/null +++ b/tests/unit/vertexai/feature_store_constants.py @@ -0,0 +1,264 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from google.cloud.aiplatform.compat import types + +_TEST_PROJECT = "test-project" +_TEST_LOCATION = "us-central1" +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" + +# Test feature online store 1 +_TEST_BIGTABLE_FOS1_ID = "my_fos1" +_TEST_BIGTABLE_FOS1_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_BIGTABLE_FOS1_ID}" +) +_TEST_BIGTABLE_FOS1_LABELS = {"my_key": "my_fos1"} +_TEST_BIGTABLE_FOS1 = types.feature_online_store.FeatureOnlineStore( + name=_TEST_BIGTABLE_FOS1_PATH, + bigtable=types.feature_online_store.FeatureOnlineStore.Bigtable( + auto_scaling=types.feature_online_store.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=1, + max_node_count=2, + cpu_utilization_target=50, + ) + ), + labels=_TEST_BIGTABLE_FOS1_LABELS, +) + +# Test feature online store 2 +_TEST_BIGTABLE_FOS2_ID = "my_fos2" +_TEST_BIGTABLE_FOS2_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_BIGTABLE_FOS2_ID}" +) +_TEST_BIGTABLE_FOS2_LABELS = {"my_key": "my_fos2"} +_TEST_BIGTABLE_FOS2 = types.feature_online_store.FeatureOnlineStore( + name=_TEST_BIGTABLE_FOS2_PATH, + bigtable=types.feature_online_store.FeatureOnlineStore.Bigtable( + auto_scaling=types.feature_online_store.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=2, + max_node_count=3, + cpu_utilization_target=60, + ) + ), + labels=_TEST_BIGTABLE_FOS2_LABELS, +) + +# Test feature online store 3 +_TEST_BIGTABLE_FOS3_ID = "my_fos3" +_TEST_BIGTABLE_FOS3_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_BIGTABLE_FOS3_ID}" +) +_TEST_BIGTABLE_FOS3_LABELS = {"my_key": "my_fos3"} +_TEST_BIGTABLE_FOS3 = types.feature_online_store.FeatureOnlineStore( + name=_TEST_BIGTABLE_FOS3_PATH, + bigtable=types.feature_online_store.FeatureOnlineStore.Bigtable( + auto_scaling=types.feature_online_store.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=3, + max_node_count=4, + cpu_utilization_target=70, + ) + ), + labels=_TEST_BIGTABLE_FOS3_LABELS, +) + +# Test feature online store for optimized with esf endpoint +_TEST_ESF_OPTIMIZED_FOS_ID = "my_esf_optimized_fos" +_TEST_ESF_OPTIMIZED_FOS_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_ESF_OPTIMIZED_FOS_ID}" +) +_TEST_ESF_OPTIMIZED_FOS_LABELS = {"my_key": "my_esf_optimized_fos"} +_TEST_ESF_OPTIMIZED_FOS = types.feature_online_store.FeatureOnlineStore( + name=_TEST_ESF_OPTIMIZED_FOS_PATH, + optimized=types.feature_online_store.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint( + public_endpoint_domain_name="test-esf-endpoint", + ), + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, +) + +# Test feature online store for optimized with psc endpoint +_TEST_PSC_OPTIMIZED_FOS_ID = "my_psc_optimized_fos" +_TEST_PSC_OPTIMIZED_FOS_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_PSC_OPTIMIZED_FOS_ID}" +) +_TEST_PSC_OPTIMIZED_FOS_LABELS = {"my_key": "my_psc_optimized_fos"} +_TEST_PSC_PROJECT_ALLOWLIST = ["project-1", "project-2"] +_TEST_PSC_OPTIMIZED_FOS = types.feature_online_store_v1.FeatureOnlineStore( + name=_TEST_PSC_OPTIMIZED_FOS_PATH, + optimized=types.feature_online_store_v1.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint(), + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, +) + +_TEST_FOS_LIST = [_TEST_BIGTABLE_FOS1, _TEST_BIGTABLE_FOS2, _TEST_BIGTABLE_FOS3] + +# Test feature online store for optimized with esf endpoint but sync has not run yet. +_TEST_ESF_OPTIMIZED_FOS2_ID = "my_esf_optimised_fos2" +_TEST_ESF_OPTIMIZED_FOS2_PATH = ( + f"{_TEST_PARENT}/featureOnlineStores/{_TEST_ESF_OPTIMIZED_FOS2_ID}" +) +_TEST_ESF_OPTIMIZED_FOS2_LABELS = {"my_key": "my_esf_optimized_fos2"} +_TEST_ESF_OPTIMIZED_FOS2 = types.feature_online_store_v1.FeatureOnlineStore( + name=_TEST_ESF_OPTIMIZED_FOS2_PATH, + optimized=types.feature_online_store_v1.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint(), + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, +) + + +# Test feature view 1 +_TEST_FV1_ID = "my_fv1" +_TEST_FV1_PATH = f"{_TEST_BIGTABLE_FOS1_PATH}/featureViews/my_fv1" +_TEST_FV1_LABELS = {"my_key": "my_fv1"} +_TEST_FV1_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table" +_TEST_FV1_ENTITY_ID_COLUMNS = ["entity_id"] +_TEST_FV1 = types.feature_view.FeatureView( + name=_TEST_FV1_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, +) + +# Test feature view 2 +_TEST_FV2_ID = "my_fv2" +_TEST_FV2_PATH = f"{_TEST_BIGTABLE_FOS1_PATH}/featureViews/my_fv2" +_TEST_FV2_LABELS = {"my_key": "my_fv2"} +_TEST_FV2_BQ_URI = f"bq://{_TEST_PROJECT}.my_dataset.my_table" +_TEST_FV2_ENTITY_ID_COLUMNS = ["entity_id"] +_TEST_FV2 = types.feature_view.FeatureView( + name=_TEST_FV2_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV2_BQ_URI, + entity_id_columns=_TEST_FV2_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV2_LABELS, +) + +_TEST_FV_LIST = [_TEST_FV1, _TEST_FV2] + +# Test feature view sync 1 +_TEST_FV_SYNC1_ID = "my_fv_sync1" +_TEST_FV_SYNC1_PATH = f"{_TEST_FV1_PATH}/featureViewSyncs/my_fv_sync1" +_TEST_FV_SYNC1 = types.feature_view_sync.FeatureViewSync( + name=_TEST_FV_SYNC1_PATH, +) +_TEST_FV_SYNC1_RESPONSE = ( + types.feature_online_store_admin_service.SyncFeatureViewResponse( + feature_view_sync=_TEST_FV_SYNC1_PATH, + ) +) + +# Test feature view sync 2 +_TEST_FV_SYNC2_ID = "my_fv_sync2" +_TEST_FV_SYNC2_PATH = f"{_TEST_FV2_PATH}/featureViewSyncs/my_fv_sync2" +_TEST_FV_SYNC2 = types.feature_view_sync.FeatureViewSync( + name=_TEST_FV_SYNC2_PATH, +) + +_TEST_FV_SYNC_LIST = [_TEST_FV_SYNC1, _TEST_FV_SYNC2] + +# Test optimized feature view 1 +_TEST_OPTIMIZED_FV1_ID = "optimized_fv1" +_TEST_OPTIMIZED_FV1_PATH = f"{_TEST_ESF_OPTIMIZED_FOS_PATH}/featureViews/optimized_fv1" +_TEST_OPTIMIZED_FV1 = types.feature_view.FeatureView( + name=_TEST_OPTIMIZED_FV1_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, +) + +# Test optimized feature view 2 +_TEST_OPTIMIZED_FV2_ID = "optimized_fv2" +_TEST_OPTIMIZED_FV2_PATH = f"{_TEST_ESF_OPTIMIZED_FOS2_PATH}/featureViews/optimized_fv2" +_TEST_OPTIMIZED_FV2 = types.feature_view.FeatureView( + name=_TEST_OPTIMIZED_FV2_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, +) + +# Test embedding feature view 1 +_TEST_EMBEDDING_FV1_ID = "embedding_fv1" +_TEST_EMBEDDING_FV1_PATH = f"{_TEST_ESF_OPTIMIZED_FOS_PATH}/featureViews/embedding_fv1" +_TEST_EMBEDDING_FV1 = types.feature_view.FeatureView( + name=_TEST_EMBEDDING_FV1_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, +) + +_TEST_STRING_FILTER = ( + types.feature_online_store_service.NearestNeighborQuery.StringFilter( + name="filter_name", + allow_tokens=["allow_token_1", "allow_token_2"], + ) +) + +# Test optimized embedding feature view +_TEST_OPTIMIZED_EMBEDDING_FV_ID = "optimized_embedding_fv" +_TEST_OPTIMIZED_EMBEDDING_FV_PATH = ( + f"{_TEST_ESF_OPTIMIZED_FOS_PATH}/featureViews/{_TEST_OPTIMIZED_EMBEDDING_FV_ID}" +) +_TEST_OPTIMIZED_EMBEDDING_FV = types.feature_view.FeatureView( + name=_TEST_OPTIMIZED_EMBEDDING_FV_PATH, + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, + index_config=types.feature_view.FeatureView.IndexConfig( + embedding_column="embedding_column", + filter_columns=["col1", "col2"], + crowding_column="crowding_column", + embedding_dimension=123, + distance_measure_type=types.feature_view.FeatureView.IndexConfig.DistanceMeasureType.DOT_PRODUCT_DISTANCE, + ), +) + +# Response for FetchFeatureValues +_TEST_FV_FETCH1 = types.feature_online_store_service_v1.FetchFeatureValuesResponse( + key_values=types.feature_online_store_service_v1.FetchFeatureValuesResponse.FeatureNameValuePairList( + features=[ + types.feature_online_store_service_v1.FetchFeatureValuesResponse.FeatureNameValuePairList.FeatureNameValuePair( + name="key1", + value=types.featurestore_online_service.FeatureValue( + string_value="value1", + ), + ), + ] + ) +) + +# Response for SearchNearestEntitiesResponse +_TEST_FV_SEARCH1 = types.feature_online_store_service_v1.SearchNearestEntitiesResponse( + nearest_neighbors=types.feature_online_store_service_v1.NearestNeighbors( + neighbors=[ + types.feature_online_store_service_v1.NearestNeighbors.Neighbor( + entity_id="neighbor_entity_id_1", + distance=0.1, + ), + ] + ) +) diff --git a/tests/unit/vertexai/test_feature_online_store.py b/tests/unit/vertexai/test_feature_online_store.py new file mode 100644 index 0000000000..de26dd42b2 --- /dev/null +++ b/tests/unit/vertexai/test_feature_online_store.py @@ -0,0 +1,500 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from unittest.mock import call +import re +from unittest import mock +from unittest.mock import patch +from typing import Dict + +from google.api_core import operation as ga_operation +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform.compat import types +from vertexai.resources.preview import ( + FeatureOnlineStore, + FeatureOnlineStoreType, + FeatureView, + IndexConfig, + DistanceMeasureType, + TreeAhConfig, +) +from vertexai.resources.preview.feature_store import ( + feature_online_store, +) +from google.cloud.aiplatform.compat.services import ( + feature_online_store_admin_service_client, +) +import pytest + +from test_feature_view import fv_eq +from feature_store_constants import ( + _TEST_PROJECT, + _TEST_LOCATION, + _TEST_PARENT, + _TEST_BIGTABLE_FOS1_ID, + _TEST_BIGTABLE_FOS1_PATH, + _TEST_BIGTABLE_FOS1_LABELS, + _TEST_BIGTABLE_FOS2_ID, + _TEST_BIGTABLE_FOS2_PATH, + _TEST_BIGTABLE_FOS2_LABELS, + _TEST_BIGTABLE_FOS3_ID, + _TEST_BIGTABLE_FOS3_PATH, + _TEST_BIGTABLE_FOS3_LABELS, + _TEST_ESF_OPTIMIZED_FOS_ID, + _TEST_ESF_OPTIMIZED_FOS_PATH, + _TEST_ESF_OPTIMIZED_FOS_LABELS, + _TEST_PSC_OPTIMIZED_FOS_ID, + _TEST_PSC_OPTIMIZED_FOS_LABELS, + _TEST_PSC_PROJECT_ALLOWLIST, + _TEST_FOS_LIST, + _TEST_FV1_ID, + _TEST_FV1_PATH, + _TEST_FV1_LABELS, + _TEST_FV1_BQ_URI, + _TEST_FV1_ENTITY_ID_COLUMNS, + _TEST_OPTIMIZED_EMBEDDING_FV_ID, + _TEST_OPTIMIZED_EMBEDDING_FV_PATH, +) + + +@pytest.fixture +def fos_logger_mock(): + with patch.object( + feature_online_store._LOGGER, + "info", + wraps=feature_online_store._LOGGER.info, + ) as logger_mock: + yield logger_mock + + +@pytest.fixture +def list_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "list_feature_online_stores", + ) as list_fos_mock: + list_fos_mock.return_value = _TEST_FOS_LIST + yield list_fos_mock + + +@pytest.fixture +def delete_fos_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "delete_feature_online_store", + ) as delete_fos_mock: + delete_fos_lro_mock = mock.Mock(ga_operation.Operation) + delete_fos_mock.return_value = delete_fos_lro_mock + yield delete_fos_mock + + +def fos_eq( + fos_to_check: FeatureOnlineStore, + name: str, + resource_name: str, + project: str, + location: str, + labels: Dict[str, str], + type: FeatureOnlineStoreType, +): + """Check if a FeatureOnlineStore has the appropriate values set.""" + assert fos_to_check.name == name + assert fos_to_check.resource_name == resource_name + assert fos_to_check.project == project + assert fos_to_check.location == location + assert fos_to_check.labels == labels + assert fos_to_check.feature_online_store_type == type + + +pytestmark = pytest.mark.usefixtures("google_auth_mock") + + +@pytest.mark.parametrize( + "online_store_name", + [_TEST_BIGTABLE_FOS1_ID, _TEST_BIGTABLE_FOS1_PATH], +) +def test_init(online_store_name, get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fos = FeatureOnlineStore(online_store_name) + + get_fos_mock.assert_called_once_with( + name=_TEST_BIGTABLE_FOS1_PATH, retry=base._DEFAULT_RETRY + ) + + fos_eq( + fos, + name=_TEST_BIGTABLE_FOS1_ID, + resource_name=_TEST_BIGTABLE_FOS1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS1_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +def test_create( + create_request_timeout, + create_bigtable_fos_mock, + get_fos_mock, + fos_logger_mock, + sync=True, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fos = FeatureOnlineStore.create_bigtable_store( + _TEST_BIGTABLE_FOS1_ID, + labels=_TEST_BIGTABLE_FOS1_LABELS, + create_request_timeout=create_request_timeout, + sync=sync, + ) + + if not sync: + fos.wait() + + # When creating, the FeatureOnlineStore object doesn't have the path set. + expected_feature_online_store = types.feature_online_store_v1.FeatureOnlineStore( + bigtable=types.feature_online_store_v1.FeatureOnlineStore.Bigtable( + auto_scaling=types.feature_online_store_v1.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=1, + max_node_count=1, + cpu_utilization_target=50, + ) + ), + labels=_TEST_BIGTABLE_FOS1_LABELS, + ) + create_bigtable_fos_mock.assert_called_once_with( + parent=_TEST_PARENT, + feature_online_store=expected_feature_online_store, + feature_online_store_id=_TEST_BIGTABLE_FOS1_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureOnlineStore"), + call( + f"Create FeatureOnlineStore backing LRO: {create_bigtable_fos_mock.return_value.operation.name}" + ), + call( + "FeatureOnlineStore created. Resource name: projects/test-project/locations/us-central1/featureOnlineStores/my_fos1" + ), + call("To use this FeatureOnlineStore in another session:"), + call( + "feature_online_store = aiplatform.FeatureOnlineStore('projects/test-project/locations/us-central1/featureOnlineStores/my_fos1')" + ), + ] + ) + + fos_eq( + fos, + name=_TEST_BIGTABLE_FOS1_ID, + resource_name=_TEST_BIGTABLE_FOS1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS1_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +def test_create_esf_optimized_store( + create_request_timeout, + create_esf_optimized_fos_mock, + get_esf_optimized_fos_mock, + fos_logger_mock, + sync=True, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fos = FeatureOnlineStore.create_optimized_store( + _TEST_ESF_OPTIMIZED_FOS_ID, + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, + create_request_timeout=create_request_timeout, + sync=sync, + ) + + if not sync: + fos.wait() + + expected_feature_online_store = types.feature_online_store_v1.FeatureOnlineStore( + optimized=types.feature_online_store_v1.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=types.feature_online_store_v1.FeatureOnlineStore.DedicatedServingEndpoint(), + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, + ) + create_esf_optimized_fos_mock.assert_called_once_with( + parent=_TEST_PARENT, + feature_online_store=expected_feature_online_store, + feature_online_store_id=_TEST_ESF_OPTIMIZED_FOS_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureOnlineStore"), + call( + "Create FeatureOnlineStore backing LRO:" + f" {create_esf_optimized_fos_mock.return_value.operation.name}" + ), + call( + "FeatureOnlineStore created. Resource name:" + " projects/test-project/locations/us-central1/featureOnlineStores/my_esf_optimized_fos" + ), + call("To use this FeatureOnlineStore in another session:"), + call( + "feature_online_store =" + " aiplatform.FeatureOnlineStore('projects/test-project/locations/us-central1/featureOnlineStores/my_esf_optimized_fos')" + ), + ] + ) + + fos_eq( + fos, + name=_TEST_ESF_OPTIMIZED_FOS_ID, + resource_name=_TEST_ESF_OPTIMIZED_FOS_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_ESF_OPTIMIZED_FOS_LABELS, + type=FeatureOnlineStoreType.OPTIMIZED, + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +def test_create_psc_optimized_store( + create_request_timeout, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + with pytest.raises( + ValueError, + match=re.escape("private_service_connect is not supported"), + ): + FeatureOnlineStore.create_optimized_store( + _TEST_PSC_OPTIMIZED_FOS_ID, + labels=_TEST_PSC_OPTIMIZED_FOS_LABELS, + create_request_timeout=create_request_timeout, + enable_private_service_connect=True, + project_allowlist=_TEST_PSC_PROJECT_ALLOWLIST, + ) + + +def test_list(list_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + online_stores = FeatureOnlineStore.list() + + list_fos_mock.assert_called_once_with(request={"parent": _TEST_PARENT}) + assert len(online_stores) == len(_TEST_FOS_LIST) + fos_eq( + online_stores[0], + name=_TEST_BIGTABLE_FOS1_ID, + resource_name=_TEST_BIGTABLE_FOS1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS1_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + fos_eq( + online_stores[1], + name=_TEST_BIGTABLE_FOS2_ID, + resource_name=_TEST_BIGTABLE_FOS2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS2_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + fos_eq( + online_stores[2], + name=_TEST_BIGTABLE_FOS3_ID, + resource_name=_TEST_BIGTABLE_FOS3_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_BIGTABLE_FOS3_LABELS, + type=FeatureOnlineStoreType.BIGTABLE, + ) + + +@pytest.mark.parametrize("force", [True, False]) +def test_delete(force, delete_fos_mock, get_fos_mock, fos_logger_mock, sync=True): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + fos.delete(force=force, sync=sync) + + if not sync: + fos.wait() + + delete_fos_mock.assert_called_once_with( + name=_TEST_BIGTABLE_FOS1_PATH, + force=force, + ) + + fos_logger_mock.assert_has_calls( + [ + call( + "Deleting FeatureOnlineStore resource: projects/test-project/locations/us-central1/featureOnlineStores/my_fos1" + ), + call( + f"Delete FeatureOnlineStore backing LRO: {delete_fos_mock.return_value.operation.name}" + ), + call( + "FeatureOnlineStore resource projects/test-project/locations/us-central1/featureOnlineStores/my_fos1 deleted." + ), + ] + ) + + +def test_create_bq_fv_none_source_raises_error(get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + with pytest.raises( + ValueError, + match=re.escape("Please specify valid big_query_source."), + ): + fos.create_feature_view_from_big_query("bq_fv", None) + + +def test_create_bq_fv_bad_uri_raises_error(get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + with pytest.raises( + ValueError, + match=re.escape("Please specify URI in big_query_source."), + ): + fos.create_feature_view_from_big_query( + "bq_fv", + FeatureView.BigQuerySource(uri=None, entity_id_columns=["entity_id"]), + ) + + +@pytest.mark.parametrize("entity_id_columns", [None, []]) +def test_create_bq_fv_bad_entity_id_columns_raises_error( + entity_id_columns, get_fos_mock +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + with pytest.raises( + ValueError, + match=re.escape("Please specify entity ID columns in big_query_source."), + ): + fos.create_feature_view_from_big_query( + "bq_fv", + FeatureView.BigQuerySource(uri="hi", entity_id_columns=entity_id_columns), + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_bq_fv( + create_request_timeout, + sync, + get_fos_mock, + create_bq_fv_mock, + get_fv_mock, + fos_logger_mock, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_BIGTABLE_FOS1_ID) + + fv = fos.create_feature_view_from_big_query( + _TEST_FV1_ID, + FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS + ), + labels=_TEST_FV1_LABELS, + create_request_timeout=create_request_timeout, + ) + + if not sync: + fos.wait() + + # When creating, the FeatureView object doesn't have the path set. + expected_fv = types.feature_view.FeatureView( + big_query_source=types.feature_view.FeatureView.BigQuerySource( + uri=_TEST_FV1_BQ_URI, + entity_id_columns=_TEST_FV1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FV1_LABELS, + ) + create_bq_fv_mock.assert_called_with( + parent=_TEST_BIGTABLE_FOS1_PATH, + feature_view=expected_fv, + feature_view_id=_TEST_FV1_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fv_eq( + fv_to_check=fv, + name=_TEST_FV1_ID, + resource_name=_TEST_FV1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + + fos_logger_mock.assert_has_calls( + [ + call("Creating FeatureView"), + call( + f"Create FeatureView backing LRO: {create_bq_fv_mock.return_value.operation.name}" + ), + call( + "FeatureView created. Resource name: projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv1" + ), + call("To use this FeatureView in another session:"), + call( + "feature_view = aiplatform.FeatureView('projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv1')" + ), + ] + ) + + +def test_create_embedding_fv( + get_esf_optimized_fos_mock, + create_embedding_fv_from_bq_mock, + get_optimized_embedding_fv_mock, +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + fos = FeatureOnlineStore(_TEST_ESF_OPTIMIZED_FOS_ID) + + embedding_fv = fos.create_feature_view_from_big_query( + _TEST_OPTIMIZED_EMBEDDING_FV_ID, + FeatureView.BigQuerySource(uri="hi", entity_id_columns=["entity_id"]), + index_config=IndexConfig( + embedding_column="embedding", + filter_column=["currency_code", "gender", "shipping_country_codes"], + crowding_column="crowding", + dimentions=1536, + distance_measure_type=DistanceMeasureType.SQUARED_L2_DISTANCE, + algorithm_config=TreeAhConfig(), + ), + ) + fv_eq( + fv_to_check=embedding_fv, + name=_TEST_OPTIMIZED_EMBEDDING_FV_ID, + resource_name=_TEST_OPTIMIZED_EMBEDDING_FV_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) diff --git a/tests/unit/vertexai/test_feature_view.py b/tests/unit/vertexai/test_feature_view.py new file mode 100644 index 0000000000..a5b9b146dd --- /dev/null +++ b/tests/unit/vertexai/test_feature_view.py @@ -0,0 +1,508 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import re +from typing import Dict +from unittest import mock +from unittest.mock import call, patch +from google.api_core import operation as ga_operation + +from google.cloud import aiplatform +from google.cloud.aiplatform import base +from vertexai.resources.preview import ( + FeatureView, +) +import vertexai.resources.preview.feature_store.utils as fv_utils +import pytest +from google.cloud.aiplatform.compat.services import ( + feature_online_store_admin_service_client, + feature_online_store_service_client, +) +from vertexai.resources.preview.feature_store import ( + feature_view, +) + +from feature_store_constants import ( + _TEST_BIGTABLE_FOS1_ID, + _TEST_BIGTABLE_FOS1_PATH, + _TEST_EMBEDDING_FV1_PATH, + _TEST_STRING_FILTER, + _TEST_FV1_ID, + _TEST_FV1_LABELS, + _TEST_FV1_PATH, + _TEST_FV2_ID, + _TEST_FV2_LABELS, + _TEST_FV2_PATH, + _TEST_FV_FETCH1, + _TEST_FV_LIST, + _TEST_FV_SEARCH1, + _TEST_FV_SYNC1, + _TEST_FV_SYNC1_ID, + _TEST_FV_SYNC1_PATH, + _TEST_FV_SYNC2_ID, + _TEST_FV_SYNC2_PATH, + _TEST_FV_SYNC_LIST, + _TEST_LOCATION, + _TEST_OPTIMIZED_FV1_PATH, + _TEST_OPTIMIZED_FV2_PATH, + _TEST_PROJECT, + _TEST_FV_SYNC1_RESPONSE, +) + + +pytestmark = pytest.mark.usefixtures("google_auth_mock") + + +@pytest.fixture +def fv_logger_mock(): + with patch.object( + feature_view._LOGGER, + "info", + wraps=feature_view._LOGGER.info, + ) as logger_mock: + yield logger_mock + + +@pytest.fixture +def list_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "list_feature_views", + ) as list_fv: + list_fv.return_value = _TEST_FV_LIST + yield list_fv + + +@pytest.fixture +def delete_fv_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "delete_feature_view", + ) as delete_fv: + delete_fv_lro_mock = mock.Mock(ga_operation.Operation) + delete_fv.return_value = delete_fv_lro_mock + yield delete_fv + + +@pytest.fixture +def get_fv_sync_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "get_feature_view_sync", + ) as get_fv_sync_mock: + get_fv_sync_mock.return_value = _TEST_FV_SYNC1 + yield get_fv_sync_mock + + +@pytest.fixture +def list_fv_syncs_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "list_feature_view_syncs", + ) as list_fv_syncs_mock: + list_fv_syncs_mock.return_value = _TEST_FV_SYNC_LIST + yield list_fv_syncs_mock + + +@pytest.fixture +def sync_fv_sync_mock(): + with patch.object( + feature_online_store_admin_service_client.FeatureOnlineStoreAdminServiceClient, + "sync_feature_view", + ) as sync_fv_sync_mock: + sync_fv_sync_mock.return_value = _TEST_FV_SYNC1_RESPONSE + yield sync_fv_sync_mock + + +@pytest.fixture +def fetch_feature_values_mock(): + with patch.object( + feature_online_store_service_client.FeatureOnlineStoreServiceClient, + "fetch_feature_values", + ) as fetch_feature_values_mock: + fetch_feature_values_mock.return_value = _TEST_FV_FETCH1 + yield fetch_feature_values_mock + + +@pytest.fixture +def search_nearest_entities_mock(): + with patch.object( + feature_online_store_service_client.FeatureOnlineStoreServiceClient, + "search_nearest_entities", + ) as search_nearest_entities_mock: + search_nearest_entities_mock.return_value = _TEST_FV_SEARCH1 + yield search_nearest_entities_mock + + +def fv_eq( + fv_to_check: FeatureView, + name: str, + resource_name: str, + project: str, + location: str, + labels: Dict[str, str], +): + """Check if a FeatureView has the appropriate values set.""" + assert fv_to_check.name == name + assert fv_to_check.resource_name == resource_name + assert fv_to_check.project == project + assert fv_to_check.location == location + assert fv_to_check.labels == labels + + +def fv_sync_eq( + fv_sync_to_check: FeatureView.FeatureViewSync, + name: str, + resource_name: str, + project: str, + location: str, +): + """Check if a FeatureViewSync has the appropriate values set.""" + assert fv_sync_to_check.name == name + assert fv_sync_to_check.resource_name == resource_name + assert fv_sync_to_check.project == project + assert fv_sync_to_check.location == location + + +def test_init_with_fv_id_and_no_fos_id_raises_error(get_fv_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape( + "Since feature view is not provided as a path, please specify" + + " feature_online_store_id." + ), + ): + FeatureView(_TEST_FV1_ID) + + +def test_init_with_fv_id(get_fv_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv = FeatureView(_TEST_FV1_ID, feature_online_store_id=_TEST_BIGTABLE_FOS1_ID) + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fv_eq( + fv_to_check=fv, + name=_TEST_FV1_ID, + resource_name=_TEST_FV1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + + +def test_init_with_fv_path(get_fv_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv = FeatureView(_TEST_FV1_PATH) + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fv_eq( + fv_to_check=fv, + name=_TEST_FV1_ID, + resource_name=_TEST_FV1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + + +def test_list(list_fv_mock, get_fos_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + feature_views = FeatureView.list(feature_online_store_id=_TEST_BIGTABLE_FOS1_ID) + + list_fv_mock.assert_called_once_with(request={"parent": _TEST_BIGTABLE_FOS1_PATH}) + assert len(feature_views) == len(_TEST_FV_LIST) + + fv_eq( + feature_views[0], + name=_TEST_FV1_ID, + resource_name=_TEST_FV1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV1_LABELS, + ) + fv_eq( + feature_views[1], + name=_TEST_FV2_ID, + resource_name=_TEST_FV2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FV2_LABELS, + ) + + +def test_delete(delete_fv_mock, fv_logger_mock, get_fos_mock, get_fv_mock, sync=True): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv = FeatureView(name=_TEST_FV1_ID, feature_online_store_id=_TEST_BIGTABLE_FOS1_ID) + fv.delete() + + if not sync: + fv.wait() + + delete_fv_mock.assert_called_once_with(name=_TEST_FV1_PATH) + + fv_logger_mock.assert_has_calls( + [ + call( + "Deleting FeatureView resource: projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv1" + ), + call( + f"Delete FeatureView backing LRO: {delete_fv_mock.return_value.operation.name}" + ), + call( + "FeatureView resource projects/test-project/locations/us-central1/featureOnlineStores/my_fos1/featureViews/my_fv1 deleted." + ), + ] + ) + + +def test_get_sync(get_fv_mock, get_fv_sync_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv_sync = FeatureView(_TEST_FV1_PATH).get_sync(_TEST_FV_SYNC1_ID) + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + get_fv_sync_mock.assert_called_once_with( + name=_TEST_FV_SYNC1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fv_sync_eq( + fv_sync_to_check=fv_sync, + name=_TEST_FV_SYNC1_ID, + resource_name=_TEST_FV_SYNC1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + +def test_list_syncs(get_fv_mock, list_fv_syncs_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv_syncs = FeatureView(_TEST_FV1_PATH).list_syncs() + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + list_fv_syncs_mock.assert_called_once_with(request={"parent": _TEST_FV1_PATH}) + assert len(fv_syncs) == len(_TEST_FV_SYNC_LIST) + + fv_sync_eq( + fv_sync_to_check=fv_syncs[0], + name=_TEST_FV_SYNC1_ID, + resource_name=_TEST_FV_SYNC1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + fv_sync_eq( + fv_sync_to_check=fv_syncs[1], + name=_TEST_FV_SYNC2_ID, + resource_name=_TEST_FV_SYNC2_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + +def test_on_demand_sync(get_fv_mock, get_fv_sync_mock, sync_fv_sync_mock): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fv_sync = FeatureView(_TEST_FV1_PATH).sync() + + get_fv_mock.assert_called_once_with( + name=_TEST_FV1_PATH, + retry=base._DEFAULT_RETRY, + ) + + sync_fv_sync_mock.assert_called_once_with( + request={"feature_view": _TEST_FV1_PATH}, + ) + + get_fv_sync_mock.assert_called_once_with( + name=_TEST_FV_SYNC1_PATH, + retry=base._DEFAULT_RETRY, + ) + + fv_sync_eq( + fv_sync_to_check=fv_sync, + name=_TEST_FV_SYNC1_ID, + resource_name=_TEST_FV_SYNC1_PATH, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + +@pytest.mark.parametrize("output_type", ["dict", "proto"]) +def test_fetch_feature_values_bigtable( + get_fos_mock, get_fv_mock, fetch_feature_values_mock, fv_logger_mock, output_type +): + if output_type == "dict": + fv_dict = FeatureView(_TEST_FV1_PATH).read(key=["key1"]).to_dict() + assert fv_dict == { + "features": [{"name": "key1", "value": {"string_value": "value1"}}] + } + elif output_type == "proto": + fv_proto = FeatureView(_TEST_FV1_PATH).read(key=["key1"]).to_proto() + assert fv_proto == _TEST_FV_FETCH1 + + fv_logger_mock.assert_has_calls( + [ + call("Connecting to Bigtable online store name my_fos1"), + ] + ) + + +@pytest.mark.parametrize("output_type", ["dict", "proto"]) +def test_fetch_feature_values_optimized( + get_esf_optimized_fos_mock, + get_optimized_fv_mock, + fetch_feature_values_mock, + fv_logger_mock, + output_type, +): + if output_type == "dict": + fv_dict = FeatureView(_TEST_OPTIMIZED_FV1_PATH).read(key=["key1"]).to_dict() + assert fv_dict == { + "features": [{"name": "key1", "value": {"string_value": "value1"}}] + } + elif output_type == "proto": + fv_proto = FeatureView(_TEST_OPTIMIZED_FV1_PATH).read(key=["key1"]).to_proto() + assert fv_proto == _TEST_FV_FETCH1 + + fv_logger_mock.assert_has_calls( + [ + call( + "Public endpoint for the optimized online store my_esf_optimized_fos is test-esf-endpoint" + ), + ] + ) + + +def test_fetch_feature_values_optimized_no_endpoint( + get_esf_optimized_fos_no_endpoint_mock, + get_optimized_fv_no_endpointmock, + fetch_feature_values_mock, +): + """Tests that the public endpoint is not created for the optimized online store.""" + with pytest.raises( + fv_utils.PublicEndpointNotFoundError, + match=re.escape( + "Public endpoint is not created yet for the optimized online " + "store:my_esf_optimised_fos2. Please run sync and wait for it " + "to complete." + ), + ): + FeatureView(_TEST_OPTIMIZED_FV2_PATH).read(key=["key1"]).to_dict() + + +@pytest.mark.parametrize("output_type", ["dict", "proto"]) +def test_search_nearest_entities( + get_esf_optimized_fos_mock, + get_embedding_fv_mock, + search_nearest_entities_mock, + fv_logger_mock, + output_type, +): + if output_type == "dict": + fv_dict = ( + # Test with entity_id input. + FeatureView(_TEST_EMBEDDING_FV1_PATH) + .search( + entity_id="key1", + neighbor_count=2, + string_filters=[_TEST_STRING_FILTER], + per_crowding_attribute_neighbor_count=1, + return_full_entity=True, + approximate_neighbor_candidates=3, + leaf_nodes_search_fraction=0.5, + ) + .to_dict() + ) + assert fv_dict == { + "neighbors": [{"distance": 0.1, "entity_id": "neighbor_entity_id_1"}] + } + elif output_type == "proto": + fv_proto = ( + # Test with embedding_value input. + FeatureView(_TEST_EMBEDDING_FV1_PATH) + .search(embedding_value=[0.1, 0.2, 0.3]) + .to_proto() + ) + assert fv_proto == _TEST_FV_SEARCH1 + + fv_logger_mock.assert_has_calls( + [ + call( + "Public endpoint for the optimized online store my_esf_optimized_fos" + " is test-esf-endpoint" + ), + ] + ) + + +def test_search_nearest_entities_without_entity_id_or_embedding( + get_esf_optimized_fos_mock, + get_embedding_fv_mock, + search_nearest_entities_mock, + fv_logger_mock, +): + try: + FeatureView(_TEST_EMBEDDING_FV1_PATH).search().to_proto() + assert not search_nearest_entities_mock.called + except ValueError as e: + error_msg = ( + "Either entity_id or embedding_value needs to be provided for search." + ) + assert str(e) == error_msg + + +def test_search_nearest_entities_no_endpoint( + get_esf_optimized_fos_no_endpoint_mock, + get_optimized_fv_no_endpointmock, + fetch_feature_values_mock, +): + """Tests that the public endpoint is not created for the optimized online store.""" + try: + FeatureView(_TEST_OPTIMIZED_FV2_PATH).search(entity_id="key1").to_dict() + assert not fetch_feature_values_mock.called + except fv_utils.PublicEndpointNotFoundError as e: + assert isinstance(e, fv_utils.PublicEndpointNotFoundError) + error_msg = ( + "Public endpoint is not created yet for the optimized online " + "store:my_esf_optimised_fos2. Please run sync and wait for it " + "to complete." + ) + assert str(e) == error_msg diff --git a/vertexai/resources/preview/__init__.py b/vertexai/resources/preview/__init__.py index c37347bbdd..c23b8a11f4 100644 --- a/vertexai/resources/preview/__init__.py +++ b/vertexai/resources/preview/__init__.py @@ -37,6 +37,18 @@ PipelineJobSchedule, ) +from vertexai.resources.preview.feature_store import ( + FeatureOnlineStore, + FeatureOnlineStoreType, + FeatureView, + FeatureViewReadResponse, + IndexConfig, + TreeAhConfig, + BruteForceConfig, + DistanceMeasureType, + AlgorithmConfig, +) + __all__ = ( "CustomJob", @@ -48,4 +60,13 @@ "PersistentResource", "EntityType", "PipelineJobSchedule", + "FeatureOnlineStoreType", + "FeatureOnlineStore", + "FeatureView", + "FeatureViewReadResponse", + "IndexConfig", + "TreeAhConfig", + "BruteForceConfig", + "DistanceMeasureType", + "AlgorithmConfig", ) diff --git a/vertexai/resources/preview/feature_store/__init__.py b/vertexai/resources/preview/feature_store/__init__.py new file mode 100644 index 0000000000..ce9724d6e8 --- /dev/null +++ b/vertexai/resources/preview/feature_store/__init__.py @@ -0,0 +1,48 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""The vertexai resources preview module.""" + +from vertexai.resources.preview.feature_store.feature_online_store import ( + FeatureOnlineStore, + FeatureOnlineStoreType, +) + +from vertexai.resources.preview.feature_store.feature_view import ( + FeatureView, +) + +from vertexai.resources.preview.feature_store.utils import ( + FeatureViewReadResponse, + IndexConfig, + TreeAhConfig, + BruteForceConfig, + DistanceMeasureType, + AlgorithmConfig, +) + +__all__ = ( + FeatureOnlineStoreType, + FeatureOnlineStore, + FeatureView, + FeatureViewReadResponse, + IndexConfig, + IndexConfig, + TreeAhConfig, + BruteForceConfig, + DistanceMeasureType, + AlgorithmConfig, +) diff --git a/vertexai/resources/preview/feature_store/feature_online_store.py b/vertexai/resources/preview/feature_store/feature_online_store.py new file mode 100644 index 0000000000..8b62593ef4 --- /dev/null +++ b/vertexai/resources/preview/feature_store/feature_online_store.py @@ -0,0 +1,524 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import enum +from typing import ( + Dict, + Optional, + Sequence, + Tuple, +) + +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import ( + base, + initializer, + utils, +) +from google.cloud.aiplatform.compat.types import ( + feature_online_store as gca_feature_online_store, + feature_view as gca_feature_view, +) +from vertexai.resources.preview.feature_store.feature_view import ( + FeatureView, +) +from vertexai.resources.preview.feature_store.utils import ( + IndexConfig, +) + + +_LOGGER = base.Logger(__name__) + + +@enum.unique +class FeatureOnlineStoreType(enum.Enum): + UNKNOWN = 0 + BIGTABLE = 1 + OPTIMIZED = 2 + + +class FeatureOnlineStore(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature Online Store resources.""" + + client_class = utils.FeatureOnlineStoreAdminClientWithOverride + + _resource_noun = "feature_online_stores" + _getter_method = "get_feature_online_store" + _list_method = "list_feature_online_stores" + _delete_method = "delete_feature_online_store" + _parse_resource_name_method = "parse_feature_online_store_path" + _format_resource_name_method = "feature_online_store_path" + _gca_resource: gca_feature_online_store.FeatureOnlineStore + + def __init__( + self, + name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature online store. + + Args: + name: + The resource name + (`projects/.../locations/.../featureOnlineStores/...`) or ID. + project: + Project to retrieve feature online store from. If unset, the + project set in aiplatform.init will be used. + location: + Location to retrieve feature online store from. If not set, + location set in aiplatform.init will be used. + credentials: + Custom credentials to use to retrieve this feature online store. + Overrides credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + self._gca_resource = self._get_gca_resource(resource_name=name) + + @classmethod + @base.optional_sync() + def create_bigtable_store( + cls, + name: str, + min_node_count: Optional[int] = 1, + max_node_count: Optional[int] = 1, + cpu_utilization_target: Optional[int] = 50, + labels: Optional[Dict[str, str]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> "FeatureOnlineStore": + """Creates a Bigtable online store. + + Example Usage: + + my_fos = vertexai.preview.FeatureOnlineStore.create_bigtable_store('my_fos') + + Args: + name: The name of the feature online store. + min_node_count: + The minimum number of Bigtable nodes to scale down to. Must be + greater than or equal to 1. + max_node_count: + The maximum number of Bigtable nodes to scale up to. Must + satisfy min_node_count <= max_node_count <= (10 * + min_node_count). + cpu_utilization_target: + A percentage of the cluster's CPU capacity. Can be from 10% to + 80%. When a cluster's CPU utilization exceeds the target that + you have set, Bigtable immediately adds nodes to the cluster. + When CPU utilization is substantially lower than the target, + Bigtable removes nodes. If not set will default to 50%. + labels: + The labels with user-defined metadata to organize your feature + online store. Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only contain lowercase + letters, numeric characters, underscores and dashes. + International characters are allowed. See https://goo.gl/xmQnxf + for more information on and examples of labels. No more than 64 + user labels can be associated with one feature online store + (System labels are excluded)." System reserved label keys are + prefixed with "aiplatform.googleapis.com/" and are immutable. + project: + Project to create feature online store in. If unset, the project + set in aiplatform.init will be used. + location: + Location to create feature online store in. If not set, location + set in aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature online store. + Overrides credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + FeatureOnlineStore - the FeatureOnlineStore resource object. + """ + + if min_node_count < 1: + raise ValueError("min_node_count must be greater than or equal to 1") + + if max_node_count < min_node_count: + raise ValueError( + "max_node_count must be greater than or equal to min_node_count" + ) + elif 10 * min_node_count < max_node_count: + raise ValueError( + "max_node_count must be less than or equal to 10 * min_node_count" + ) + + if cpu_utilization_target < 10 or cpu_utilization_target > 80: + raise ValueError("cpu_utilization_target must be between 10 and 80") + + gapic_feature_online_store = gca_feature_online_store.FeatureOnlineStore( + bigtable=gca_feature_online_store.FeatureOnlineStore.Bigtable( + auto_scaling=gca_feature_online_store.FeatureOnlineStore.Bigtable.AutoScaling( + min_node_count=min_node_count, + max_node_count=max_node_count, + cpu_utilization_target=cpu_utilization_target, + ), + ), + ) + + if labels: + utils.validate_labels(labels) + gapic_feature_online_store.labels = labels + + if request_metadata is None: + request_metadata = () + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + create_online_store_lro = api_client.create_feature_online_store( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + feature_online_store=gapic_feature_online_store, + feature_online_store_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(cls, create_online_store_lro) + + created_online_store = create_online_store_lro.result() + + _LOGGER.log_create_complete(cls, created_online_store, "feature_online_store") + + online_store_obj = cls( + name=created_online_store.name, + project=project, + location=location, + credentials=credentials, + ) + + return online_store_obj + + @classmethod + @base.optional_sync() + def create_optimized_store( + cls, + name: str, + enable_private_service_connect: bool = False, + project_allowlist: Optional[Sequence[str]] = None, + labels: Optional[Dict[str, str]] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> "FeatureOnlineStore": + """Creates an Optimized online store. + + Example Usage: + + my_fos = vertexai.preview.FeatureOnlineStore.create_optimized_store('my_fos') + + Args: + name: The name of the feature online store. + enable_private_service_connect (bool): + Optional. If true, expose the optimized online store + via private service connect. Otherwise the optimized online + store will be accessible through public endpoint + project_allowlist (MutableSequence[str]): + A list of Projects from which the forwarding + rule will target the service attachment. Only needed when + enable_private_service_connect is set to true. + labels: + The labels with user-defined metadata to organize your feature + online store. Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only contain lowercase + letters, numeric characters, underscores and dashes. + International characters are allowed. See https://goo.gl/xmQnxf + for more information on and examples of labels. No more than 64 + user labels can be associated with one feature online store + (System labels are excluded)." System reserved label keys are + prefixed with "aiplatform.googleapis.com/" and are immutable. + project: + Project to create feature online store in. If unset, the project + set in aiplatform.init will be used. + location: + Location to create feature online store in. If not set, location + set in aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature online store. + Overrides credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + FeatureOnlineStore - the FeatureOnlineStore resource object. + """ + if enable_private_service_connect: + raise ValueError("private_service_connect is not supported") + else: + dedicated_serving_endpoint = ( + gca_feature_online_store.FeatureOnlineStore.DedicatedServingEndpoint() + ) + + gapic_feature_online_store = gca_feature_online_store.FeatureOnlineStore( + optimized=gca_feature_online_store.FeatureOnlineStore.Optimized(), + dedicated_serving_endpoint=dedicated_serving_endpoint, + ) + + if labels: + utils.validate_labels(labels) + gapic_feature_online_store.labels = labels + + if request_metadata is None: + request_metadata = () + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + create_online_store_lro = api_client.create_feature_online_store( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + feature_online_store=gapic_feature_online_store, + feature_online_store_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(cls, create_online_store_lro) + + created_online_store = create_online_store_lro.result() + + _LOGGER.log_create_complete(cls, created_online_store, "feature_online_store") + + online_store_obj = cls( + name=created_online_store.name, + project=project, + location=location, + credentials=credentials, + ) + + return online_store_obj + + @base.optional_sync() + def delete(self, force: bool = False, sync: bool = True) -> None: + """Deletes this online store. + + WARNING: This deletion is permanent. + + Args: + force: + If set to True, all feature views under this online store will + be deleted prior to online store deletion. Otherwise, deletion + will only succeed if the online store has no FeatureViews. + sync: + Whether to execute this deletion synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + """ + + lro = getattr(self.api_client, self._delete_method)( + name=self.resource_name, + force=force, + ) + _LOGGER.log_delete_with_lro(self, lro) + lro.result() + _LOGGER.log_delete_complete(self) + + @property + def feature_online_store_type(self) -> FeatureOnlineStoreType: + if self._gca_resource.bigtable: + return FeatureOnlineStoreType.BIGTABLE + # Optimized is an empty proto, so self._gca_resource.optimized is always false. + elif hasattr(self.gca_resource, "optimized"): + return FeatureOnlineStoreType.OPTIMIZED + else: + raise ValueError( + f"Online store does not have type or is unsupported by SDK: {self._gca_resource}." + ) + + @property + def labels(self) -> Dict[str, str]: + return self._gca_resource.labels + + @base.optional_sync() + def create_feature_view_from_big_query( + self, + name: str, + big_query_source: FeatureView.BigQuerySource, + labels: Optional[Dict[str, str]] = None, + sync_config: Optional[str] = None, + index_config: Optional[IndexConfig] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> "FeatureView": + """Creates a FeatureView from a BigQuery source. + + Example Usage: + ``` + existing_fos = FeatureOnlineStore('my_fos') + new_fv = existing_fos.create_feature_view_from_bigquery( + 'my_fos', + BigQuerySource( + uri='bq://my-proj/dataset/table', + entity_id_columns=['entity_id'], + ) + ) + # Example for how to create an embedding FeatureView. + embedding_fv = existing_fos.create_feature_view_from_bigquery( + 'my_fos', + BigQuerySource( + uri='bq://my-proj/dataset/table', + entity_id_columns=['entity_id'], + ) + index_config=IndexConfig( + embedding_column="embedding", + filter_column=["currency_code", "gender", + crowding_column="crowding", + dimentions=1536, + distance_measure_type=DistanceMeasureType.SQUARED_L2_DISTANCE, + algorithm_config=TreeAhConfig(), + ) + ) + ``` + Args: + name: The name of the feature view. + big_query_source: + The BigQuery source to load data from when a feature view sync + runs. + labels: + The labels with user-defined metadata to organize your + FeatureViews. + + Label keys and values can be no longer than 64 characters + (Unicode codepoints), can only contain lowercase letters, + numeric characters, underscores and dashes. International + characters are allowed. + + See https://goo.gl/xmQnxf for more information on and examples + of labels. No more than 64 user labels can be associated with + one FeatureOnlineStore(System labels are excluded)." System + reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + sync_config: + Configures when data is to be synced/updated for this + FeatureView. At the end of the sync the latest feature values + for each entity ID of this FeatureView are made ready for online + serving. Example format: "TZ=America/New_York 0 9 * * *" (sync + daily at 9 AM EST). + project: + Project to create feature view in. If unset, the project set in + aiplatform.init will be used. + location: + Location to create feature view in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature view. + Overrides credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + FeatureView - the FeatureView resource object. + """ + if not big_query_source: + raise ValueError("Please specify valid big_query_source.") + elif not big_query_source.uri: + raise ValueError("Please specify URI in big_query_source.") + elif not big_query_source.entity_id_columns: + raise ValueError("Please specify entity ID columns in big_query_source.") + + gapic_feature_view = gca_feature_view.FeatureView( + big_query_source=gca_feature_view.FeatureView.BigQuerySource( + uri=big_query_source.uri, + entity_id_columns=big_query_source.entity_id_columns, + ), + sync_config=gca_feature_view.FeatureView.SyncConfig(cron=sync_config) + if sync_config + else None, + ) + + if labels: + utils.validate_labels(labels) + gapic_feature_view.labels = labels + + if request_metadata is None: + request_metadata = () + + if index_config: + gapic_feature_view.index_config = gca_feature_view.FeatureView.IndexConfig( + index_config.as_dict() + ) + + api_client = self.__class__._instantiate_client( + location=location, credentials=credentials + ) + + create_feature_view_lro = api_client.create_feature_view( + parent=self.resource_name, + feature_view=gapic_feature_view, + feature_view_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(FeatureView, create_feature_view_lro) + + created_feature_view = create_feature_view_lro.result() + + _LOGGER.log_create_complete(FeatureView, created_feature_view, "feature_view") + + feature_view_obj = FeatureView( + name=created_feature_view.name, + project=project, + location=location, + credentials=credentials, + ) + + return feature_view_obj diff --git a/vertexai/resources/preview/feature_store/feature_view.py b/vertexai/resources/preview/feature_store/feature_view.py new file mode 100644 index 0000000000..a8475c396b --- /dev/null +++ b/vertexai/resources/preview/feature_store/feature_view.py @@ -0,0 +1,429 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +import re +from typing import List, Optional +from google.cloud.aiplatform import initializer +from google.auth import credentials as auth_credentials +from google.cloud.aiplatform import base +from google.cloud.aiplatform import utils +from google.cloud.aiplatform.compat.types import ( + feature_view_sync as gca_feature_view_sync, + feature_view as gca_feature_view, + feature_online_store_service as fos_service, +) +import vertexai.resources.preview.feature_store.utils as fv_utils + +_LOGGER = base.Logger(__name__) + + +class FeatureView(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature View resources.""" + + client_class = utils.FeatureOnlineStoreAdminClientWithOverride + + _resource_noun = "featureViews" + _getter_method = "get_feature_view" + _list_method = "list_feature_views" + _delete_method = "delete_feature_view" + _parse_resource_name_method = "parse_feature_view_path" + _format_resource_name_method = "feature_view_path" + _gca_resource: gca_feature_view.FeatureView + _online_store_client: utils.FeatureOnlineStoreClientWithOverride + + def __init__( + self, + name: str, + feature_online_store_id: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature view. + + Args: + name: + The resource name + (`projects/.../locations/.../featureOnlineStores/.../featureViews/...`) + or ID. + feature_online_store_id: + The feature online store ID. Must be passed in if name is an ID + and not a resource path. + project: + Project to retrieve the feature view from. If unset, the project + set in aiplatform.init will be used. + location: + Location to retrieve the feature view from. If not set, location + set in aiplatform.init will be used. + credentials: + Custom credentials to use to retrieve this feature view. + Overrides credentials set in aiplatform.init. + """ + + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + + if re.fullmatch( + r"projects/.+/locations/.+/featureOnlineStores/.+/featureViews/.+", + name, + ): + feature_view = name + else: + from .feature_online_store import FeatureOnlineStore + + # Construct the feature view path using feature online store ID if + # only the feature view ID is provided. + if not feature_online_store_id: + raise ValueError( + "Since feature view is not provided as a path, please specify" + + " feature_online_store_id." + ) + + feature_online_store_path = utils.full_resource_name( + resource_name=feature_online_store_id, + resource_noun=FeatureOnlineStore._resource_noun, + parse_resource_name_method=FeatureOnlineStore._parse_resource_name, + format_resource_name_method=FeatureOnlineStore._format_resource_name, + ) + + feature_view = f"{feature_online_store_path}/featureViews/{name}" + + self._gca_resource = self._get_gca_resource(resource_name=feature_view) + + @property + def _get_online_store_client(self) -> utils.FeatureOnlineStoreClientWithOverride: + if getattr(self, "_online_store_client", None): + return self._online_store_client + + fos_name = fv_utils.get_feature_online_store_name(self.resource_name) + from .feature_online_store import FeatureOnlineStore + + fos = FeatureOnlineStore(name=fos_name) + + if fos._gca_resource.bigtable.auto_scaling: + # This is Bigtable online store. + _LOGGER.info(f"Connecting to Bigtable online store name {fos_name}") + self._online_store_client = initializer.global_config.create_client( + client_class=utils.FeatureOnlineStoreClientWithOverride, + credentials=self.credentials, + location_override=self.location, + ) + return self._online_store_client + + # From here, optimized serving. + if not fos._gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name: + raise fv_utils.PublicEndpointNotFoundError( + "Public endpoint is not created yet for the optimized online store:" + f"{fos_name}. Please run sync and wait for it to complete." + ) + + _LOGGER.info( + f"Public endpoint for the optimized online store {fos_name} is" + f" {fos._gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name}" + ) + self._online_store_client = initializer.global_config.create_client( + client_class=utils.FeatureOnlineStoreClientWithOverride, + credentials=self.credentials, + location_override=self.location, + prediction_client=True, + api_path_override=fos._gca_resource.dedicated_serving_endpoint.public_endpoint_domain_name, + ) + return self._online_store_client + + @classmethod + def list( + cls, + feature_online_store_id: str, + filter: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ) -> List["FeatureView"]: + """List all feature view under feature_online_store_id. + + Example Usage: + ``` + feature_views = vertexai.preview.FeatureView.list( + feature_online_store_id="my_fos", + filter=labels.label_key=label_value) + ``` + Args: + feature_online_store_id: + Parentfeature online store ID. + filter: + Filter to apply on the returned feature online store. + credentials: + Custom credentials to use to get a list of feature views. + Overrides credentials set in aiplatform.init. + + Returns: + List[FeatureView] - list of FeatureView resource object. + """ + from .feature_online_store import FeatureOnlineStore + + fos = FeatureOnlineStore(name=feature_online_store_id) + return cls._list( + filter=filter, credentials=credentials, parent=fos.resource_name + ) + + @base.optional_sync() + def delete(self, sync: bool = True) -> None: + """Deletes this feature view. + + WARNING: This deletion is permanent. + + Args: + sync: + Whether to execute this deletion synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + """ + lro = getattr(self.api_client, self._delete_method)(name=self.resource_name) + _LOGGER.log_delete_with_lro(self, lro) + lro.result() + _LOGGER.log_delete_complete(self) + + def sync(self) -> "FeatureViewSync": + """Starts an on-demand Sync for the FeatureView. + + Args: None + + Returns: + "FeatureViewSync" - FeatureViewSync instance + """ + sync_method = getattr(self.api_client, self.FeatureViewSync.sync_method()) + + sync_request = { + "feature_view": self.resource_name, + } + sync_response = sync_method(request=sync_request) + + return self.FeatureViewSync(name=sync_response.feature_view_sync) + + def get_sync(self, name) -> "FeatureViewSync": + """Gets the FeatureViewSync resource for the given name. + + Args: + name: The resource ID + + Returns: + "FeatureViewSync" - FeatureViewSync instance + """ + feature_view_path = self.resource_name + feature_view_sync = f"{feature_view_path}/featureViewSyncs/{name}" + return self.FeatureViewSync(name=feature_view_sync) + + def list_syncs( + self, + filter: Optional[str] = None, + ) -> List["FeatureViewSync"]: + """List all feature view under this FeatureView. + + Args: + parent_resource_name: Fully qualified name of the parent FeatureView + resource. + filter: Filter to apply on the returned feature online store. + + Returns: + List[FeatureViewSync] - list of FeatureViewSync resource object. + """ + + return self.FeatureViewSync._list( + filter=filter, credentials=self.credentials, parent=self.resource_name + ) + + def read( + self, + key: List[str], + request_timeout: Optional[float] = None, + ) -> fv_utils.FeatureViewReadResponse: + """Read the feature values from FeatureView. + + Example Usage: + ``` + data = vertexai.preview.FeatureView( + name='feature_view_name', feature_online_store_id='fos_name') + .read(key=[12345, 6789]) + .to_dict() + ``` + Args: + key: The request key to read feature values for. + + Returns: + "FeatureViewReadResponse" - FeatureViewReadResponse object. It is + intermediate class that can be further converted by to_dict() or + to_proto() + """ + self.wait() + response = self._get_online_store_client.fetch_feature_values( + feature_view=self.resource_name, + data_key=fos_service.FeatureViewDataKey( + composite_key=fos_service.FeatureViewDataKey.CompositeKey(parts=key) + ), + timeout=request_timeout, + ) + return fv_utils.FeatureViewReadResponse(response) + + def search( + self, + entity_id: Optional[str] = None, + embedding_value: Optional[List[float]] = None, + neighbor_count: Optional[int] = None, + string_filters: Optional[ + List[fos_service.NearestNeighborQuery.StringFilter] + ] = None, + per_crowding_attribute_neighbor_count: Optional[int] = None, + return_full_entity: bool = False, + approximate_neighbor_candidates: Optional[int] = None, + leaf_nodes_search_fraction: Optional[float] = None, + request_timeout: Optional[float] = None, + ) -> fv_utils.SearchNearestEntitiesResponse: + """Search the nearest entities from FeatureView. + + Example Usage: + ``` + data = vertexai.preview.FeatureView( + name='feature_view_name', feature_online_store_id='fos_name') + .search(entity_id='sample_entity') + .to_dict() + ``` + Args: + entity_id: The entity id whose similar entities should be searched + for. + embedding_value: The embedding vector that be used for similar + search. + neighbor_count: The number of similar entities to be retrieved + from feature view for each query. + string_filters: The list of string filters. + per_crowding_attribute_neighbor_count: Crowding is a constraint on a + neighbor list produced by nearest neighbor search requiring that + no more than sper_crowding_attribute_neighbor_count of the k + neighbors returned have the same value of crowding_attribute. + It's used for improving result diversity. + return_full_entity: If true, return full entities including the + features other than embeddings. + approximate_neighbor_candidates: The number of neighbors to find via + approximate search before exact reordering is performed; if set, + this value must be > neighbor_count. + leaf_nodes_search_fraction: The fraction of the number of leaves to + search, set at query time allows user to tune search performance. + This value increase result in both search accuracy and latency + increase. The value should be between 0.0 and 1.0. + + Returns: + "SearchNearestEntitiesResponse" - SearchNearestEntitiesResponse + object. It is intermediate class that can be further converted by + to_dict() or to_proto() + """ + self.wait() + if entity_id: + embedding = None + elif embedding_value: + embedding = fos_service.NearestNeighborQuery.Embedding( + value=embedding_value + ) + else: + raise ValueError( + f"Either entity_id or embedding_value needs to be provided for" + f" search." + ) + response = self._get_online_store_client.search_nearest_entities( + request=fos_service.SearchNearestEntitiesRequest( + feature_view=self.resource_name, + query=fos_service.NearestNeighborQuery( + entity_id=entity_id, + embedding=embedding, + neighbor_count=neighbor_count, + string_filters=string_filters, + per_crowding_attribute_neighbor_count=per_crowding_attribute_neighbor_count, # pylint: disable=line-too-long + parameters=fos_service.NearestNeighborQuery.Parameters( + approximate_neighbor_candidates=approximate_neighbor_candidates, + leaf_nodes_search_fraction=leaf_nodes_search_fraction, + ), + ), + return_full_entity=return_full_entity, + ), + timeout=request_timeout, + ) + return fv_utils.SearchNearestEntitiesResponse(response) + + @dataclass + class BigQuerySource: + uri: str + entity_id_columns: List[str] + + class FeatureViewSync(base.VertexAiResourceNounWithFutureManager): + """Class for managing Feature View Sync resources.""" + + client_class = utils.FeatureOnlineStoreAdminClientWithOverride + + _resource_noun = "featureViewSyncs" + _getter_method = "get_feature_view_sync" + _list_method = "list_feature_view_syncs" + _delete_method = "delete_feature_view" + _sync_method = "sync_feature_view" + _parse_resource_name_method = "parse_feature_view_sync_path" + _format_resource_name_method = "feature_view_sync_path" + _gca_resource: gca_feature_view_sync.FeatureViewSync + + def __init__( + self, + name: str, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + ): + """Retrieves an existing managed feature view sync. + + Args: + name: The resource name + (`projects/.../locations/.../featureOnlineStores/.../featureViews/.../featureViewSyncs/...`) + project: Project to retrieve the feature view from. If unset, the + project set in aiplatform.init will be used. + location: Location to retrieve the feature view from. If not set, + location set in aiplatform.init will be used. + credentials: Custom credentials to use to retrieve this feature view. + Overrides credentials set in aiplatform.init. + """ + super().__init__( + project=project, + location=location, + credentials=credentials, + resource_name=name, + ) + + if not re.fullmatch( + r"projects/.+/locations/.+/featureOnlineStores/.+/featureViews/.+/featureViewSyncs/.+", + name, + ): + raise ValueError( + "name need to specify the fully qualified" + + " feature_view_sync resource path." + ) + + self._gca_resource = getattr(self.api_client, self._getter_method)( + name=name, retry=base._DEFAULT_RETRY + ) + + @classmethod + def sync_method(cls) -> str: + """Returns the sync method.""" + return cls._sync_method diff --git a/vertexai/resources/preview/feature_store/utils.py b/vertexai/resources/preview/feature_store/utils.py new file mode 100644 index 0000000000..f9090d3854 --- /dev/null +++ b/vertexai/resources/preview/feature_store/utils.py @@ -0,0 +1,156 @@ +# -*- coding: utf-8 -*- + +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import abc +from dataclasses import dataclass +import enum +import proto +from typing_extensions import override +from typing import Any, Dict, List, Optional +from google.cloud.aiplatform.compat.types import ( + feature_online_store_service as fos_service, +) + + +def get_feature_online_store_name(online_store_name: str) -> str: + """Extract Feature Online Store's name from FeatureView's full resource name. + + Args: + online_store_name: Full resource name is projects/project_number/ + locations/us-central1/featureOnlineStores/fos_name/featureViews/fv_name + + Returns: + str: feature online store name. + """ + arr = online_store_name.split("/") + return arr[5] + + +class PublicEndpointNotFoundError(RuntimeError): + """Public endpoint has not been created yet.""" + + +@dataclass +class FeatureViewReadResponse: + _response: fos_service.FetchFeatureValuesResponse + + def __init__(self, response: fos_service.FetchFeatureValuesResponse): + self._response = response + + def to_dict(self) -> Dict[str, Any]: + return proto.Message.to_dict(self._response.key_values) + + def to_proto(self) -> fos_service.FetchFeatureValuesResponse: + return self._response + + +@dataclass +class SearchNearestEntitiesResponse: + _response: fos_service.SearchNearestEntitiesResponse + + def __init__(self, response: fos_service.SearchNearestEntitiesResponse): + self._response = response + + def to_dict(self) -> Dict[str, Any]: + return proto.Message.to_dict(self._response.nearest_neighbors) + + def to_proto(self) -> fos_service.SearchNearestEntitiesResponse: + return self._response + + +class DistanceMeasureType(enum.Enum): + """The distance measure used in nearest neighbor search.""" + + DISTANCE_MEASURE_TYPE_UNSPECIFIED = 0 + # Euclidean (L_2) Distance. + SQUARED_L2_DISTANCE = 1 + # Cosine Distance. Defined as 1 - cosine similarity. + COSINE_DISTANCE = 2 + # Dot Product Distance. Defined as a negative of the dot product. + DOT_PRODUCT_DISTANCE = 3 + + +class AlgorithmConfig(abc.ABC): + """Base class for configuration options for matching algorithm.""" + + def as_dict(self) -> Dict: + """Returns the configuration as a dictionary. + + Returns: + Dict[str, Any] + """ + pass + + +@dataclass +class TreeAhConfig(AlgorithmConfig): + """Configuration options for using the tree-AH algorithm (Shallow tree + Asymmetric Hashing). + Please refer to this paper for more details: https://arxiv.org/abs/1908.10396 + + Args: + leaf_node_embedding_count (int): + Optional. Number of embeddings on each leaf node. The default value is 1000 if not set. + """ + + leaf_node_embedding_count: Optional[int] = None + + @override + def as_dict(self) -> Dict: + return {"leaf_node_embedding_count": self.leaf_node_embedding_count} + + +@dataclass +class BruteForceConfig(AlgorithmConfig): + """Configuration options for using brute force search. + It simply implements the standard linear search in the database for + each query. + """ + + @override + def as_dict(self) -> Dict[str, Any]: + return {"bruteForceConfig": {}} + + +@dataclass +class IndexConfig: + """Configuration options for the Vertex FeatureView for embedding.""" + + embedding_column: str + filter_column: List[str] + crowding_column: str + dimentions: Optional[int] + distance_measure_type: DistanceMeasureType + algorithm_config: AlgorithmConfig + + def as_dict(self) -> Dict[str, Any]: + """Returns the configuration as a dictionary. + + Returns: + Dict[str, Any] + """ + config = { + "embedding_column": self.embedding_column, + "filter_columns": self.filter_column, + "crowding_column": self.crowding_column, + "embedding_dimension": self.dimentions, + "distance_measure_type": self.distance_measure_type.value, + } + if isinstance(self.algorithm_config, TreeAhConfig): + config["tree_ah_config"] = self.algorithm_config.as_dict() + else: + config["brute_force_config"] = self.algorithm_config.as_dict() + return config