From 894f98345e27a0bceae3183492a3ccad17465be6 Mon Sep 17 00:00:00 2001 From: Yaliang Wu Date: Thu, 16 Mar 2023 12:11:22 -0700 Subject: [PATCH] set model state as partially loaded if unload model from partial nodes (#806) Signed-off-by: Yaliang Wu --- .../unload/UnloadModelNodeResponse.java | 15 +++++- .../unload/UnloadModelNodeResponseTest.java | 16 +++--- .../unload/UnloadModelNodesResponseTest.java | 17 ++++-- .../unload/TransportUnloadModelAction.java | 28 +++++++++- .../TransportUnloadModelActionTests.java | 52 ++++++++++++++----- 5 files changed, 98 insertions(+), 30 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/transport/unload/UnloadModelNodeResponse.java b/common/src/main/java/org/opensearch/ml/common/transport/unload/UnloadModelNodeResponse.java index d2aa22ddbe..c02e2bbd73 100644 --- a/common/src/main/java/org/opensearch/ml/common/transport/unload/UnloadModelNodeResponse.java +++ b/common/src/main/java/org/opensearch/ml/common/transport/unload/UnloadModelNodeResponse.java @@ -20,10 +20,14 @@ public class UnloadModelNodeResponse extends BaseNodeResponse implements ToXContentFragment { private Map modelUnloadStatus; + private Map modelWorkerNodeCounts; - public UnloadModelNodeResponse(DiscoveryNode node, Map modelUnloadStatus) { + public UnloadModelNodeResponse(DiscoveryNode node, + Map modelUnloadStatus, + Map modelWorkerNodeCounts) { super(node); this.modelUnloadStatus = modelUnloadStatus; + this.modelWorkerNodeCounts = modelWorkerNodeCounts; } public UnloadModelNodeResponse(StreamInput in) throws IOException { @@ -31,6 +35,9 @@ public UnloadModelNodeResponse(StreamInput in) throws IOException { if (in.readBoolean()) { this.modelUnloadStatus = in.readMap(s -> s.readString(), s-> s.readString()); } + if (in.readBoolean()) { + this.modelWorkerNodeCounts = in.readMap(s -> s.readString(), s-> s.readInt()); + } } public static UnloadModelNodeResponse readStats(StreamInput in) throws IOException { @@ -47,6 +54,12 @@ public void writeTo(StreamOutput out) throws IOException { } else { out.writeBoolean(false); } + if (modelWorkerNodeCounts != null) { + out.writeBoolean(true); + out.writeMap(modelWorkerNodeCounts, StreamOutput::writeString, StreamOutput::writeInt); + } else { + out.writeBoolean(false); + } } public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { diff --git a/common/src/test/java/org/opensearch/ml/common/transport/unload/UnloadModelNodeResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/unload/UnloadModelNodeResponseTest.java index 7dc612c346..9eba1414fb 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/unload/UnloadModelNodeResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/unload/UnloadModelNodeResponseTest.java @@ -8,12 +8,8 @@ import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.transport.TransportAddress; -import org.opensearch.common.xcontent.XContentBuilder; -import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.InetAddress; import java.util.Collections; @@ -29,6 +25,8 @@ public class UnloadModelNodeResponseTest { @Mock private DiscoveryNode localNode; + private Map modelWorkerNodeCounts; + @Before public void setUp() throws Exception { localNode = new DiscoveryNode( @@ -39,13 +37,15 @@ public void setUp() throws Exception { Collections.singleton(CLUSTER_MANAGER_ROLE), Version.CURRENT ); + modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", 1); } @Test public void testSerializationDeserialization() throws IOException { Map modelToLoadStatus = new HashMap<>(); - modelToLoadStatus.put("modelName:version", "response"); - UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus); + modelToLoadStatus.put("modelId1", "response"); + UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); UnloadModelNodeResponse newResponse = new UnloadModelNodeResponse(output.bytes().streamInput()); @@ -54,7 +54,7 @@ public void testSerializationDeserialization() throws IOException { @Test public void testSerializationDeserialization_NullModelLoadStatus() throws IOException { - UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, null); + UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, null, null); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); UnloadModelNodeResponse newResponse = new UnloadModelNodeResponse(output.bytes().streamInput()); @@ -63,7 +63,7 @@ public void testSerializationDeserialization_NullModelLoadStatus() throws IOExce @Test public void testReadProfile() throws IOException { - UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, new HashMap<>()); + UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, new HashMap<>(), new HashMap<>()); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); UnloadModelNodeResponse newResponse = UnloadModelNodeResponse.readStats(output.bytes().streamInput()); diff --git a/common/src/test/java/org/opensearch/ml/common/transport/unload/UnloadModelNodesResponseTest.java b/common/src/test/java/org/opensearch/ml/common/transport/unload/UnloadModelNodesResponseTest.java index 257a822808..42d73dd4b9 100644 --- a/common/src/test/java/org/opensearch/ml/common/transport/unload/UnloadModelNodesResponseTest.java +++ b/common/src/test/java/org/opensearch/ml/common/transport/unload/UnloadModelNodesResponseTest.java @@ -31,6 +31,7 @@ public class UnloadModelNodesResponseTest { private ClusterName clusterName; private DiscoveryNode node1; private DiscoveryNode node2; + private Map modelWorkerNodeCounts; @Before public void setUp() throws Exception { @@ -51,6 +52,8 @@ public void setUp() throws Exception { Collections.singleton(CLUSTER_MANAGER_ROLE), Version.CURRENT ); + modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", 1); } @Test @@ -69,12 +72,16 @@ public void testToXContent() throws IOException { List nodes = new ArrayList<>(); Map modelToUnloadStatus1 = new HashMap<>(); - modelToUnloadStatus1.put("modelName:version1", "response"); - nodes.add(new UnloadModelNodeResponse(node1, modelToUnloadStatus1)); + modelToUnloadStatus1.put("modelId1", "response"); + Map modelWorkerNodeCounts1 = new HashMap<>(); + modelWorkerNodeCounts1.put("modelId1", 1); + nodes.add(new UnloadModelNodeResponse(node1, modelToUnloadStatus1, modelWorkerNodeCounts1)); Map modelToUnloadStatus2 = new HashMap<>(); - modelToUnloadStatus2.put("modelName:version2", "response"); - nodes.add(new UnloadModelNodeResponse(node2, modelToUnloadStatus2)); + modelToUnloadStatus2.put("modelId2", "response"); + Map modelWorkerNodeCounts2 = new HashMap<>(); + modelWorkerNodeCounts2.put("modelId2", 2); + nodes.add(new UnloadModelNodeResponse(node2, modelToUnloadStatus2, modelWorkerNodeCounts2)); List failures = new ArrayList<>(); UnloadModelNodesResponse response = new UnloadModelNodesResponse(clusterName, nodes, failures); @@ -82,7 +89,7 @@ public void testToXContent() throws IOException { response.toXContent(builder, ToXContent.EMPTY_PARAMS); String jsonStr = Strings.toString(builder); assertEquals( - "{\"foo1\":{\"stats\":{\"modelName:version1\":\"response\"}},\"foo2\":{\"stats\":{\"modelName:version2\":\"response\"}}}", + "{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}", jsonStr ); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/unload/TransportUnloadModelAction.java b/plugin/src/main/java/org/opensearch/ml/action/unload/TransportUnloadModelAction.java index d23cc133ac..40250680c0 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/unload/TransportUnloadModelAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/unload/TransportUnloadModelAction.java @@ -97,8 +97,19 @@ protected UnloadModelNodesResponse newResponse( ) { if (responses != null) { Map> removedNodeMap = new HashMap<>(); + Map modelWorkNodeCounts = new HashMap<>(); responses.stream().forEach(r -> { Set notFoundModels = new HashSet<>(); + Map nodeCounts = r.getModelWorkerNodeCounts(); + if (nodeCounts != null) { + for (Map.Entry entry : nodeCounts.entrySet()) { + if (!modelWorkNodeCounts.containsKey(entry.getKey()) + || modelWorkNodeCounts.get(entry.getKey()) < entry.getValue()) { + modelWorkNodeCounts.put(entry.getKey(), entry.getValue()); + } + } + } + Map modelUnloadStatus = r.getModelUnloadStatus(); for (Map.Entry entry : modelUnloadStatus.entrySet()) { String status = entry.getValue(); @@ -128,7 +139,11 @@ protected UnloadModelNodesResponse newResponse( BulkRequest bulkRequest = new BulkRequest(); for (String modelId : removedNodeMap.keySet()) { UpdateRequest updateRequest = new UpdateRequest(); - updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(ImmutableMap.of(MODEL_STATE_FIELD, MLModelState.UNLOADED)); + int removedNodeCount = removedNodeMap.get(modelId).size(); + MLModelState mlModelState = modelWorkNodeCounts.get(modelId) > removedNodeCount + ? MLModelState.PARTIALLY_LOADED + : MLModelState.UNLOADED; + updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(ImmutableMap.of(MODEL_STATE_FIELD, mlModelState)); bulkRequest.add(updateRequest); } ActionListener actionListenr = ActionListener @@ -181,8 +196,17 @@ private UnloadModelNodeResponse createUnloadModelNodeResponse(UnloadModelNodesRe mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment(); String[] modelIds = unloadModelNodesRequest.getModelIds(); + + Map modelWorkerNodeCounts = new HashMap<>(); + if (modelIds != null) { + for (String modelId : modelIds) { + String[] workerNodes = mlModelManager.getWorkerNodes(modelId); + modelWorkerNodeCounts.put(modelId, workerNodes == null ? 0 : workerNodes.length); + } + } + Map modelUnloadStatus = mlModelManager.unloadModel(modelIds); mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement(); - return new UnloadModelNodeResponse(clusterService.localNode(), modelUnloadStatus); + return new UnloadModelNodeResponse(clusterService.localNode(), modelUnloadStatus, modelWorkerNodeCounts); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/unload/TransportUnloadModelActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/unload/TransportUnloadModelActionTests.java index eb94c2492f..b3996ba5a8 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/unload/TransportUnloadModelActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/unload/TransportUnloadModelActionTests.java @@ -9,8 +9,11 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE; +import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX; import java.io.IOException; import java.net.InetAddress; @@ -18,14 +21,18 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.concurrent.ExecutorService; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.bulk.BulkRequest; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; @@ -35,6 +42,8 @@ import org.opensearch.common.transport.TransportAddress; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.ml.cluster.DiscoveryNodeHelper; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.model.MLModelState; import org.opensearch.ml.common.transport.unload.UnloadModelNodeRequest; import org.opensearch.ml.common.transport.unload.UnloadModelNodeResponse; import org.opensearch.ml.common.transport.unload.UnloadModelNodesRequest; @@ -130,9 +139,11 @@ public void testNewNodeRequest() { } public void testNewNodeStreamRequest() throws IOException { - java.util.Map modelToLoadStatus = new HashMap<>(); - modelToLoadStatus.put("modelName:version", "response"); - UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus); + Map modelToLoadStatus = new HashMap<>(); + Map modelWorkerNodeCounts = new HashMap<>(); + modelToLoadStatus.put("modelId1", "response"); + modelWorkerNodeCounts.put("modelId1", 1); + UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts); BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); final UnloadModelNodeResponse unLoadResponse = action.newNodeResponse(output.bytes().streamInput()); @@ -156,17 +167,23 @@ public void testNewResponseWithUnloadedModelStatus() { new String[] { "modelId1", "modelId2" } ); final List responses = new ArrayList<>(); - java.util.Map modelToLoadStatus = new HashMap<>(); - modelToLoadStatus.put("modelName:version", "unloaded"); - UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus); - modelToLoadStatus.put("modelName:version", "unloaded"); - UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus); + Map modelToLoadStatus = new HashMap<>(); + modelToLoadStatus.put("modelId1", "unloaded"); + Map modelWorkerNodeCounts = new HashMap<>(); + modelWorkerNodeCounts.put("modelId1", 1); + UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts); + UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts); responses.add(response1); responses.add(response2); final List failures = new ArrayList<>(); final UnloadModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); assertNotNull(response); - + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BulkRequest.class); + verify(client, times(1)).bulk(argumentCaptor.capture(), any()); + UpdateRequest updateRequest = (UpdateRequest) argumentCaptor.getValue().requests().get(0); + assertEquals(ML_MODEL_INDEX, updateRequest.index()); + Map updateContent = updateRequest.doc().sourceAsMap(); + assertEquals(MLModelState.UNLOADED.name(), updateContent.get(MLModel.MODEL_STATE_FIELD)); } public void testNewResponseWithNotFoundModelStatus() { @@ -175,15 +192,22 @@ public void testNewResponseWithNotFoundModelStatus() { new String[] { "modelId1", "modelId2" } ); final List responses = new ArrayList<>(); - java.util.Map modelToLoadStatus = new HashMap<>(); - modelToLoadStatus.put("modelName:version", "not_found"); - UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus); - modelToLoadStatus.put("modelName:version", "not_found"); - UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus); + Map modelToLoadStatus = new HashMap<>(); + Map modelWorkerNodeCounts = new HashMap<>(); + modelToLoadStatus.put("modelId1", "not_found"); + modelWorkerNodeCounts.put("modelId1", 2); + UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts); + UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts); responses.add(response1); responses.add(response2); final List failures = new ArrayList<>(); final UnloadModelNodesResponse response = action.newResponse(nodesRequest, responses, failures); assertNotNull(response); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(BulkRequest.class); + verify(client, times(1)).bulk(argumentCaptor.capture(), any()); + UpdateRequest updateRequest = (UpdateRequest) argumentCaptor.getValue().requests().get(0); + assertEquals(ML_MODEL_INDEX, updateRequest.index()); + Map updateContent = updateRequest.doc().sourceAsMap(); + assertEquals(MLModelState.PARTIALLY_LOADED.name(), updateContent.get(MLModel.MODEL_STATE_FIELD)); } }