Skip to content

Commit

Permalink
Check before delete (#3209)
Browse files Browse the repository at this point in the history
* add logic to detect agent before deleting

Signed-off-by: xinyual <xinyual@amazon.com>

* add logic to detect agent before deleting

Signed-off-by: xinyual <xinyual@amazon.com>

* add logic to detect pipelines before delete model

Signed-off-by: xinyual <xinyual@amazon.com>

* check pipeline before deleting

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spotless

Signed-off-by: xinyual <xinyual@amazon.com>

* remove useless file

Signed-off-by: xinyual <xinyual@amazon.com>

* rename functions

Signed-off-by: xinyual <xinyual@amazon.com>

* fix failure test

Signed-off-by: xinyual <xinyual@amazon.com>

* add UT

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spotless

Signed-off-by: xinyual <xinyual@amazon.com>

* renam

Signed-off-by: xinyual <xinyual@amazon.com>

* refactor to parallel check

Signed-off-by: xinyual <xinyual@amazon.com>

* concate error message

Signed-off-by: xinyual <xinyual@amazon.com>

* move logic after user access check

Signed-off-by: xinyual <xinyual@amazon.com>

* change agent model searcher map to set

Signed-off-by: xinyual <xinyual@amazon.com>

* rename and remove useless method

Signed-off-by: xinyual <xinyual@amazon.com>

* fix bug to fetch all pipelines

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spotless

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spotless

Signed-off-by: xinyual <xinyual@amazon.com>

* remove and add comment

Signed-off-by: xinyual <xinyual@amazon.com>

* rename and add more UTs

Signed-off-by: xinyual <xinyual@amazon.com>

* use correct key

Signed-off-by: xinyual <xinyual@amazon.com>

* simplify function

Signed-off-by: xinyual <xinyual@amazon.com>

* change to a better class

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spotless

Signed-off-by: xinyual <xinyual@amazon.com>

* change compareAndSet to set

Signed-off-by: xinyual <xinyual@amazon.com>

* apply comment

Signed-off-by: xinyual <xinyual@amazon.com>

* change name and reformat logic

Signed-off-by: xinyual <xinyual@amazon.com>

* change name

Signed-off-by: xinyual <xinyual@amazon.com>

* remove useless line

Signed-off-by: xinyual <xinyual@amazon.com>

* change to a better method

Signed-off-by: xinyual <xinyual@amazon.com>

* change name

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spotless

Signed-off-by: xinyual <xinyual@amazon.com>

* add java doc for function

Signed-off-by: xinyual <xinyual@amazon.com>

* add another interface

Signed-off-by: xinyual <xinyual@amazon.com>

* apply java spotless

Signed-off-by: xinyual <xinyual@amazon.com>

* change interface to with model

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spot less

Signed-off-by: xinyual <xinyual@amazon.com>

* add settings

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spot less

Signed-off-by: xinyual <xinyual@amazon.com>

* add test for cluster setting

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spotless

Signed-off-by: xinyual <xinyual@amazon.com>

* recover useless change

Signed-off-by: xinyual <xinyual@amazon.com>

* change default value of cluster setting

Signed-off-by: xinyual <xinyual@amazon.com>

* rename setting and add comment

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spot

Signed-off-by: xinyual <xinyual@amazon.com>

* remove logic for hidden model

Signed-off-by: xinyual <xinyual@amazon.com>

* reorder code

Signed-off-by: xinyual <xinyual@amazon.com>

* reorder code

Signed-off-by: xinyual <xinyual@amazon.com>

* reorder code

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spot

Signed-off-by: xinyual <xinyual@amazon.com>

* add UT

Signed-off-by: xinyual <xinyual@amazon.com>

* add more UT

Signed-off-by: xinyual <xinyual@amazon.com>

* remove search for hidden agent

Signed-off-by: xinyual <xinyual@amazon.com>

* fix logic and apply spot

Signed-off-by: xinyual <xinyual@amazon.com>

* add exist for UT

Signed-off-by: xinyual <xinyual@amazon.com>

* change dsl to query index

Signed-off-by: xinyual <xinyual@amazon.com>

* change query logic

Signed-off-by: xinyual <xinyual@amazon.com>

* remove useless ut

Signed-off-by: xinyual <xinyual@amazon.com>

* rebert

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spot

Signed-off-by: xinyual <xinyual@amazon.com>

* rechange code

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spot

Signed-off-by: xinyual <xinyual@amazon.com>

* remove useless should

Signed-off-by: xinyual <xinyual@amazon.com>

* apply spot

Signed-off-by: xinyual <xinyual@amazon.com>

* fix final dsl logic and ut

Signed-off-by: xinyual <xinyual@amazon.com>

---------

Signed-off-by: xinyual <xinyual@amazon.com>
  • Loading branch information
xinyual authored Jan 24, 2025
1 parent af96fe0 commit 570edaf
Show file tree
Hide file tree
Showing 10 changed files with 825 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public class CommonValue {
public static final String ML_MEMORY_MESSAGE_INDEX = ".plugins-ml-memory-message";
public static final String ML_STOP_WORDS_INDEX = ".plugins-ml-stop-words";
public static final Set<String> stopWordsIndices = ImmutableSet.of(".plugins-ml-stop-words");
public static final String TOOL_PARAMETERS_PREFIX = "tools.parameters.";

// Index mapping paths
public static final String ML_MODEL_GROUP_INDEX_MAPPING_PATH = "index-mappings/ml_model_group.json";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.spi.tools.Parser;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.ToolAnnotation;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -33,7 +33,7 @@
*/
@Log4j2
@ToolAnnotation(MLModelTool.TYPE)
public class MLModelTool implements Tool {
public class MLModelTool implements WithModelTool {
public static final String TYPE = "MLModelTool";
public static final String RESPONSE_FIELD = "response_field";
public static final String MODEL_ID_FIELD = "model_id";
Expand Down Expand Up @@ -127,7 +127,7 @@ public boolean validate(Map<String, String> parameters) {
return true;
}

public static class Factory implements Tool.Factory<MLModelTool> {
public static class Factory implements WithModelTool.Factory<MLModelTool> {
private Client client;

private static Factory INSTANCE;
Expand Down Expand Up @@ -172,5 +172,10 @@ public String getDefaultType() {
public String getDefaultVersion() {
return null;
}

@Override
public List<String> getAllModelKeys() {
return List.of(MODEL_ID_FIELD);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package org.opensearch.ml.engine.utils;

import static org.opensearch.ml.common.CommonValue.ML_AGENT_INDEX;
import static org.opensearch.ml.common.CommonValue.TOOL_PARAMETERS_PREFIX;

import java.util.HashSet;
import java.util.Map;
import java.util.Set;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.WithModelTool;
import org.opensearch.search.builder.SearchSourceBuilder;

public class AgentModelsSearcher {
private final Set<String> relatedModelIdSet;

public AgentModelsSearcher(Map<String, Tool.Factory> toolFactories) {
relatedModelIdSet = new HashSet<>();
for (Map.Entry<String, Tool.Factory> entry : toolFactories.entrySet()) {
Tool.Factory toolFactory = entry.getValue();
if (toolFactory instanceof WithModelTool.Factory) {
WithModelTool.Factory withModelTool = (WithModelTool.Factory) toolFactory;
relatedModelIdSet.addAll(withModelTool.getAllModelKeys());
}
}
}

/**
* Construct a should query to search all agent which containing candidate model Id
@param candidateModelId the candidate model Id
@return a should search request towards agent index.
*/
public SearchRequest constructQueryRequestToSearchModelIdInsideAgent(String candidateModelId) {
SearchRequest searchRequest = new SearchRequest(ML_AGENT_INDEX);
// Two conditions here
// 1. {[(exists hidden field) and (hidden field = false)] or (not exist hidden field)} and
// 2. Any model field contains candidate ID
BoolQueryBuilder searchAgentQuery = QueryBuilders.boolQuery();

BoolQueryBuilder hiddenFieldQuery = QueryBuilders.boolQuery();
// not exist hidden
hiddenFieldQuery.should(QueryBuilders.boolQuery().mustNot(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD)));
// exist but equal to false
BoolQueryBuilder existHiddenFieldQuery = QueryBuilders.boolQuery();
existHiddenFieldQuery.must(QueryBuilders.termsQuery(MLAgent.IS_HIDDEN_FIELD, false));
existHiddenFieldQuery.must(QueryBuilders.existsQuery(MLAgent.IS_HIDDEN_FIELD));
hiddenFieldQuery.should(existHiddenFieldQuery);

//
BoolQueryBuilder modelIdQuery = QueryBuilders.boolQuery();
for (String keyField : relatedModelIdSet) {
modelIdQuery.should(QueryBuilders.termsQuery(TOOL_PARAMETERS_PREFIX + keyField, candidateModelId));
}

searchAgentQuery.must(hiddenFieldQuery);
searchAgentQuery.must(modelIdQuery);
searchRequest.source(new SearchSourceBuilder().query(searchAgentQuery));
return searchRequest;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;
import static org.opensearch.ml.engine.tools.MLModelTool.DEFAULT_DESCRIPTION;
import static org.opensearch.ml.engine.tools.MLModelTool.MODEL_ID_FIELD;

import java.util.Arrays;
import java.util.Collections;
Expand Down Expand Up @@ -218,5 +219,6 @@ public void testTool() {
assertTrue(tool.validate(otherParams));
assertFalse(tool.validate(emptyParams));
assertEquals(DEFAULT_DESCRIPTION, tool.getDescription());
assertEquals(List.of(MODEL_ID_FIELD), MLModelTool.Factory.getInstance().getAllModelKeys());
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.engine.utils;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import org.junit.Test;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.ExistsQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.TermsQueryBuilder;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.spi.tools.Tool;
import org.opensearch.ml.common.spi.tools.WithModelTool;

public class AgentModelSearcherTests {

@Test
public void testConstructor_CollectsModelIds() {
// Arrange
WithModelTool.Factory withModelToolFactory1 = mock(WithModelTool.Factory.class);
when(withModelToolFactory1.getAllModelKeys()).thenReturn(Arrays.asList("modelKey1", "modelKey2"));

WithModelTool.Factory withModelToolFactory2 = mock(WithModelTool.Factory.class);
when(withModelToolFactory2.getAllModelKeys()).thenReturn(Collections.singletonList("anotherModelKey"));

// This tool factory does not implement WithModelTool.Factory
Tool.Factory regularToolFactory = mock(Tool.Factory.class);

Map<String, Tool.Factory> toolFactories = new HashMap<>();
toolFactories.put("withModelTool1", withModelToolFactory1);
toolFactories.put("withModelTool2", withModelToolFactory2);
toolFactories.put("regularTool", regularToolFactory);

// Act
AgentModelsSearcher searcher = new AgentModelsSearcher(toolFactories);

// (Optional) We can't directly access relatedModelIdSet,
// but we can test the behavior indirectly using the search call:
SearchRequest request = searcher.constructQueryRequestToSearchModelIdInsideAgent("candidateId");

// Assert
// Verify the searchRequest uses all keys from the WithModelTool factories
BoolQueryBuilder boolQueryBuilder = (BoolQueryBuilder) request.source().query();
// We expect modelKey1, modelKey2, anotherModelKey => total 3 "should" clauses
assertEquals(2, boolQueryBuilder.must().size());
for (QueryBuilder query : boolQueryBuilder.must()) {
BoolQueryBuilder subBoolQueryBuilder = (BoolQueryBuilder) query;
assertTrue(subBoolQueryBuilder.should().size() == 2 || subBoolQueryBuilder.should().size() == 3);
if (subBoolQueryBuilder.should().size() == 3) {
boolQueryBuilder.should().forEach(subQuery -> {
assertTrue(subQuery instanceof TermsQueryBuilder);
TermsQueryBuilder termsQuery = (TermsQueryBuilder) subQuery;
// Each TermsQueryBuilder should contain candidateModelId
assertTrue(termsQuery.values().contains("candidateId"));
});
} else {
boolQueryBuilder.should().forEach(subQuery -> {
assertTrue(subQuery instanceof BoolQueryBuilder);
BoolQueryBuilder boolQuery = (BoolQueryBuilder) subQuery;
assertTrue(boolQuery.must().size() == 2 || boolQuery.mustNot().size() == 1);
if (boolQuery.must().size() == 2) {
boolQuery.must().forEach(existSubQuery -> {
assertTrue(existSubQuery instanceof ExistsQueryBuilder || existSubQuery instanceof TermsQueryBuilder);
if (existSubQuery instanceof TermsQueryBuilder) {
TermsQueryBuilder termsQuery = (TermsQueryBuilder) existSubQuery;
assertTrue(termsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
assertTrue(termsQuery.values().contains(false));
} else {
ExistsQueryBuilder existsQuery = (ExistsQueryBuilder) existSubQuery;
assertTrue(existsQuery.fieldName().equals(MLAgent.IS_HIDDEN_FIELD));
}
});
} else {
QueryBuilder mustNotQuery = boolQuery.mustNot().get(0);
assertTrue(mustNotQuery instanceof ExistsQueryBuilder);
assertEquals(MLAgent.IS_HIDDEN_FIELD, ((ExistsQueryBuilder) mustNotQuery).fieldName());
}
});
}
}

}
}
Loading

0 comments on commit 570edaf

Please sign in to comment.