Skip to content

Commit

Permalink
dispatch ML task to ML node first (#346) (#347)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
(cherry picked from commit 6cbb626)

Co-authored-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and ylwu-amzn authored Jun 17, 2022
1 parent 0573bb7 commit b04caff
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
import org.opensearch.action.ActionResponse;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.node.DiscoveryNodeRole;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.io.stream.NamedWriteableRegistry;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.IndexScopedSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.settings.SettingsFilter;
import org.opensearch.common.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -115,14 +113,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {
private ClusterService clusterService;
private ThreadPool threadPool;

public static final Setting<Boolean> IS_ML_NODE_SETTING = Setting.boolSetting("node.ml", false, Setting.Property.NodeScope);

public static final DiscoveryNodeRole ML_ROLE = new DiscoveryNodeRole("ml", "l") {
@Override
public Setting<Boolean> legacySetting() {
return IS_ML_NODE_SETTING;
}
};
public static final String ML_ROLE_NAME = "ml";

@Override
public List<ActionHandler<? extends ActionRequest, ? extends ActionResponse>> getActions() {
Expand Down
28 changes: 23 additions & 5 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.task;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
Expand All @@ -23,6 +24,7 @@
import org.opensearch.ml.action.stats.MLStatsNodesAction;
import org.opensearch.ml.action.stats.MLStatsNodesRequest;
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.ml.utils.MLNodeUtils;

import com.google.common.collect.ImmutableSet;

Expand All @@ -49,9 +51,7 @@ public MLTaskDispatcher(ClusterService clusterService, Client client) {
* @param listener Action listener
*/
public void dispatchTask(ActionListener<DiscoveryNode> listener) {
// todo: add ML node type setting check
// DiscoveryNode[] mlNodes = getEligibleMLNodes();
DiscoveryNode[] mlNodes = getEligibleDataNodes();
DiscoveryNode[] mlNodes = getEligibleNodes();
MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(mlNodes);
MLStatsNodesRequest
.addNodeLevelStats(ImmutableSet.of(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE));
Expand Down Expand Up @@ -107,14 +107,32 @@ public void dispatchTask(ActionListener<DiscoveryNode> listener) {
}));
}

private DiscoveryNode[] getEligibleDataNodes() {
/**
* Get eligible node to run ML task. If there are nodes with ml role, will return all these
* ml nodes; otherwise return all data nodes.
*
* @return array of discovery node
*/
protected DiscoveryNode[] getEligibleNodes() {
ClusterState state = this.clusterService.state();
final List<DiscoveryNode> eligibleMLNodes = new ArrayList<>();
final List<DiscoveryNode> eligibleDataNodes = new ArrayList<>();
for (DiscoveryNode node : state.nodes()) {
if (MLNodeUtils.isMLNode(node)) {
eligibleMLNodes.add(node);
}
if (node.isDataNode()) {
eligibleDataNodes.add(node);
}
}
return eligibleDataNodes.toArray(new DiscoveryNode[0]);
if (eligibleMLNodes.size() > 0) {
DiscoveryNode[] mlNodes = eligibleMLNodes.toArray(new DiscoveryNode[0]);
log.debug("Find {} dedicated ML nodes: {}", eligibleMLNodes.size(), Arrays.toString(mlNodes));
return mlNodes;
} else {
DiscoveryNode[] dataNodes = eligibleDataNodes.toArray(new DiscoveryNode[0]);
log.debug("Find no dedicated ML nodes. But have {} data nodes: {}", eligibleDataNodes.size(), Arrays.toString(dataNodes));
return dataNodes;
}
}
}
5 changes: 3 additions & 2 deletions plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,20 @@

package org.opensearch.ml.utils;

import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_ROLE_NAME;

import java.io.IOException;

import lombok.experimental.UtilityClass;

import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.xcontent.*;
import org.opensearch.ml.plugin.MachineLearningPlugin;

@UtilityClass
public class MLNodeUtils {
public boolean isMLNode(DiscoveryNode node) {
return node.getRoles().stream().anyMatch(role -> role.roleName().equalsIgnoreCase(MachineLearningPlugin.ML_ROLE.roleName()));
return node.getRoles().stream().anyMatch(role -> role.roleName().equalsIgnoreCase(ML_ROLE_NAME));
}

public static XContentParser createXContentParserFromRegistry(NamedXContentRegistry xContentRegistry, BytesReference bytesReference)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@

import static org.mockito.Mockito.*;
import static org.opensearch.ml.common.breaker.MemoryCircuitBreaker.DEFAULT_JVM_HEAP_USAGE_THRESHOLD;
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_ROLE_NAME;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

Expand All @@ -35,6 +36,8 @@
import org.opensearch.ml.stats.MLNodeLevelStat;
import org.opensearch.test.OpenSearchTestCase;

import com.google.common.collect.ImmutableSet;

public class MLTaskDispatcherTests extends OpenSearchTestCase {

@Mock
Expand All @@ -48,8 +51,9 @@ public class MLTaskDispatcherTests extends OpenSearchTestCase {

MLTaskDispatcher taskDispatcher;
ClusterState testState;
DiscoveryNode node1;
DiscoveryNode node2;
DiscoveryNode dataNode1;
DiscoveryNode dataNode2;
DiscoveryNode mlNode;
MLStatsNodesResponse mlStatsNodesResponse;
String clusterName = "test cluster";

Expand All @@ -59,11 +63,12 @@ public void setup() {

taskDispatcher = spy(new MLTaskDispatcher(clusterService, client));

Set<DiscoveryNodeRole> roleSet = new HashSet<>();
roleSet.add(DiscoveryNodeRole.DATA_ROLE);
node1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), new HashMap<>(), roleSet, Version.CURRENT);
node2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), new HashMap<>(), roleSet, Version.CURRENT);
DiscoveryNodes nodes = DiscoveryNodes.builder().add(node1).add(node2).build();
Set<DiscoveryNodeRole> dataRoleSet = ImmutableSet.of(DiscoveryNodeRole.DATA_ROLE);
dataNode1 = new DiscoveryNode("node1", buildNewFakeTransportAddress(), new HashMap<>(), dataRoleSet, Version.CURRENT);
dataNode2 = new DiscoveryNode("node2", buildNewFakeTransportAddress(), new HashMap<>(), dataRoleSet, Version.CURRENT);
Set<DiscoveryNodeRole> mlRoleSet = ImmutableSet.of(ML_ROLE);
mlNode = new DiscoveryNode("mlNode", buildNewFakeTransportAddress(), new HashMap<>(), mlRoleSet, Version.CURRENT);
DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).build();
testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, null, 0, false);
when(clusterService.state()).thenReturn(testState);

Expand Down Expand Up @@ -111,12 +116,34 @@ public void testDispatchTask_TaskCountExceedLimit() {
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
}

public void testGetEligibleNodes_DataNodeOnly() {
DiscoveryNode[] eligibleNodes = taskDispatcher.getEligibleNodes();
assertEquals(2, eligibleNodes.length);
for (DiscoveryNode node : eligibleNodes) {
assertTrue(node.isDataNode());
}
}

public void testGetEligibleNodes_MlAndDataNodes() {
DiscoveryNodes nodes = DiscoveryNodes.builder().add(dataNode1).add(dataNode2).add(mlNode).build();
testState = new ClusterState(new ClusterName(clusterName), 123l, "111111", null, null, nodes, null, null, 0, false);
when(clusterService.state()).thenReturn(testState);

DiscoveryNode[] eligibleNodes = taskDispatcher.getEligibleNodes();
assertEquals(1, eligibleNodes.length);
for (DiscoveryNode node : eligibleNodes) {
assertFalse(node.isDataNode());
DiscoveryNodeRole[] discoveryNodeRoles = node.getRoles().toArray(new DiscoveryNodeRole[0]);
assertEquals(ML_ROLE_NAME, discoveryNodeRoles[0].roleName());
}
}

private MLStatsNodesResponse getMlStatsNodesResponse() {
Map<MLNodeLevelStat, Object> nodeStats = new HashMap<>();
nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l);
nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats);
return new MLStatsNodesResponse(
new ClusterName(clusterName),
Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2),
Expand All @@ -127,8 +154,8 @@ private MLStatsNodesResponse getMlStatsNodesResponse() {
private MLStatsNodesResponse getNodesResponse_NoTaskCounts() {
Map<MLNodeLevelStat, Object> nodeStats = new HashMap<>();
nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats);
return new MLStatsNodesResponse(
new ClusterName(clusterName),
Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2),
Expand All @@ -140,8 +167,8 @@ private MLStatsNodesResponse getNodesResponse_MemoryExceedLimits() {
Map<MLNodeLevelStat, Object> nodeStats = new HashMap<>();
nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 90l);
nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 5l);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats);
return new MLStatsNodesResponse(
new ClusterName(clusterName),
Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2),
Expand All @@ -153,8 +180,8 @@ private MLStatsNodesResponse getNodesResponse_TaskCountExceedLimits() {
Map<MLNodeLevelStat, Object> nodeStats = new HashMap<>();
nodeStats.put(MLNodeLevelStat.ML_NODE_JVM_HEAP_USAGE, 50l);
nodeStats.put(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT, 15l);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(node1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse1 = new MLStatsNodeResponse(dataNode1, nodeStats);
MLStatsNodeResponse mlStatsNodeResponse2 = new MLStatsNodeResponse(dataNode1, nodeStats);
return new MLStatsNodesResponse(
new ClusterName(clusterName),
Arrays.asList(mlStatsNodeResponse1, mlStatsNodeResponse2),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.utils;

import static java.util.Collections.emptyMap;
import static org.opensearch.ml.utils.TestHelper.ML_ROLE;

import java.io.IOException;
import java.util.HashSet;
Expand All @@ -22,7 +23,6 @@
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.plugin.MachineLearningPlugin;
import org.opensearch.test.OpenSearchTestCase;

public class MLNodeUtilsTests extends OpenSearchTestCase {
Expand All @@ -34,7 +34,7 @@ public void testIsMLNode() {
DiscoveryNode normalNode = new DiscoveryNode("Normal node", buildNewFakeTransportAddress(), emptyMap(), roleSet, Version.CURRENT);
Assert.assertFalse(MLNodeUtils.isMLNode(normalNode));

roleSet.add(MachineLearningPlugin.ML_ROLE);
roleSet.add(ML_ROLE);
DiscoveryNode mlNode = new DiscoveryNode("ML node", buildNewFakeTransportAddress(), emptyMap(), roleSet, Version.CURRENT);
Assert.assertTrue(MLNodeUtils.isMLNode(mlNode));
}
Expand Down
12 changes: 12 additions & 0 deletions plugin/src/test/java/org/opensearch/ml/utils/TestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@
import org.opensearch.client.Response;
import org.opensearch.client.RestClient;
import org.opensearch.client.WarningsHandler;
import org.opensearch.cluster.node.DiscoveryNodeRole;
import org.opensearch.common.bytes.BytesArray;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.LoggingDeprecationHandler;
import org.opensearch.common.xcontent.NamedXContentRegistry;
Expand All @@ -56,6 +58,16 @@
import com.google.common.collect.ImmutableMap;

public class TestHelper {

public static final Setting<Boolean> IS_ML_NODE_SETTING = Setting.boolSetting("node.ml", false, Setting.Property.NodeScope);

public static final DiscoveryNodeRole ML_ROLE = new DiscoveryNodeRole("ml", "ml") {
@Override
public Setting<Boolean> legacySetting() {
return IS_ML_NODE_SETTING;
}
};

public static XContentParser parser(String xc) throws IOException {
return parser(xc, true);
}
Expand Down

0 comments on commit b04caff

Please sign in to comment.