diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java index 69047231a2..20f5764334 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskDispatcher.java @@ -9,6 +9,7 @@ import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -54,7 +55,7 @@ public MLTaskDispatcher(ClusterService clusterService, Client client) { public void dispatchTask(ActionListener listener) { // todo: add ML node type setting check // DiscoveryNode[] mlNodes = getEligibleMLNodes(); - DiscoveryNode[] mlNodes = getEligibleDataNodes(); + DiscoveryNode[] mlNodes = getEligibleNodes(); MLStatsNodesRequest MLStatsNodesRequest = new MLStatsNodesRequest(mlNodes); MLStatsNodesRequest.addAll(ImmutableSet.of(ML_EXECUTING_TASK_COUNT, JVM_HEAP_USAGE.getName())); @@ -120,14 +121,26 @@ private DiscoveryNode[] getEligibleMLNodes() { return eligibleNodes.toArray(new DiscoveryNode[0]); } - private DiscoveryNode[] getEligibleDataNodes() { + private DiscoveryNode[] getEligibleNodes() { ClusterState state = this.clusterService.state(); + final List eligibleMLNodes = new ArrayList<>(); final List eligibleDataNodes = new ArrayList<>(); for (DiscoveryNode node : state.nodes()) { + if (MLNodeUtils.isMLNode(node)) { + eligibleMLNodes.add(node); + } if (node.isDataNode()) { eligibleDataNodes.add(node); } } - return eligibleDataNodes.toArray(new DiscoveryNode[0]); + if (eligibleMLNodes.size() > 0) { + DiscoveryNode[] mlNodes = eligibleMLNodes.toArray(new DiscoveryNode[0]); + log.info("We have {} dedicated ML nodes: {}", eligibleMLNodes.size(), Arrays.toString(mlNodes)); + return mlNodes; + } else { + DiscoveryNode[] dataNodes = eligibleDataNodes.toArray(new DiscoveryNode[0]); + log.info("We have no dedicated ML nodes. But have {} data nodes: {}", eligibleDataNodes.size(), Arrays.toString(dataNodes)); + return dataNodes; + } } }