From aac0926656a14e6c32df3aa4d8441559caa261ef Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Tue, 21 Mar 2023 14:33:17 -0700 Subject: [PATCH] add exclude nodes setting (#813) * add exclude nodes setting Signed-off-by: Yaliang Wu * fix duplicate nodes if node has both ml and data roles Signed-off-by: Yaliang Wu --------- Signed-off-by: Yaliang Wu --- .../unload/TransportUnloadModelAction.java | 6 +- .../ml/cluster/DiscoveryNodeHelper.java | 30 +++++ .../opensearch/ml/model/MLModelManager.java | 33 +++++- .../ml/plugin/MachineLearningPlugin.java | 6 +- .../ml/settings/MLCommonsSettings.java | 3 + .../ml/task/MLPredictTaskRunner.java | 2 +- .../opensearch/ml/task/MLTaskDispatcher.java | 2 +- .../ml/cluster/DiscoveryNodeHelperTests.java | 104 ++++++++++++++++-- .../ml/model/MLModelManagerTests.java | 55 ++++++++- .../ml/rest/RestMLTrainAndPredictIT.java | 62 +++++++++++ 10 files changed, 281 insertions(+), 22 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/action/unload/TransportUnloadModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/unload/TransportUnloadModelAction.java index 40250680c0..5a4a777d31 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/unload/TransportUnloadModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/unload/TransportUnloadModelAction.java @@ -198,8 +198,10 @@ private UnloadModelNodeResponse createUnloadModelNodeResponse(UnloadModelNodesRe String[] modelIds = unloadModelNodesRequest.getModelIds(); Map modelWorkerNodeCounts = new HashMap<>(); - if (modelIds != null) { - for (String modelId : modelIds) { + boolean specifiedModelIds = modelIds != null && modelIds.length > 0; + String[] removedModelIds = specifiedModelIds ? modelIds : mlModelManager.getAllModelIds(); + if (removedModelIds != null) { + for (String modelId : removedModelIds) { String[] workerNodes = mlModelManager.getWorkerNodes(modelId); modelWorkerNodeCounts.put(modelId, workerNodes == null ? 0 : workerNodes.length); } diff --git a/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java b/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java index 1a2ac43f26..0709490b7c 100644 --- a/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java +++ b/plugin/src/main/java/org/opensearch/ml/cluster/DiscoveryNodeHelper.java @@ -5,6 +5,7 @@ package org.opensearch.ml.cluster; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE; import java.util.ArrayList; @@ -19,6 +20,7 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Strings; import org.opensearch.common.settings.Settings; import org.opensearch.ml.common.CommonValue; import org.opensearch.ml.utils.MLNodeUtils; @@ -28,12 +30,17 @@ public class DiscoveryNodeHelper { private final ClusterService clusterService; private final HotDataNodePredicate eligibleNodeFilter; private volatile Boolean onlyRunOnMLNode; + private volatile Set excludedNodeNames; public DiscoveryNodeHelper(ClusterService clusterService, Settings settings) { this.clusterService = clusterService; eligibleNodeFilter = new HotDataNodePredicate(); onlyRunOnMLNode = ML_COMMONS_ONLY_RUN_ON_ML_NODE.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_ONLY_RUN_ON_ML_NODE, it -> onlyRunOnMLNode = it); + excludedNodeNames = Strings.commaDelimitedListToSet(ML_COMMONS_EXCLUDE_NODE_NAMES.get(settings)); + clusterService + .getClusterSettings() + .addSettingsUpdateConsumer(ML_COMMONS_EXCLUDE_NODE_NAMES, it -> excludedNodeNames = Strings.commaDelimitedListToSet(it)); } public String[] getEligibleNodeIds() { @@ -50,6 +57,9 @@ public DiscoveryNode[] getEligibleNodes() { final List eligibleMLNodes = new ArrayList<>(); final List eligibleDataNodes = new ArrayList<>(); for (DiscoveryNode node : state.nodes()) { + if (excludedNodeNames != null && excludedNodeNames.contains(node.getName())) { + continue; + } if (MLNodeUtils.isMLNode(node)) { eligibleMLNodes.add(node); } @@ -68,6 +78,26 @@ public DiscoveryNode[] getEligibleNodes() { } } + public String[] filterEligibleNodes(String[] nodeIds) { + if (nodeIds == null || nodeIds.length == 0) { + return nodeIds; + } + DiscoveryNode[] nodes = getNodes(nodeIds); + final Set eligibleNodes = new HashSet<>(); + for (DiscoveryNode node : nodes) { + if (excludedNodeNames != null && excludedNodeNames.contains(node.getName())) { + continue; + } + if (MLNodeUtils.isMLNode(node)) { + eligibleNodes.add(node.getId()); + } + if (!onlyRunOnMLNode && node.isDataNode() && isEligibleDataNode(node)) { + eligibleNodes.add(node.getId()); + } + } + return eligibleNodes.toArray(new String[0]); + } + public DiscoveryNode[] getAllNodes() { ClusterState state = this.clusterService.state(); final List nodes = new ArrayList<>(); diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java index 3c38d51299..d12905f80f 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLModelManager.java @@ -75,6 +75,7 @@ import org.opensearch.index.reindex.DeleteByQueryAction; import org.opensearch.index.reindex.DeleteByQueryRequest; import org.opensearch.ml.breaker.MLCircuitBreakerService; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; @@ -126,6 +127,7 @@ public class MLModelManager { private final MLIndicesHandler mlIndicesHandler; private final MLTaskManager mlTaskManager; private final MLEngine mlEngine; + private final DiscoveryNodeHelper nodeHelper; private volatile Integer maxModelPerNode; private volatile Integer maxUploadTasksPerNode; @@ -153,7 +155,8 @@ public MLModelManager( MLIndicesHandler mlIndicesHandler, MLTaskManager mlTaskManager, MLModelCacheHelper modelCacheHelper, - MLEngine mlEngine + MLEngine mlEngine, + DiscoveryNodeHelper nodeHelper ) { this.client = client; this.threadPool = threadPool; @@ -166,6 +169,7 @@ public MLModelManager( this.mlIndicesHandler = mlIndicesHandler; this.mlTaskManager = mlTaskManager; this.mlEngine = mlEngine; + this.nodeHelper = nodeHelper; this.maxModelPerNode = ML_COMMONS_MAX_MODELS_PER_NODE.get(settings); clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_MAX_MODELS_PER_NODE, it -> maxModelPerNode = it); @@ -726,13 +730,36 @@ private void removeModel(String modelId) { } /** - * Get worker nodes of specif model. + * Get worker nodes of specific model. + * + * @param modelId model id + * @param onlyEligibleNode return only eligible node + * @return list of worker node ids + */ + public String[] getWorkerNodes(String modelId, boolean onlyEligibleNode) { + String[] workerNodeIds = modelCacheHelper.getWorkerNodes(modelId); + if (!onlyEligibleNode) { + return workerNodeIds; + } + if (workerNodeIds == null || workerNodeIds.length == 0) { + return workerNodeIds; + } + + String[] eligibleNodeIds = nodeHelper.filterEligibleNodes(workerNodeIds); + if (eligibleNodeIds == null || eligibleNodeIds.length == 0) { + throw new IllegalArgumentException("No eligible worker node found"); + } + return eligibleNodeIds; + } + + /** + * Get worker node of specific model without filtering eligible node. * * @param modelId model id * @return list of worker node ids */ public String[] getWorkerNodes(String modelId) { - return modelCacheHelper.getWorkerNodes(modelId); + return getWorkerNodes(modelId, false); } /** diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 83cec6d1e6..fb97ffc36e 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -268,7 +268,8 @@ public Collection createComponents( mlIndicesHandler, mlTaskManager, modelCacheHelper, - mlEngine + mlEngine, + nodeHelper ); mlInputDatasetHandler = new MLInputDatasetHandler(client); @@ -513,7 +514,8 @@ public List> getSettings() { MLCommonsSettings.ML_COMMONS_MAX_ML_TASK_PER_NODE, MLCommonsSettings.ML_COMMONS_MAX_LOAD_MODEL_TASKS_PER_NODE, MLCommonsSettings.ML_COMMONS_TRUSTED_URL_REGEX, - MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD + MLCommonsSettings.ML_COMMONS_NATIVE_MEM_THRESHOLD, + MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES ); return settings; } diff --git a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java index 3bae6ea92a..f9b69872e0 100644 --- a/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java +++ b/plugin/src/main/java/org/opensearch/ml/settings/MLCommonsSettings.java @@ -56,4 +56,7 @@ private MLCommonsSettings() {} public static final Setting ML_COMMONS_NATIVE_MEM_THRESHOLD = Setting .intSetting("plugins.ml_commons.native_memory_threshold", 90, 0, 100, Setting.Property.NodeScope, Setting.Property.Dynamic); + + public static final Setting ML_COMMONS_EXCLUDE_NODE_NAMES = Setting + .simpleString("plugins.ml_commons.exclude_nodes._name", Setting.Property.NodeScope, Setting.Property.Dynamic); } diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java index 200a7d072b..07f48d4616 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLPredictTaskRunner.java @@ -126,7 +126,7 @@ public void dispatchTask(MLPredictionTaskRequest request, TransportService trans transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener)); } }, e -> { listener.onFailure(e); }); - String[] workerNodes = mlModelManager.getWorkerNodes(modelId); + String[] workerNodes = mlModelManager.getWorkerNodes(modelId, true); if (workerNodes == null || workerNodes.length == 0) { if (algorithm == FunctionName.TEXT_EMBEDDING) { listener.onFailure(new IllegalArgumentException("model not loaded")); diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java index 3d46b16631..121da93528 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java @@ -74,7 +74,7 @@ public void dispatch(ActionListener actionListener) { public void dispatchPredictTask(String[] nodeIds, ActionListener actionListener) { if (nodeIds == null || nodeIds.length == 0) { - throw new IllegalArgumentException("Model not loaded yet"); + throw new IllegalArgumentException("no eligible node to run predict request"); } if (ROUND_ROBIN.equals(dispatchPolicy)) { dispatchTaskWithRoundRobin( diff --git a/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java b/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java index f93d19216a..99d9e90a8c 100644 --- a/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java +++ b/plugin/src/test/java/org/opensearch/ml/cluster/DiscoveryNodeHelperTests.java @@ -9,12 +9,16 @@ import static java.util.Collections.emptySet; import static org.mockito.Mockito.when; import static org.opensearch.cluster.node.DiscoveryNodeRole.BUILT_IN_ROLES; +import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_EXCLUDE_NODE_NAMES; import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_ONLY_RUN_ON_ML_NODE; import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import static org.opensearch.ml.utils.TestHelper.clusterSetting; import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; import java.util.Set; +import java.util.stream.Collectors; import org.junit.Before; import org.mockito.Mock; @@ -36,10 +40,14 @@ public class DiscoveryNodeHelperTests extends OpenSearchTestCase { private final String clusterManagerNodeId = "clusterManagerNode"; private final String dataNode1Id = "dataNode1"; + private final String dataNode1Name = "dataNodeName1"; private final String dataNode2Id = "dataNode2"; + private final String dataNode2Name = "dataNodeName2"; private final String warmDataNode1Id = "warmDataNode1"; private final String mlNode1Id = "mlNode1"; + private final String mlNode1Name = "mlNodeName1"; private final String mlNode2Id = "mlNode2"; + private final String mlNode2Name = "mlNodeName2"; private final String clusterName = "multi-node-cluster"; @Mock @@ -54,11 +62,13 @@ public class DiscoveryNodeHelperTests extends OpenSearchTestCase { private DiscoveryNode mlNode1; private DiscoveryNode mlNode2; private ClusterState clusterState; + private String nonExistingNodeName; @Before public void setup() throws IOException { MockitoAnnotations.openMocks(this); - mockSettings(true); + nonExistingNodeName = randomAlphaOfLength(5); + mockSettings(true, nonExistingNodeName); clusterManagerNode = new DiscoveryNode( clusterManagerNodeId, @@ -68,8 +78,16 @@ public void setup() throws IOException { Version.CURRENT ); - dataNode1 = new DiscoveryNode(dataNode1Id, buildNewFakeTransportAddress(), emptyMap(), BUILT_IN_ROLES, Version.CURRENT); + dataNode1 = new DiscoveryNode( + dataNode1Name, + dataNode1Id, + buildNewFakeTransportAddress(), + emptyMap(), + BUILT_IN_ROLES, + Version.CURRENT + ); dataNode2 = new DiscoveryNode( + dataNode1Name, dataNode2Id, buildNewFakeTransportAddress(), ImmutableMap.of(CommonValue.BOX_TYPE_KEY, CommonValue.HOT_BOX_TYPE), @@ -84,8 +102,22 @@ public void setup() throws IOException { Version.CURRENT ); - mlNode1 = new DiscoveryNode(mlNode1Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); - mlNode2 = new DiscoveryNode(mlNode2Id, buildNewFakeTransportAddress(), emptyMap(), ImmutableSet.of(ML_ROLE), Version.CURRENT); + mlNode1 = new DiscoveryNode( + mlNode1Name, + mlNode1Id, + buildNewFakeTransportAddress(), + emptyMap(), + ImmutableSet.of(ML_ROLE), + Version.CURRENT + ); + mlNode2 = new DiscoveryNode( + mlNode2Name, + mlNode2Id, + buildNewFakeTransportAddress(), + emptyMap(), + ImmutableSet.of(ML_ROLE), + Version.CURRENT + ); DiscoveryNodes nodes = DiscoveryNodes .builder() @@ -102,21 +134,27 @@ public void setup() throws IOException { discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); } - private void mockSettings(boolean onlyRunOnMLNode) { - settings = Settings.builder().put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), onlyRunOnMLNode).build(); - ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ONLY_RUN_ON_ML_NODE); + private void mockSettings(boolean onlyRunOnMLNode, String excludedNodeName) { + settings = Settings + .builder() + .put(ML_COMMONS_ONLY_RUN_ON_ML_NODE.getKey(), onlyRunOnMLNode) + .put(ML_COMMONS_EXCLUDE_NODE_NAMES.getKey(), excludedNodeName) + .build(); + ClusterSettings clusterSettings = clusterSetting(settings, ML_COMMONS_ONLY_RUN_ON_ML_NODE, ML_COMMONS_EXCLUDE_NODE_NAMES); when(clusterService.getClusterSettings()).thenReturn(clusterSettings); } public void testGetEligibleNodes_MLNode() { DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(); assertEquals(2, eligibleNodes.length); - assertEquals(mlNode1.getName(), eligibleNodes[0].getName()); - assertEquals(mlNode2.getName(), eligibleNodes[1].getName()); + Set nodeIds = new HashSet<>(); + nodeIds.addAll(Arrays.asList(eligibleNodes).stream().map(n -> n.getId()).collect(Collectors.toList())); + assertTrue(nodeIds.contains(mlNode1.getId())); + assertTrue(nodeIds.contains(mlNode2.getId())); } public void testGetEligibleNodes_DataNode() { - mockSettings(false); + mockSettings(false, nonExistingNodeName); DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); DiscoveryNodes nodes = DiscoveryNodes.builder().add(clusterManagerNode).add(dataNode1).add(dataNode2).add(warmDataNode1).build(); clusterState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, null, 0, false); @@ -124,8 +162,50 @@ public void testGetEligibleNodes_DataNode() { DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(); assertEquals(2, eligibleNodes.length); - assertEquals(mlNode1.getName(), eligibleNodes[0].getName()); - assertEquals(mlNode2.getName(), eligibleNodes[1].getName()); + assertEquals(dataNode1.getName(), eligibleNodes[0].getName()); + assertEquals(dataNode2.getName(), eligibleNodes[1].getName()); + } + + public void testGetEligibleNodes_MLNode_Excluded() { + mockSettings(false, mlNode1.getName() + "," + mlNode2.getName()); + DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); + DiscoveryNode[] eligibleNodes = discoveryNodeHelper.getEligibleNodes(); + assertEquals(2, eligibleNodes.length); + assertEquals(dataNode1.getName(), eligibleNodes[0].getName()); + assertEquals(dataNode1.getName(), eligibleNodes[1].getName()); + } + + public void testFilterEligibleNodes_Null() { + mockSettings(false, mlNode1.getName() + "," + mlNode2.getName()); + DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); + String[] eligibleNodes = discoveryNodeHelper.filterEligibleNodes(null); + assertNull(eligibleNodes); + } + + public void testFilterEligibleNodes_Empty() { + mockSettings(false, mlNode1.getName() + "," + mlNode2.getName()); + DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); + String[] eligibleNodes = discoveryNodeHelper.filterEligibleNodes(new String[] {}); + assertEquals(0, eligibleNodes.length); + } + + public void testFilterEligibleNodes() { + mockSettings(true, mlNode1.getName()); + DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); + String[] eligibleNodes = discoveryNodeHelper.filterEligibleNodes(new String[] { mlNode1Id, mlNode2Id, dataNode1Id }); + assertEquals(1, eligibleNodes.length); + assertEquals(mlNode2Id, eligibleNodes[0]); + } + + public void testFilterEligibleNodes_BothMLAndDataNodes() { + mockSettings(false, mlNode1.getName()); + DiscoveryNodeHelper discoveryNodeHelper = new DiscoveryNodeHelper(clusterService, settings); + String[] eligibleNodes = discoveryNodeHelper.filterEligibleNodes(new String[] { mlNode1Id, mlNode2Id, dataNode1Id }); + assertEquals(2, eligibleNodes.length); + Set nodeIds = new HashSet<>(); + nodeIds.addAll(Arrays.asList(eligibleNodes)); + assertTrue(nodeIds.contains(dataNode1Id)); + assertTrue(nodeIds.contains(mlNode2Id)); } public void testGetAllNodeIds() { diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java index 87a587479e..ecf6dc6f13 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLModelManagerTests.java @@ -73,6 +73,7 @@ import org.opensearch.common.xcontent.NamedXContentRegistry; import org.opensearch.ml.breaker.MLCircuitBreakerService; import org.opensearch.ml.breaker.ThresholdCircuitBreaker; +import org.opensearch.ml.cluster.DiscoveryNodeHelper; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.MLTask; @@ -148,6 +149,8 @@ public class MLModelManagerTests extends OpenSearchTestCase { private MLEngine mlEngine; @Mock ThresholdCircuitBreaker thresholdCircuitBreaker; + @Mock + DiscoveryNodeHelper nodeHelper; @Before public void setup() throws URISyntaxException { @@ -231,7 +234,8 @@ public void setup() throws URISyntaxException { mlIndicesHandler, mlTaskManager, modelCacheHelper, - mlEngine + mlEngine, + nodeHelper ) ); @@ -581,6 +585,55 @@ public void testClearRoutingTable() { verify(modelCacheHelper).clearWorkerNodes(); } + public void testGetWorkerNodes() { + String[] nodes = new String[] { "node1", "node2" }; + when(modelCacheHelper.getWorkerNodes(anyString())).thenReturn(nodes); + String[] workerNodes = modelManager.getWorkerNodes(modelId); + assertArrayEquals(nodes, workerNodes); + } + + public void testGetWorkerNodes_Null() { + when(modelCacheHelper.getWorkerNodes(anyString())).thenReturn(null); + String[] workerNodes = modelManager.getWorkerNodes(modelId); + assertNull(workerNodes); + } + + public void testGetWorkerNodes_EmptyNodes() { + when(modelCacheHelper.getWorkerNodes(anyString())).thenReturn(new String[] {}); + String[] workerNodes = modelManager.getWorkerNodes(modelId); + assertEquals(0, workerNodes.length); + } + + public void testGetWorkerNodes_FilterEligibleNodes() { + String[] nodes = new String[] { "node1", "node2" }; + when(modelCacheHelper.getWorkerNodes(anyString())).thenReturn(nodes); + + String[] eligibleNodes = new String[] { "node1" }; + when(nodeHelper.filterEligibleNodes(any())).thenReturn(eligibleNodes); + String[] workerNodes = modelManager.getWorkerNodes(modelId, true); + assertArrayEquals(eligibleNodes, workerNodes); + } + + public void testGetWorkerNodes_FilterEligibleNodes_Null() { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage("No eligible worker node found"); + String[] nodes = new String[] { "node1", "node2" }; + when(modelCacheHelper.getWorkerNodes(anyString())).thenReturn(nodes); + + when(nodeHelper.filterEligibleNodes(any())).thenReturn(null); + modelManager.getWorkerNodes(modelId, true); + } + + public void testGetWorkerNodes_FilterEligibleNodes_Empty() { + expectedEx.expect(IllegalArgumentException.class); + expectedEx.expectMessage("No eligible worker node found"); + String[] nodes = new String[] { "node1", "node2" }; + when(modelCacheHelper.getWorkerNodes(anyString())).thenReturn(nodes); + + when(nodeHelper.filterEligibleNodes(any())).thenReturn(new String[] {}); + modelManager.getWorkerNodes(modelId, true); + } + private void testLoadModel_FailedToRetrieveModelChunks(boolean lastChunk) { when(modelCacheHelper.isModelLoaded(modelId)).thenReturn(false); when(modelCacheHelper.getLoadedModels()).thenReturn(new String[] {}); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLTrainAndPredictIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLTrainAndPredictIT.java index 895ef1a4b7..a69ab549e9 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLTrainAndPredictIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLTrainAndPredictIT.java @@ -12,6 +12,8 @@ import java.util.function.Consumer; import org.apache.http.HttpEntity; +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; import org.junit.After; import org.junit.Before; import org.opensearch.client.Response; @@ -51,6 +53,66 @@ public void testTrainAndPredictKmeans() throws IOException { validateStats(FunctionName.KMEANS, ActionName.TRAIN_PREDICT, 0, 0, 2, 2); } + public void testTrainAndPredictKmeans_ExcludeNodes() throws IOException { + Response nodeResponse = TestHelper.makeRequest(client(), "GET", "/_cat/nodes", ImmutableMap.of(), (HttpEntity) null, null); + String response = TestHelper.httpEntityToString(nodeResponse.getEntity()); + String[] nodes = response.split("\n"); + StringBuilder nodeNames = new StringBuilder(); + for (String nodeString : nodes) { + String[] items = nodeString.split(" "); + if (items.length > 0) { + String nodeName = items[items.length - 1]; + nodeNames.append(nodeName).append(","); + } + } + String excludedNames = nodeNames.substring(0, nodeNames.length() - 1); + + Response updateSettingResponse = TestHelper + .makeRequest( + client(), + "PUT", + "_cluster/settings", + null, + "{\"persistent\":{\"plugins.ml_commons.exclude_nodes._name\":\"" + excludedNames + "\"}}", + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "")) + ); + assertEquals(200, updateSettingResponse.getStatusLine().getStatusCode()); + + try { + trainAndPredictKmeans(); + + // The trainAndPredictKmeans method should throw exception, so should not run this line + fail("Exclude nodes setting doesn't work"); + } catch (Exception e) { + assertTrue(e.getMessage().contains("400 Bad Request")); + assertTrue(e.getMessage().contains("\"reason\":\"No eligible node found to execute this request")); + } + } + + private Response trainAndPredictKmeans() throws IOException { + KMeansParams params = KMeansParams.builder().centroids(3).build(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.query(new MatchAllQueryBuilder()); + sourceBuilder.size(1000); + sourceBuilder.fetchSource(new String[] { "petal_length_in_cm", "petal_width_in_cm" }, null); + MLInputDataset inputData = SearchQueryInputDataset + .builder() + .indices(ImmutableList.of(irisIndex)) + .searchSourceBuilder(sourceBuilder) + .build(); + MLInput kmeansInput = MLInput.builder().algorithm(FunctionName.KMEANS).parameters(params).inputDataset(inputData).build(); + Response kmeansResponse = TestHelper + .makeRequest( + client(), + "POST", + "/_plugins/_ml/_train_predict/kmeans", + ImmutableMap.of(), + TestHelper.toHttpEntity(kmeansInput), + null + ); + return kmeansResponse; + } + private void trainAndPredictKmeansWithCustomParam() throws IOException { KMeansParams params = KMeansParams.builder().centroids(3).build(); trainAndPredictKmeansWithParmas(params, clusterCount -> assertTrue(clusterCount.size() >= 2));