Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

set model state as partially loaded if unload model from partial nodes #806

Merged
merged 1 commit into from
Mar 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,24 @@
public class UnloadModelNodeResponse extends BaseNodeResponse implements ToXContentFragment {

private Map<String, String> modelUnloadStatus;
private Map<String, Integer> modelWorkerNodeCounts;

public UnloadModelNodeResponse(DiscoveryNode node, Map<String, String> modelUnloadStatus) {
public UnloadModelNodeResponse(DiscoveryNode node,
Map<String, String> modelUnloadStatus,
Map<String, Integer> modelWorkerNodeCounts) {
super(node);
this.modelUnloadStatus = modelUnloadStatus;
this.modelWorkerNodeCounts = modelWorkerNodeCounts;
}

public UnloadModelNodeResponse(StreamInput in) throws IOException {
super(in);
if (in.readBoolean()) {
this.modelUnloadStatus = in.readMap(s -> s.readString(), s-> s.readString());
}
if (in.readBoolean()) {
this.modelWorkerNodeCounts = in.readMap(s -> s.readString(), s-> s.readInt());
}
}

public static UnloadModelNodeResponse readStats(StreamInput in) throws IOException {
Expand All @@ -47,6 +54,12 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
if (modelWorkerNodeCounts != null) {
out.writeBoolean(true);
out.writeMap(modelWorkerNodeCounts, StreamOutput::writeString, StreamOutput::writeInt);
} else {
out.writeBoolean(false);
}
}

public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@
import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.transport.TransportAddress;
import org.opensearch.common.xcontent.XContentBuilder;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.InetAddress;
import java.util.Collections;
Expand All @@ -29,6 +25,8 @@ public class UnloadModelNodeResponseTest {
@Mock
private DiscoveryNode localNode;

private Map<String, Integer> modelWorkerNodeCounts;

@Before
public void setUp() throws Exception {
localNode = new DiscoveryNode(
Expand All @@ -39,13 +37,15 @@ public void setUp() throws Exception {
Collections.singleton(CLUSTER_MANAGER_ROLE),
Version.CURRENT
);
modelWorkerNodeCounts = new HashMap<>();
modelWorkerNodeCounts.put("modelId1", 1);
}

@Test
public void testSerializationDeserialization() throws IOException {
Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelName:version", "response");
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
modelToLoadStatus.put("modelId1", "response");
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
UnloadModelNodeResponse newResponse = new UnloadModelNodeResponse(output.bytes().streamInput());
Expand All @@ -54,7 +54,7 @@ public void testSerializationDeserialization() throws IOException {

@Test
public void testSerializationDeserialization_NullModelLoadStatus() throws IOException {
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, null);
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, null, null);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
UnloadModelNodeResponse newResponse = new UnloadModelNodeResponse(output.bytes().streamInput());
Expand All @@ -63,7 +63,7 @@ public void testSerializationDeserialization_NullModelLoadStatus() throws IOExce

@Test
public void testReadProfile() throws IOException {
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, new HashMap<>());
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, new HashMap<>(), new HashMap<>());
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
UnloadModelNodeResponse newResponse = UnloadModelNodeResponse.readStats(output.bytes().streamInput());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ public class UnloadModelNodesResponseTest {
private ClusterName clusterName;
private DiscoveryNode node1;
private DiscoveryNode node2;
private Map<String, Integer> modelWorkerNodeCounts;

@Before
public void setUp() throws Exception {
Expand All @@ -51,6 +52,8 @@ public void setUp() throws Exception {
Collections.singleton(CLUSTER_MANAGER_ROLE),
Version.CURRENT
);
modelWorkerNodeCounts = new HashMap<>();
modelWorkerNodeCounts.put("modelId1", 1);
}

@Test
Expand All @@ -69,20 +72,24 @@ public void testToXContent() throws IOException {
List<UnloadModelNodeResponse> nodes = new ArrayList<>();

Map<String, String> modelToUnloadStatus1 = new HashMap<>();
modelToUnloadStatus1.put("modelName:version1", "response");
nodes.add(new UnloadModelNodeResponse(node1, modelToUnloadStatus1));
modelToUnloadStatus1.put("modelId1", "response");
Map<String, Integer> modelWorkerNodeCounts1 = new HashMap<>();
modelWorkerNodeCounts1.put("modelId1", 1);
nodes.add(new UnloadModelNodeResponse(node1, modelToUnloadStatus1, modelWorkerNodeCounts1));

Map<String, String> modelToUnloadStatus2 = new HashMap<>();
modelToUnloadStatus2.put("modelName:version2", "response");
nodes.add(new UnloadModelNodeResponse(node2, modelToUnloadStatus2));
modelToUnloadStatus2.put("modelId2", "response");
Map<String, Integer> modelWorkerNodeCounts2 = new HashMap<>();
modelWorkerNodeCounts2.put("modelId2", 2);
nodes.add(new UnloadModelNodeResponse(node2, modelToUnloadStatus2, modelWorkerNodeCounts2));

List<FailedNodeException> failures = new ArrayList<>();
UnloadModelNodesResponse response = new UnloadModelNodesResponse(clusterName, nodes, failures);
XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON);
response.toXContent(builder, ToXContent.EMPTY_PARAMS);
String jsonStr = Strings.toString(builder);
assertEquals(
"{\"foo1\":{\"stats\":{\"modelName:version1\":\"response\"}},\"foo2\":{\"stats\":{\"modelName:version2\":\"response\"}}}",
"{\"foo1\":{\"stats\":{\"modelId1\":\"response\"}},\"foo2\":{\"stats\":{\"modelId2\":\"response\"}}}",
jsonStr
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,19 @@ protected UnloadModelNodesResponse newResponse(
) {
if (responses != null) {
Map<String, List<String>> removedNodeMap = new HashMap<>();
Map<String, Integer> modelWorkNodeCounts = new HashMap<>();
responses.stream().forEach(r -> {
Set<String> notFoundModels = new HashSet<>();
Map<String, Integer> nodeCounts = r.getModelWorkerNodeCounts();
if (nodeCounts != null) {
for (Map.Entry<String, Integer> entry : nodeCounts.entrySet()) {
if (!modelWorkNodeCounts.containsKey(entry.getKey())
|| modelWorkNodeCounts.get(entry.getKey()) < entry.getValue()) {
modelWorkNodeCounts.put(entry.getKey(), entry.getValue());
}
}
}

Map<String, String> modelUnloadStatus = r.getModelUnloadStatus();
for (Map.Entry<String, String> entry : modelUnloadStatus.entrySet()) {
String status = entry.getValue();
Expand Down Expand Up @@ -128,7 +139,11 @@ protected UnloadModelNodesResponse newResponse(
BulkRequest bulkRequest = new BulkRequest();
for (String modelId : removedNodeMap.keySet()) {
UpdateRequest updateRequest = new UpdateRequest();
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(ImmutableMap.of(MODEL_STATE_FIELD, MLModelState.UNLOADED));
int removedNodeCount = removedNodeMap.get(modelId).size();
MLModelState mlModelState = modelWorkNodeCounts.get(modelId) > removedNodeCount
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to check if removedNodeCount > 0 here?
modelWorkNodeCounts.get(modelId) > removedNodeCount && removedNodeCount > 0.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removedNodeCount will always > 0, check line 121

? MLModelState.PARTIALLY_LOADED
: MLModelState.UNLOADED;
updateRequest.index(ML_MODEL_INDEX).id(modelId).doc(ImmutableMap.of(MODEL_STATE_FIELD, mlModelState));
bulkRequest.add(updateRequest);
}
ActionListener<BulkResponse> actionListenr = ActionListener
Expand Down Expand Up @@ -181,8 +196,17 @@ private UnloadModelNodeResponse createUnloadModelNodeResponse(UnloadModelNodesRe
mlStats.getStat(MLNodeLevelStat.ML_NODE_TOTAL_REQUEST_COUNT).increment();

String[] modelIds = unloadModelNodesRequest.getModelIds();

Map<String, Integer> modelWorkerNodeCounts = new HashMap<>();
if (modelIds != null) {
for (String modelId : modelIds) {
String[] workerNodes = mlModelManager.getWorkerNodes(modelId);
modelWorkerNodeCounts.put(modelId, workerNodes == null ? 0 : workerNodes.length);
}
}

Map<String, String> modelUnloadStatus = mlModelManager.unloadModel(modelIds);
mlStats.getStat(MLNodeLevelStat.ML_NODE_EXECUTING_TASK_COUNT).decrement();
return new UnloadModelNodeResponse(clusterService.localNode(), modelUnloadStatus);
return new UnloadModelNodeResponse(clusterService.localNode(), modelUnloadStatus, modelWorkerNodeCounts);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,30 @@
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.cluster.node.DiscoveryNodeRole.CLUSTER_MANAGER_ROLE;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;

import java.io.IOException;
import java.net.InetAddress;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.Version;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.ClusterName;
import org.opensearch.cluster.node.DiscoveryNode;
Expand All @@ -35,6 +42,8 @@
import org.opensearch.common.transport.TransportAddress;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.ml.cluster.DiscoveryNodeHelper;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.model.MLModelState;
import org.opensearch.ml.common.transport.unload.UnloadModelNodeRequest;
import org.opensearch.ml.common.transport.unload.UnloadModelNodeResponse;
import org.opensearch.ml.common.transport.unload.UnloadModelNodesRequest;
Expand Down Expand Up @@ -130,9 +139,11 @@ public void testNewNodeRequest() {
}

public void testNewNodeStreamRequest() throws IOException {
java.util.Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelName:version", "response");
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
Map<String, String> modelToLoadStatus = new HashMap<>();
Map<String, Integer> modelWorkerNodeCounts = new HashMap<>();
modelToLoadStatus.put("modelId1", "response");
modelWorkerNodeCounts.put("modelId1", 1);
UnloadModelNodeResponse response = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
BytesStreamOutput output = new BytesStreamOutput();
response.writeTo(output);
final UnloadModelNodeResponse unLoadResponse = action.newNodeResponse(output.bytes().streamInput());
Expand All @@ -156,17 +167,23 @@ public void testNewResponseWithUnloadedModelStatus() {
new String[] { "modelId1", "modelId2" }
);
final List<UnloadModelNodeResponse> responses = new ArrayList<>();
java.util.Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelName:version", "unloaded");
UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
modelToLoadStatus.put("modelName:version", "unloaded");
UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelId1", "unloaded");
Map<String, Integer> modelWorkerNodeCounts = new HashMap<>();
modelWorkerNodeCounts.put("modelId1", 1);
UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
responses.add(response1);
responses.add(response2);
final List<FailedNodeException> failures = new ArrayList<>();
final UnloadModelNodesResponse response = action.newResponse(nodesRequest, responses, failures);
assertNotNull(response);

ArgumentCaptor<BulkRequest> argumentCaptor = ArgumentCaptor.forClass(BulkRequest.class);
verify(client, times(1)).bulk(argumentCaptor.capture(), any());
UpdateRequest updateRequest = (UpdateRequest) argumentCaptor.getValue().requests().get(0);
assertEquals(ML_MODEL_INDEX, updateRequest.index());
Map<String, Object> updateContent = updateRequest.doc().sourceAsMap();
assertEquals(MLModelState.UNLOADED.name(), updateContent.get(MLModel.MODEL_STATE_FIELD));
}

public void testNewResponseWithNotFoundModelStatus() {
Expand All @@ -175,15 +192,22 @@ public void testNewResponseWithNotFoundModelStatus() {
new String[] { "modelId1", "modelId2" }
);
final List<UnloadModelNodeResponse> responses = new ArrayList<>();
java.util.Map<String, String> modelToLoadStatus = new HashMap<>();
modelToLoadStatus.put("modelName:version", "not_found");
UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
modelToLoadStatus.put("modelName:version", "not_found");
UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus);
Map<String, String> modelToLoadStatus = new HashMap<>();
Map<String, Integer> modelWorkerNodeCounts = new HashMap<>();
modelToLoadStatus.put("modelId1", "not_found");
modelWorkerNodeCounts.put("modelId1", 2);
UnloadModelNodeResponse response1 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
UnloadModelNodeResponse response2 = new UnloadModelNodeResponse(localNode, modelToLoadStatus, modelWorkerNodeCounts);
responses.add(response1);
responses.add(response2);
final List<FailedNodeException> failures = new ArrayList<>();
final UnloadModelNodesResponse response = action.newResponse(nodesRequest, responses, failures);
assertNotNull(response);
ArgumentCaptor<BulkRequest> argumentCaptor = ArgumentCaptor.forClass(BulkRequest.class);
verify(client, times(1)).bulk(argumentCaptor.capture(), any());
UpdateRequest updateRequest = (UpdateRequest) argumentCaptor.getValue().requests().get(0);
assertEquals(ML_MODEL_INDEX, updateRequest.index());
Map<String, Object> updateContent = updateRequest.doc().sourceAsMap();
assertEquals(MLModelState.PARTIALLY_LOADED.name(), updateContent.get(MLModel.MODEL_STATE_FIELD));
}
}