From 880b674db2a572a2063d4ca3fe33bb7327ec3cd0 Mon Sep 17 00:00:00 2001 From: Xinyuan Lu Date: Fri, 24 Jan 2025 22:24:04 +0800 Subject: [PATCH] Check before delete (#3209) (#3431) * add logic to detect agent before deleting * add logic to detect agent before deleting * add logic to detect pipelines before delete model * check pipeline before deleting * apply spotless * remove useless file * rename functions * fix failure test * add UT * apply spotless * renam * refactor to parallel check * concate error message * move logic after user access check * change agent model searcher map to set * rename and remove useless method * fix bug to fetch all pipelines * apply spotless * apply spotless * remove and add comment * rename and add more UTs * use correct key * simplify function * change to a better class * apply spotless * change compareAndSet to set * apply comment * change name and reformat logic * change name * remove useless line * change to a better method * change name * apply spotless * add java doc for function * add another interface * apply java spotless * change interface to with model * apply spot less * add settings * apply spot less * add test for cluster setting * apply spotless * recover useless change * change default value of cluster setting * rename setting and add comment * apply spot * remove logic for hidden model * reorder code * reorder code * reorder code * apply spot * add UT * add more UT * remove search for hidden agent * fix logic and apply spot * add exist for UT * change dsl to query index * change query logic * remove useless ut * rebert * apply spot * rechange code * apply spot * remove useless should * apply spot * fix final dsl logic and ut --------- (cherry picked from commit 570edaf8ca48f08491706175755330bd24075143) Signed-off-by: xinyual --- .../org/opensearch/ml/common/CommonValue.java | 1 + .../ml/engine/tools/MLModelTool.java | 11 +- .../ml/engine/utils/AgentModelsSearcher.java | 66 +++ .../ml/engine/tools/MLModelToolTests.java | 2 + .../engine/utils/AgentModelSearcherTests.java | 96 ++++ .../models/DeleteModelTransportAction.java | 212 ++++++++- .../ml/plugin/MachineLearningPlugin.java | 6 + .../ml/settings/MLCommonsSettings.java | 3 + .../DeleteModelTransportActionTests.java | 449 ++++++++++++++++-- .../ml/common/spi/tools/WithModelTool.java | 26 + 10 files changed, 825 insertions(+), 47 deletions(-) create mode 100644 ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/AgentModelsSearcher.java create mode 100644 ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/AgentModelSearcherTests.java create mode 100644 spi/src/main/java/org/opensearch/ml/common/spi/tools/WithModelTool.java diff --git a/common/src/main/java/org/opensearch/ml/common/CommonValue.java b/common/src/main/java/org/opensearch/ml/common/CommonValue.java index ef6c067b05..7110309e88 100644 --- a/common/src/main/java/org/opensearch/ml/common/CommonValue.java +++ b/common/src/main/java/org/opensearch/ml/common/CommonValue.java @@ -45,6 +45,7 @@ public class CommonValue { public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message"; public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words"; public static final Set stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words"); + public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters."; // Index mapping paths public static final String ML_MODEL_GROUP_INDEX_MAPPING_PATH = "index-mappings/ml-model-group.json"; diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java index 1bcf6c9ef0..4fb33680f1 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/tools/MLModelTool.java @@ -17,8 +17,8 @@ import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; import org.opensearch.ml.common.spi.tools.Parser; -import org.opensearch.ml.common.spi.tools.Tool; import org.opensearch.ml.common.spi.tools.ToolAnnotation; +import org.opensearch.ml.common.spi.tools.WithModelTool; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; import org.opensearch.ml.common.utils.StringUtils; @@ -33,7 +33,7 @@ */ @Log4j2 @ToolAnnotation(MLModelTool.TYPE) -public class MLModelTool implements Tool { +public class MLModelTool implements WithModelTool { public static final String TYPE = "MLModelTool"; public static final String RESPONSE_FIELD = "response_field"; public static final String MODEL_ID_FIELD = "model_id"; @@ -127,7 +127,7 @@ public boolean validate(Map parameters) { return true; } - public static class Factory implements Tool.Factory { + public static class Factory implements WithModelTool.Factory { private Client client; private static Factory INSTANCE; @@ -172,5 +172,10 @@ public String getDefaultType() { public String getDefaultVersion() { return null; } + + @Override + public List getAllModelKeys() { + return List.of(MODEL_ID_FIELD); + } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/AgentModelsSearcher.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/AgentModelsSearcher.java new file mode 100644 index 0000000000..86b2393e67 --- /dev/null +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/utils/AgentModelsSearcher.java @@ -0,0 +1,66 @@ +package org.opensearch.ml.engine.utils; + +import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX; +import static org.opensearch.ml.common.CommonValue.TOOL_PARAMETERS_PREFIX; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.WithModelTool; +import org.opensearch.search.builder.SearchSourceBuilder; + +public class AgentModelsSearcher { + private final Set relatedModelIdSet; + + public AgentModelsSearcher(Map toolFactories) { + relatedModelIdSet = new HashSet<>(); + for (Map.Entry entry : toolFactories.entrySet()) { + Tool.Factory toolFactory = entry.getValue(); + if (toolFactory instanceof WithModelTool.Factory) { + WithModelTool.Factory withModelTool = (WithModelTool.Factory) toolFactory; + relatedModelIdSet.addAll(withModelTool.getAllModelKeys()); + } + } + } + + /** + * Construct a should query to search all agent which containing candidate model Id + + @param candidateModelId the candidate model Id + @return a should search request towards agent index. + */ + public SearchRequest constructQueryRequestToSearchModelIdInsideAgent(String candidateModelId) { + SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX); + // Two conditions here + // 1. {[(exists hidden field) and (hidden field = false)] or (not exist hidden field)} and + // 2. Any model field contains candidate ID + BoolQueryBuilder searchAgentQuery = QueryBuilders.boolQuery(); + + BoolQueryBuilder hiddenFieldQuery = QueryBuilders.boolQuery(); + // not exist hidden + hiddenFieldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD))); + // exist but equal to false + BoolQueryBuilder existHiddenFieldQuery = QueryBuilders.boolQuery(); + existHiddenFieldQuery.must(QueryBuilders.termsQuery(MLAgent.IS_HIDDEN_FIELD, false)); + existHiddenFieldQuery.must(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)); + hiddenFieldQuery.should(existHiddenFieldQuery); + + // + BoolQueryBuilder modelIdQuery = QueryBuilders.boolQuery(); + for (String keyField : relatedModelIdSet) { + modelIdQuery.should(QueryBuilders.termsQuery(TOOL_PARAMETERS_PREFIX + keyField, candidateModelId)); + } + + searchAgentQuery.must(hiddenFieldQuery); + searchAgentQuery.must(modelIdQuery); + searchRequest.source(new SearchSourceBuilder().query(searchAgentQuery)); + return searchRequest; + } + +} diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java index 3aa76cd554..beeb370d52 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/tools/MLModelToolTests.java @@ -14,6 +14,7 @@ import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.verify; import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION; +import static org.opensearch.ml.engine.tools.MLModelTool.MODEL_ID_FIELD; import java.util.Arrays; import java.util.Collections; @@ -218,5 +219,6 @@ public void testTool() { assertTrue(tool.validate(otherParams)); assertFalse(tool.validate(emptyParams)); assertEquals(DEFAULT_DESCRIPTION, tool.getDescription()); + assertEquals(List.of(MODEL_ID_FIELD), MLModelTool.Factory.getInstance().getAllModelKeys()); } } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/AgentModelSearcherTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/AgentModelSearcherTests.java new file mode 100644 index 0000000000..c2f2ec1f1f --- /dev/null +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/utils/AgentModelSearcherTests.java @@ -0,0 +1,96 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.ml.engine.utils; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.ExistsQueryBuilder; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.ml.common.agent.MLAgent; +import org.opensearch.ml.common.spi.tools.Tool; +import org.opensearch.ml.common.spi.tools.WithModelTool; + +public class AgentModelSearcherTests { + + @Test + public void testConstructor_CollectsModelIds() { + // Arrange + WithModelTool.Factory withModelToolFactory1 = mock(WithModelTool.Factory.class); + when(withModelToolFactory1.getAllModelKeys()).thenReturn(Arrays.asList("modelKey1", "modelKey2")); + + WithModelTool.Factory withModelToolFactory2 = mock(WithModelTool.Factory.class); + when(withModelToolFactory2.getAllModelKeys()).thenReturn(Collections.singletonList("anotherModelKey")); + + // This tool factory does not implement WithModelTool.Factory + Tool.Factory regularToolFactory = mock(Tool.Factory.class); + + Map toolFactories = new HashMap<>(); + toolFactories.put("withModelTool1", withModelToolFactory1); + toolFactories.put("withModelTool2", withModelToolFactory2); + toolFactories.put("regularTool", regularToolFactory); + + // Act + AgentModelsSearcher searcher = new AgentModelsSearcher(toolFactories); + + // (Optional) We can't directly access relatedModelIdSet, + // but we can test the behavior indirectly using the search call: + SearchRequest request = searcher.constructQueryRequestToSearchModelIdInsideAgent("candidateId"); + + // Assert + // Verify the searchRequest uses all keys from the WithModelTool factories + BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) request.source().query(); + // We expect modelKey1, modelKey2, anotherModelKey => total 3 "should" clauses + assertEquals(2, boolQueryBuilder.must().size()); + for (QueryBuilder query : boolQueryBuilder.must()) { + BoolQueryBuilder subBoolQueryBuilder = (BoolQueryBuilder) query; + assertTrue(subBoolQueryBuilder.should().size() == 2 || subBoolQueryBuilder.should().size() == 3); + if (subBoolQueryBuilder.should().size() == 3) { + boolQueryBuilder.should().forEach(subQuery -> { + assertTrue(subQuery instanceof TermsQueryBuilder); + TermsQueryBuilder termsQuery = (TermsQueryBuilder) subQuery; + // Each TermsQueryBuilder should contain candidateModelId + assertTrue(termsQuery.values().contains("candidateId")); + }); + } else { + boolQueryBuilder.should().forEach(subQuery -> { + assertTrue(subQuery instanceof BoolQueryBuilder); + BoolQueryBuilder boolQuery = (BoolQueryBuilder) subQuery; + assertTrue(boolQuery.must().size() == 2 || boolQuery.mustNot().size() == 1); + if (boolQuery.must().size() == 2) { + boolQuery.must().forEach(existSubQuery -> { + assertTrue(existSubQuery instanceof ExistsQueryBuilder || existSubQuery instanceof TermsQueryBuilder); + if (existSubQuery instanceof TermsQueryBuilder) { + TermsQueryBuilder termsQuery = (TermsQueryBuilder) existSubQuery; + assertTrue(termsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD)); + assertTrue(termsQuery.values().contains(false)); + } else { + ExistsQueryBuilder existsQuery = (ExistsQueryBuilder) existSubQuery; + assertTrue(existsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD)); + } + }); + } else { + QueryBuilder mustNotQuery = boolQuery.mustNot().get(0); + assertTrue(mustNotQuery instanceof ExistsQueryBuilder); + assertEquals(MLAgent.IS_HIDDEN_FIELD, ((ExistsQueryBuilder) mustNotQuery).fieldName()); + } + }); + } + } + + } +} diff --git a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java index b65940399d..ae33ebb6c4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/models/DeleteModelTransportAction.java @@ -15,20 +15,36 @@ import static org.opensearch.ml.common.MLModel.IS_HIDDEN_FIELD; import static org.opensearch.ml.common.MLModel.MODEL_ID_FIELD; import static org.opensearch.ml.common.utils.StringUtils.getErrorMessage; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_SAFE_DELETE_WITH_USAGE_CHECK; import static org.opensearch.ml.utils.RestActionUtils.getFetchSourceContext; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Deque; +import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Supplier; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchStatusException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionType; import org.opensearch.action.delete.DeleteRequest; import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.get.GetResponse; +import org.opensearch.action.ingest.GetPipelineAction; +import org.opensearch.action.ingest.GetPipelineRequest; +import org.opensearch.action.search.GetSearchPipelineAction; +import org.opensearch.action.search.GetSearchPipelineRequest; +import org.opensearch.action.search.SearchRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -37,11 +53,14 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.common.xcontent.LoggingDeprecationHandler; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.index.IndexNotFoundException; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.index.query.TermsQueryBuilder; @@ -54,6 +73,7 @@ import org.opensearch.ml.common.transport.model.MLModelDeleteAction; import org.opensearch.ml.common.transport.model.MLModelDeleteRequest; import org.opensearch.ml.common.transport.model.MLModelGetRequest; +import org.opensearch.ml.engine.utils.AgentModelsSearcher; import org.opensearch.ml.helper.ModelAccessControlHelper; import org.opensearch.ml.settings.MLFeatureEnabledSetting; import org.opensearch.ml.utils.RestActionUtils; @@ -62,6 +82,7 @@ import org.opensearch.remote.metadata.client.GetDataObjectRequest; import org.opensearch.remote.metadata.client.SdkClient; import org.opensearch.remote.metadata.common.SdkClientUtils; +import org.opensearch.search.SearchHit; import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -80,6 +101,10 @@ public class DeleteModelTransportAction extends HandledTransportAction isSafeDelete = it); this.mlFeatureEnabledSetting = mlFeatureEnabledSetting; } @@ -193,7 +225,19 @@ protected void doExecute(Task task, ActionRequest request, ActionListener