From 58e2c97b81d8d9ed3b67aa88afb4d6cda2f2a8d0 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Fri, 17 Jun 2022 01:42:20 +0000 Subject: [PATCH] dispatch ML task to ML node first (#346) Signed-off-by: Yaliang Wu (cherry picked from commit 6cbb626ea6ddf2d02fea406f0d9422a5a466117c) --- .../ml/plugin/MachineLearningPlugin.java | 11 +--- .../opensearch/ml/task/MLTaskDispatcher.java | 28 +++++++-- .../org/opensearch/ml/utils/MLNodeUtils.java | 5 +- .../ml/task/MLTaskDispatcherTests.java | 59 ++++++++++++++----- .../opensearch/ml/utils/MLNodeUtilsTests.java | 4 +- .../org/opensearch/ml/utils/TestHelper.java | 12 ++++ 6 files changed, 84 insertions(+), 35 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 9c5ae54b73..e87ea797ed 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -19,13 +19,11 @@ import org.opensearch.action.ActionResponse; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; -import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.IndexScopedSettings; -import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.SettingsFilter; import org.opensearch.common.xcontent.NamedXContentRegistry; @@ -115,14 +113,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin { private ClusterService clusterService; private ThreadPool threadPool; - public static final Setting IS_ML_NODE_SETTING = Setting.boolSetting("node.ml", false, Setting.Property.NodeScope); - - public static final DiscoveryNodeRole ML_ROLE = new DiscoveryNodeRole("ml", "l") { - @Override - public Setting legacySetting() { - return IS_ML_NODE_SETTING; - } - }; + public static final String ML_ROLE_NAME = "ml"; @Override public List> getActions() { diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java index eb45428314..1ae94bcbb2 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java @@ -6,6 +6,7 @@ package org.opensearch.ml.task; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -23,6 +24,7 @@ import org.opensearch.ml.action.stats.MLStatsNodesAction; import org.opensearch.ml.action.stats.MLStatsNodesRequest; import org.opensearch.ml.stats.MLNodeLevelStat; +import org.opensearch.ml.utils.MLNodeUtils; import com.google.common.collect.ImmutableSet; @@ -49,9 +51,7 @@ public MLTaskDispatcher(ClusterService clusterService, Client client) { * @param listener Action listener */ public void dispatchTask(ActionListener listener) { - // todo: add ML node type setting check - // DiscoveryNode[] mlNodes = getEligibleMLNodes(); - DiscoveryNode[] mlNodes = getEligibleDataNodes(); + DiscoveryNode[] mlNodes = getEligibleNodes(); MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(mlNodes); MLStatsNodesRequest .addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE)); @@ -107,14 +107,32 @@ public void dispatchTask(ActionListener listener) { })); } - private DiscoveryNode[] getEligibleDataNodes() { + /** + * Get eligible node to run ML task. If there are nodes with ml role, will return all these + * ml nodes; otherwise return all data nodes. + * + * @return array of discovery node + */ + protected DiscoveryNode[] getEligibleNodes() { ClusterState state = this.clusterService.state(); + final List eligibleMLNodes = new ArrayList<>(); final List eligibleDataNodes = new ArrayList<>(); for (DiscoveryNode node : state.nodes()) { + if (MLNodeUtils.isMLNode(node)) { + eligibleMLNodes.add(node); + } if (node.isDataNode()) { eligibleDataNodes.add(node); } } - return eligibleDataNodes.toArray(new DiscoveryNode[0]); + if (eligibleMLNodes.size() > 0) { + DiscoveryNode[] mlNodes = eligibleMLNodes.toArray(new DiscoveryNode[0]); + log.debug("Find {} dedicated ML nodes: {}", eligibleMLNodes.size(), Arrays.toString(mlNodes)); + return mlNodes; + } else { + DiscoveryNode[] dataNodes = eligibleDataNodes.toArray(new DiscoveryNode[0]); + log.debug("Find no dedicated ML nodes. But have {} data nodes: {}", eligibleDataNodes.size(), Arrays.toString(dataNodes)); + return dataNodes; + } } } diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java index 54c7fc5b7b..d5d23b480a 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java @@ -5,6 +5,8 @@ package org.opensearch.ml.utils; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_ROLE_NAME; + import java.io.IOException; import lombok.experimental.UtilityClass; @@ -12,12 +14,11 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.*; -import org.opensearch.ml.plugin.MachineLearningPlugin; @UtilityClass public class MLNodeUtils { public boolean isMLNode(DiscoveryNode node) { - return node.getRoles().stream().anyMatch(role -> role.roleName().equalsIgnoreCase(MachineLearningPlugin.ML_ROLE.roleName())); + return node.getRoles().stream().anyMatch(role -> role.roleName().equalsIgnoreCase(ML_ROLE_NAME)); } public static XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference) diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java index c8a3b1f82a..423369bc2e 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskDispatcherTests.java @@ -7,11 +7,12 @@ import static org.mockito.Mockito.*; import static org.opensearch.ml.common.breaker.MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_ROLE_NAME; +import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; import java.util.Set; @@ -35,6 +36,8 @@ import org.opensearch.ml.stats.MLNodeLevelStat; import org.opensearch.test.OpenSearchTestCase; +import com.google.common.collect.ImmutableSet; + public class MLTaskDispatcherTests extends OpenSearchTestCase { @Mock @@ -48,8 +51,9 @@ public class MLTaskDispatcherTests extends OpenSearchTestCase { MLTaskDispatcher taskDispatcher; ClusterState testState; - DiscoveryNode node1; - DiscoveryNode node2; + DiscoveryNode dataNode1; + DiscoveryNode dataNode2; + DiscoveryNode mlNode; MLStatsNodesResponse mlStatsNodesResponse; String clusterName = "test cluster"; @@ -59,11 +63,12 @@ public void setup() { taskDispatcher = spy(new MLTaskDispatcher(clusterService, client)); - Set roleSet = new HashSet<>(); - roleSet.add(DiscoveryNodeRole.DATA_ROLE); - node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), new HashMap<>(), roleSet, Version.CURRENT); - node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), new HashMap<>(), roleSet, Version.CURRENT); - DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build(); + Set dataRoleSet = ImmutableSet.of(DiscoveryNodeRole.DATA_ROLE); + dataNode1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), new HashMap<>(), dataRoleSet, Version.CURRENT); + dataNode2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), new HashMap<>(), dataRoleSet, Version.CURRENT); + Set mlRoleSet = ImmutableSet.of(ML_ROLE); + mlNode = new DiscoveryNode("mlNode", buildNewFakeTransportAddress(), new HashMap<>(), mlRoleSet, Version.CURRENT); + DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).build(); testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, null, 0, false); when(clusterService.state()).thenReturn(testState); @@ -111,12 +116,34 @@ public void testDispatchTask_TaskCountExceedLimit() { assertEquals(errorMessage, argumentCaptor.getValue().getMessage()); } + public void testGetEligibleNodes_DataNodeOnly() { + DiscoveryNode[] eligibleNodes = taskDispatcher.getEligibleNodes(); + assertEquals(2, eligibleNodes.length); + for (DiscoveryNode node : eligibleNodes) { + assertTrue(node.isDataNode()); + } + } + + public void testGetEligibleNodes_MlAndDataNodes() { + DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).add(mlNode).build(); + testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, null, 0, false); + when(clusterService.state()).thenReturn(testState); + + DiscoveryNode[] eligibleNodes = taskDispatcher.getEligibleNodes(); + assertEquals(1, eligibleNodes.length); + for (DiscoveryNode node : eligibleNodes) { + assertFalse(node.isDataNode()); + DiscoveryNodeRole[] discoveryNodeRoles = node.getRoles().toArray(new DiscoveryNodeRole[0]); + assertEquals(ML_ROLE_NAME, discoveryNodeRoles[0].roleName()); + } + } + private MLStatsNodesResponse getMlStatsNodesResponse() { Map nodeStats = new HashMap<>(); nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l); - MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats); - MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats); + MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); + MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( new ClusterName(clusterName), Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2), @@ -127,8 +154,8 @@ private MLStatsNodesResponse getMlStatsNodesResponse() { private MLStatsNodesResponse getNodesResponse_NoTaskCounts() { Map nodeStats = new HashMap<>(); nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); - MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats); - MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats); + MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); + MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( new ClusterName(clusterName), Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2), @@ -140,8 +167,8 @@ private MLStatsNodesResponse getNodesResponse_MemoryExceedLimits() { Map nodeStats = new HashMap<>(); nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 90l); nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l); - MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats); - MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats); + MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); + MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( new ClusterName(clusterName), Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2), @@ -153,8 +180,8 @@ private MLStatsNodesResponse getNodesResponse_TaskCountExceedLimits() { Map nodeStats = new HashMap<>(); nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l); nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 15l); - MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats); - MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats); + MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats); + MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats); return new MLStatsNodesResponse( new ClusterName(clusterName), Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2), diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java index d5e35a4f6a..5cb3b76760 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java @@ -6,6 +6,7 @@ package org.opensearch.ml.utils; import static java.util.Collections.emptyMap; +import static org.opensearch.ml.utils.TestHelper.ML_ROLE; import java.io.IOException; import java.util.HashSet; @@ -22,7 +23,6 @@ import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentParser; import org.opensearch.ml.common.MLTask; -import org.opensearch.ml.plugin.MachineLearningPlugin; import org.opensearch.test.OpenSearchTestCase; public class MLNodeUtilsTests extends OpenSearchTestCase { @@ -34,7 +34,7 @@ public void testIsMLNode() { DiscoveryNode normalNode = new DiscoveryNode("Normal node", buildNewFakeTransportAddress(), emptyMap(), roleSet, Version.CURRENT); Assert.assertFalse(MLNodeUtils.isMLNode(normalNode)); - roleSet.add(MachineLearningPlugin.ML_ROLE); + roleSet.add(ML_ROLE); DiscoveryNode mlNode = new DiscoveryNode("ML node", buildNewFakeTransportAddress(), emptyMap(), roleSet, Version.CURRENT); Assert.assertTrue(MLNodeUtils.isMLNode(mlNode)); } diff --git a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java index 54b7abd833..329cff997d 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java @@ -30,8 +30,10 @@ import org.opensearch.client.Response; import org.opensearch.client.RestClient; import org.opensearch.client.WarningsHandler; +import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.NamedXContentRegistry; @@ -56,6 +58,16 @@ import com.google.common.collect.ImmutableMap; public class TestHelper { + + public static final Setting IS_ML_NODE_SETTING = Setting.boolSetting("node.ml", false, Setting.Property.NodeScope); + + public static final DiscoveryNodeRole ML_ROLE = new DiscoveryNodeRole("ml", "ml") { + @Override + public Setting legacySetting() { + return IS_ML_NODE_SETTING; + } + }; + public static XContentParser parser(String xc) throws IOException { return parser(xc, true); }