Skip to content

Commit

Permalink
set model state as partially loaded if unload model from partial nodes (
Browse files Browse the repository at this point in the history
#806)

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Mar 16, 2023
1 parent 2fae22d commit 894f983
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,24 @@
public class UnloadModelNodeResponse extends BaseNodeResponse implements ToXContentFragment {

private Map<String, String> modelUnloadStatus;
private Map<String, Integer> modelWorkerNodeCounts;

public UnloadModelNodeResponse(DiscoveryNode node, Map<String, String> modelUnloadStatus) {
public UnloadModelNodeResponse(DiscoveryNode node,
Map<String, String> modelUnloadStatus,
Map<String, Integer> modelWorkerNodeCounts) {
super(node);
this.modelUnloadStatus = modelUnloadStatus;
this.modelWorkerNodeCounts = modelWorkerNodeCounts;
}

public UnloadModelNodeResponse(StreamInput in) throws IOException {
super(in);
if (in.readBoolean()) {
this.modelUnloadStatus = in.readMap(s -> s.readString(), s-> s.readString());
}
if (in.readBoolean()) {
this.modelWorkerNodeCounts = in.readMap(s -> s.readString(), s-> s.readInt());
}
}

public static UnloadModelNodeResponse readStats(StreamInput in) throws IOException {
Expand All @@ -47,6 +54,12 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (modelWorkerNodeCounts != null) {
out.writeBoolean(true);
out.writeMap(modelWorkerNodeCounts, StreamOutput::writeString, StreamOutput::writeInt);
} else {
out.writeBoolean(false);
}
}

public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@
import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.transport.TransportAddress;
import org.opensearch.common.xcontent.XContentBuilder;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.InetAddress;
import java.util.Collections;
Expand All @@ -29,6 +25,8 @@ public class UnloadModelNodeResponseTest {
@Mock
private DiscoveryNode localNode;

private Map<String, Integer> modelWorkerNodeCounts;

@Before
public void setUp() throws Exception {
localNode = new DiscoveryNode(
Expand All @@ -39,13 +37,15 @@ public void setUp() throws Exception {
Collections.singleton(CLUSTER_MANAGER_ROLE),
Version.CURRENT
);
modelWorkerNodeCounts = new HashMap<>();
modelWorkerNodeCounts.put("modelId1", 1);
}

@Test
public void testSerializationDeserialization() throws IOException {
Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelName:version", "response");
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
modelToLoadStatus.put("modelId1", "response");
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
UnloadModelNodeResponse newResponse = new UnloadModelNodeResponse(output.bytes().streamInput());
Expand All @@ -54,7 +54,7 @@ public void testSerializationDeserialization() throws IOException {

@Test
public void testSerializationDeserialization_NullModelLoadStatus() throws IOException {
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, null);
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, null, null);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
UnloadModelNodeResponse newResponse = new UnloadModelNodeResponse(output.bytes().streamInput());
Expand All @@ -63,7 +63,7 @@ public void testSerializationDeserialization_NullModelLoadStatus() throws IOExce

@Test
public void testReadProfile() throws IOException {
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, new HashMap<>());
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, new HashMap<>(), new HashMap<>());
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
UnloadModelNodeResponse newResponse = UnloadModelNodeResponse.readStats(output.bytes().streamInput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class UnloadModelNodesResponseTest {
private ClusterName clusterName;
private DiscoveryNode node1;
private DiscoveryNode node2;
private Map<String, Integer> modelWorkerNodeCounts;

@Before
public void setUp() throws Exception {
Expand All @@ -51,6 +52,8 @@ public void setUp() throws Exception {
Collections.singleton(CLUSTER_MANAGER_ROLE),
Version.CURRENT
);
modelWorkerNodeCounts = new HashMap<>();
modelWorkerNodeCounts.put("modelId1", 1);
}

@Test
Expand All @@ -69,20 +72,24 @@ public void testToXContent() throws IOException {
List<UnloadModelNodeResponse> nodes = new ArrayList<>();

Map<String, String> modelToUnloadStatus1 = new HashMap<>();
modelToUnloadStatus1.put("modelName:version1", "response");
nodes.add(new UnloadModelNodeResponse(node1, modelToUnloadStatus1));
modelToUnloadStatus1.put("modelId1", "response");
Map<String, Integer> modelWorkerNodeCounts1 = new HashMap<>();
modelWorkerNodeCounts1.put("modelId1", 1);
nodes.add(new UnloadModelNodeResponse(node1, modelToUnloadStatus1, modelWorkerNodeCounts1));

Map<String, String> modelToUnloadStatus2 = new HashMap<>();
modelToUnloadStatus2.put("modelName:version2", "response");
nodes.add(new UnloadModelNodeResponse(node2, modelToUnloadStatus2));
modelToUnloadStatus2.put("modelId2", "response");
Map<String, Integer> modelWorkerNodeCounts2 = new HashMap<>();
modelWorkerNodeCounts2.put("modelId2", 2);
nodes.add(new UnloadModelNodeResponse(node2, modelToUnloadStatus2, modelWorkerNodeCounts2));

List<FailedNodeException> failures = new ArrayList<>();
UnloadModelNodesResponse response = new UnloadModelNodesResponse(clusterName, nodes, failures);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String jsonStr = Strings.toString(builder);
assertEquals(
"{\"foo1\":{\"stats\":{\"modelName:version1\":\"response\"}},\"foo2\":{\"stats\":{\"modelName:version2\":\"response\"}}}",
"{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}",
jsonStr
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,19 @@ protected UnloadModelNodesResponse newResponse(
) {
if (responses != null) {
Map<String, List<String>> removedNodeMap = new HashMap<>();
Map<String, Integer> modelWorkNodeCounts = new HashMap<>();
responses.stream().forEach(r -> {
Set<String> notFoundModels = new HashSet<>();
Map<String, Integer> nodeCounts = r.getModelWorkerNodeCounts();
if (nodeCounts != null) {
for (Map.Entry<String, Integer> entry : nodeCounts.entrySet()) {
if (!modelWorkNodeCounts.containsKey(entry.getKey())
|| modelWorkNodeCounts.get(entry.getKey()) < entry.getValue()) {
modelWorkNodeCounts.put(entry.getKey(), entry.getValue());
}
}
}

Map<String, String> modelUnloadStatus = r.getModelUnloadStatus();
for (Map.Entry<String, String> entry : modelUnloadStatus.entrySet()) {
String status = entry.getValue();
Expand Down Expand Up @@ -128,7 +139,11 @@ protected UnloadModelNodesResponse newResponse(
BulkRequest bulkRequest = new BulkRequest();
for (String modelId : removedNodeMap.keySet()) {
UpdateRequest updateRequest = new UpdateRequest();
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(ImmutableMap.of(MODEL_STATE_FIELD, MLModelState.UNLOADED));
int removedNodeCount = removedNodeMap.get(modelId).size();
MLModelState mlModelState = modelWorkNodeCounts.get(modelId) > removedNodeCount
? MLModelState.PARTIALLY_LOADED
: MLModelState.UNLOADED;
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(ImmutableMap.of(MODEL_STATE_FIELD, mlModelState));
bulkRequest.add(updateRequest);
}
ActionListener<BulkResponse> actionListenr = ActionListener
Expand Down Expand Up @@ -181,8 +196,17 @@ private UnloadModelNodeResponse createUnloadModelNodeResponse(UnloadModelNodesRe
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();

String[] modelIds = unloadModelNodesRequest.getModelIds();

Map<String, Integer> modelWorkerNodeCounts = new HashMap<>();
if (modelIds != null) {
for (String modelId : modelIds) {
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
modelWorkerNodeCounts.put(modelId, workerNodes == null ? 0 : workerNodes.length);
}
}

Map<String, String> modelUnloadStatus = mlModelManager.unloadModel(modelIds);
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement();
return new UnloadModelNodeResponse(clusterService.localNode(), modelUnloadStatus);
return new UnloadModelNodeResponse(clusterService.localNode(), modelUnloadStatus, modelWorkerNodeCounts);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,30 @@
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.io.IOException;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.Version;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.node.DiscoveryNode;
Expand All @@ -35,6 +42,8 @@
import org.opensearch.common.transport.TransportAddress;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.unload.UnloadModelNodeRequest;
import org.opensearch.ml.common.transport.unload.UnloadModelNodeResponse;
import org.opensearch.ml.common.transport.unload.UnloadModelNodesRequest;
Expand Down Expand Up @@ -130,9 +139,11 @@ public void testNewNodeRequest() {
}

public void testNewNodeStreamRequest() throws IOException {
java.util.Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelName:version", "response");
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
Map<String, String> modelToLoadStatus = new HashMap<>();
Map<String, Integer> modelWorkerNodeCounts = new HashMap<>();
modelToLoadStatus.put("modelId1", "response");
modelWorkerNodeCounts.put("modelId1", 1);
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
final UnloadModelNodeResponse unLoadResponse = action.newNodeResponse(output.bytes().streamInput());
Expand All @@ -156,17 +167,23 @@ public void testNewResponseWithUnloadedModelStatus() {
new String[] { "modelId1", "modelId2" }
);
final List<UnloadModelNodeResponse> responses = new ArrayList<>();
java.util.Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelName:version", "unloaded");
UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
modelToLoadStatus.put("modelName:version", "unloaded");
UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelId1", "unloaded");
Map<String, Integer> modelWorkerNodeCounts = new HashMap<>();
modelWorkerNodeCounts.put("modelId1", 1);
UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
responses.add(response1);
responses.add(response2);
final List<FailedNodeException> failures = new ArrayList<>();
final UnloadModelNodesResponse response = action.newResponse(nodesRequest, responses, failures);
assertNotNull(response);

ArgumentCaptor<BulkRequest> argumentCaptor = ArgumentCaptor.forClass(BulkRequest.class);
verify(client, times(1)).bulk(argumentCaptor.capture(), any());
UpdateRequest updateRequest = (UpdateRequest) argumentCaptor.getValue().requests().get(0);
assertEquals(ML_MODEL_INDEX, updateRequest.index());
Map<String, Object> updateContent = updateRequest.doc().sourceAsMap();
assertEquals(MLModelState.UNLOADED.name(), updateContent.get(MLModel.MODEL_STATE_FIELD));
}

public void testNewResponseWithNotFoundModelStatus() {
Expand All @@ -175,15 +192,22 @@ public void testNewResponseWithNotFoundModelStatus() {
new String[] { "modelId1", "modelId2" }
);
final List<UnloadModelNodeResponse> responses = new ArrayList<>();
java.util.Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelName:version", "not_found");
UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
modelToLoadStatus.put("modelName:version", "not_found");
UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
Map<String, String> modelToLoadStatus = new HashMap<>();
Map<String, Integer> modelWorkerNodeCounts = new HashMap<>();
modelToLoadStatus.put("modelId1", "not_found");
modelWorkerNodeCounts.put("modelId1", 2);
UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
responses.add(response1);
responses.add(response2);
final List<FailedNodeException> failures = new ArrayList<>();
final UnloadModelNodesResponse response = action.newResponse(nodesRequest, responses, failures);
assertNotNull(response);
ArgumentCaptor<BulkRequest> argumentCaptor = ArgumentCaptor.forClass(BulkRequest.class);
verify(client, times(1)).bulk(argumentCaptor.capture(), any());
UpdateRequest updateRequest = (UpdateRequest) argumentCaptor.getValue().requests().get(0);
assertEquals(ML_MODEL_INDEX, updateRequest.index());
Map<String, Object> updateContent = updateRequest.doc().sourceAsMap();
assertEquals(MLModelState.PARTIALLY_LOADED.name(), updateContent.get(MLModel.MODEL_STATE_FIELD));
}
}

0 comments on commit 894f983

Please sign in to comment.