From dcb62051a7f3aeaa009b64165569c788d8c5ec44 Mon Sep 17 00:00:00 2001 From: Lingyin Wu Date: Wed, 8 Nov 2023 11:05:22 -0800 Subject: [PATCH] feat: add `index_update_method` to MatchingEngineIndex `create()` PiperOrigin-RevId: 580589542 --- .../matching_engine/matching_engine_index.py | 45 ++++++++++++------- .../aiplatform/test_matching_engine_index.py | 44 +++++++++++++++++- 2 files changed, 72 insertions(+), 17 deletions(-) diff --git a/google/cloud/aiplatform/matching_engine/matching_engine_index.py b/google/cloud/aiplatform/matching_engine/matching_engine_index.py index e320e6fc8b..ce11dbea99 100644 --- a/google/cloud/aiplatform/matching_engine/matching_engine_index.py +++ b/google/cloud/aiplatform/matching_engine/matching_engine_index.py @@ -108,6 +108,7 @@ def _create( credentials: Optional[auth_credentials.Credentials] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), sync: bool = True, + index_update_method: Optional[str] = None, ) -> "MatchingEngineIndex": """Creates a MatchingEngineIndex resource. @@ -153,20 +154,25 @@ def _create( credentials set in aiplatform.init. request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. - encryption_spec (str): - Optional. Customer-managed encryption key - spec for data storage. If set, both of the - online and offline data storage will be secured - by this key. sync (bool): Optional. Whether to execute this creation synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + index_update_method (str): + Optional. The update method to use with this index. Choose + stream_update or batch_update. If not set, batch update will be + used by default. Returns: MatchingEngineIndex - Index resource object """ + index_update_method_enum = None + if index_update_method in _INDEX_UPDATE_METHOD_TO_ENUM_VALUE: + index_update_method_enum = _INDEX_UPDATE_METHOD_TO_ENUM_VALUE[ + index_update_method + ] + gapic_index = gca_matching_engine_index.Index( display_name=display_name, description=description, @@ -174,6 +180,7 @@ def _create( "config": config.as_dict(), "contentsDeltaUri": contents_delta_uri, }, + index_update_method=index_update_method_enum, ) if labels: @@ -386,6 +393,7 @@ def create_tree_ah_index( credentials: Optional[auth_credentials.Credentials] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), sync: bool = True, + index_update_method: Optional[str] = None, ) -> "MatchingEngineIndex": """Creates a MatchingEngineIndex resource that uses the tree-AH algorithm. @@ -456,15 +464,14 @@ def create_tree_ah_index( credentials set in aiplatform.init. request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. - encryption_spec (str): - Optional. Customer-managed encryption key - spec for data storage. If set, both of the - online and offline data storage will be secured - by this key. sync (bool): Optional. Whether to execute this creation synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + index_update_method (str): + Optional. The update method to use with this index. Choose + STREAM_UPDATE or BATCH_UPDATE. If not set, batch update will be + used by default. Returns: MatchingEngineIndex - Index resource object @@ -494,6 +501,7 @@ def create_tree_ah_index( credentials=credentials, request_metadata=request_metadata, sync=sync, + index_update_method=index_update_method, ) @classmethod @@ -512,6 +520,7 @@ def create_brute_force_index( credentials: Optional[auth_credentials.Credentials] = None, request_metadata: Optional[Sequence[Tuple[str, str]]] = (), sync: bool = True, + index_update_method: Optional[str] = None, ) -> "MatchingEngineIndex": """Creates a MatchingEngineIndex resource that uses the brute force algorithm. @@ -571,15 +580,14 @@ def create_brute_force_index( credentials set in aiplatform.init. request_metadata (Sequence[Tuple[str, str]]): Optional. Strings which should be sent along with the request as metadata. - encryption_spec (str): - Optional. Customer-managed encryption key - spec for data storage. If set, both of the - online and offline data storage will be secured - by this key. sync (bool): Optional. Whether to execute this creation synchronously. If False, this method will be executed in concurrent Future and any downstream object will be immediately returned and synced when the Future has completed. + index_update_method (str): + Optional. The update method to use with this index. Choose + stream_update or batch_update. If not set, batch update will be + used by default. Returns: MatchingEngineIndex - Index resource object @@ -605,4 +613,11 @@ def create_brute_force_index( credentials=credentials, request_metadata=request_metadata, sync=sync, + index_update_method=index_update_method, ) + + +_INDEX_UPDATE_METHOD_TO_ENUM_VALUE = { + "STREAM_UPDATE": gca_matching_engine_index.Index.IndexUpdateMethod.STREAM_UPDATE, + "BATCH_UPDATE": gca_matching_engine_index.Index.IndexUpdateMethod.BATCH_UPDATE, +} diff --git a/tests/unit/aiplatform/test_matching_engine_index.py b/tests/unit/aiplatform/test_matching_engine_index.py index a4693dc809..1472438b97 100644 --- a/tests/unit/aiplatform/test_matching_engine_index.py +++ b/tests/unit/aiplatform/test_matching_engine_index.py @@ -92,6 +92,18 @@ ), ] +# Index update method +_TEST_INDEX_BATCH_UPDATE_METHOD = "BATCH_UPDATE" +_TEST_INDEX_STREAM_UPDATE_METHOD = "STREAM_UPDATE" +_TEST_INDEX_EMPTY_UPDATE_METHOD = None +_TEST_INDEX_INVALID_UPDATE_METHOD = "INVALID_UPDATE_METHOD" +_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP = { + _TEST_INDEX_BATCH_UPDATE_METHOD: gca_index.Index.IndexUpdateMethod.BATCH_UPDATE, + _TEST_INDEX_STREAM_UPDATE_METHOD: gca_index.Index.IndexUpdateMethod.STREAM_UPDATE, + _TEST_INDEX_EMPTY_UPDATE_METHOD: None, + _TEST_INDEX_INVALID_UPDATE_METHOD: None, +} + def uuid_mock(): return uuid.UUID(int=1) @@ -273,7 +285,16 @@ def test_delete_index(self, delete_index_mock, sync): @pytest.mark.usefixtures("get_index_mock") @pytest.mark.parametrize("sync", [True, False]) - def test_create_tree_ah_index(self, create_index_mock, sync): + @pytest.mark.parametrize( + "index_update_method", + [ + _TEST_INDEX_STREAM_UPDATE_METHOD, + _TEST_INDEX_BATCH_UPDATE_METHOD, + _TEST_INDEX_EMPTY_UPDATE_METHOD, + _TEST_INDEX_INVALID_UPDATE_METHOD, + ], + ) + def test_create_tree_ah_index(self, create_index_mock, sync, index_update_method): aiplatform.init(project=_TEST_PROJECT) my_index = aiplatform.MatchingEngineIndex.create_tree_ah_index( @@ -287,6 +308,7 @@ def test_create_tree_ah_index(self, create_index_mock, sync): description=_TEST_INDEX_DESCRIPTION, labels=_TEST_LABELS, sync=sync, + index_update_method=index_update_method, ) if not sync: @@ -312,6 +334,9 @@ def test_create_tree_ah_index(self, create_index_mock, sync): }, description=_TEST_INDEX_DESCRIPTION, labels=_TEST_LABELS, + index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[ + index_update_method + ], ) create_index_mock.assert_called_once_with( @@ -322,7 +347,18 @@ def test_create_tree_ah_index(self, create_index_mock, sync): @pytest.mark.usefixtures("get_index_mock") @pytest.mark.parametrize("sync", [True, False]) - def test_create_brute_force_index(self, create_index_mock, sync): + @pytest.mark.parametrize( + "index_update_method", + [ + _TEST_INDEX_STREAM_UPDATE_METHOD, + _TEST_INDEX_BATCH_UPDATE_METHOD, + _TEST_INDEX_EMPTY_UPDATE_METHOD, + _TEST_INDEX_INVALID_UPDATE_METHOD, + ], + ) + def test_create_brute_force_index( + self, create_index_mock, sync, index_update_method + ): aiplatform.init(project=_TEST_PROJECT) my_index = aiplatform.MatchingEngineIndex.create_brute_force_index( @@ -333,6 +369,7 @@ def test_create_brute_force_index(self, create_index_mock, sync): description=_TEST_INDEX_DESCRIPTION, labels=_TEST_LABELS, sync=sync, + index_update_method=index_update_method, ) if not sync: @@ -353,6 +390,9 @@ def test_create_brute_force_index(self, create_index_mock, sync): }, description=_TEST_INDEX_DESCRIPTION, labels=_TEST_LABELS, + index_update_method=_TEST_INDEX_UPDATE_METHOD_EXPECTED_RESULT_MAP[ + index_update_method + ], ) create_index_mock.assert_called_once_with(