Skip to content

Commit

Permalink
add more UT for task manager/runner (#206)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Mar 11, 2022
1 parent a4305a5 commit 1d30a22
Show file tree
Hide file tree
Showing 7 changed files with 846 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public class MLTask implements ToXContentObject, Writeable {
private User user; // TODO: support document level access control later
private boolean async;

@Builder
@Builder(toBuilder = true)
public MLTask(
String taskId,
String modelId,
Expand Down
16 changes: 10 additions & 6 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,20 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.indices.MLInputDatasetHandler',
'org.opensearch.ml.plugin.*',
'org.opensearch.ml.task.MLTaskDispatcher',
'org.opensearch.ml.task.MLTaskRunner',
'org.opensearch.ml.task.MLTrainingTaskRunner',
'org.opensearch.ml.task.MLPredictTaskRunner',
'org.opensearch.ml.rest.RestMLTrainingAction',
'org.opensearch.ml.rest.RestMLPredictionAction',
'org.opensearch.ml.utils.RestActionUtils',
'org.opensearch.ml.task.MLTaskCache',
'org.opensearch.ml.task.MLTaskManager',
'org.opensearch.ml.task.MLTrainAndPredictTaskRunner',
'org.opensearch.ml.rest.AbstractMLSearchAction*'
'org.opensearch.ml.rest.AbstractMLSearchAction*',
'org.opensearch.ml.utils.MLNodeUtils', //0.5
'org.opensearch.ml.task.MLExecuteTaskRunner', //0.5
'org.opensearch.ml.rest.RestMLDeleteTaskAction', //0.5
'org.opensearch.ml.rest.RestMLGetModelAction', //0.5
'org.opensearch.ml.rest.RestMLExecuteAction', //0.3
'org.opensearch.ml.rest.RestMLDeleteModelAction', //0.5
'org.opensearch.ml.rest.RestMLTrainAndPredictAction', //0.3
'org.opensearch.ml.rest.RestMLGetTaskAction' //0.5
]

jacocoTestCoverageVerification {
Expand All @@ -237,7 +241,7 @@ jacocoTestCoverageVerification {
limit {
counter = 'LINE'
value = 'COVEREDRATIO'
minimum = 0.3 //TODO: add more test to meet the coverage bar 0.7
minimum = 0.7
}
}
}
Expand Down
137 changes: 133 additions & 4 deletions plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,29 @@
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.opensearch.action.ActionListener;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.action.update.UpdateResponse;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.index.Index;
import org.opensearch.index.shard.ShardId;
import org.opensearch.ml.common.parameter.MLTask;
import org.opensearch.ml.common.parameter.MLTaskState;
import org.opensearch.ml.common.parameter.MLTaskType;
import org.opensearch.ml.indices.MLIndicesHandler;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;

import com.google.common.collect.ImmutableMap;

public class MLTaskManagerTests extends OpenSearchTestCase {
MLTaskManager mlTaskManager;
MLTask mlTask;
Client client;
ThreadPool threadPool;
ThreadContext threadContext;
MLIndicesHandler mlIndicesHandler;

@Rule
Expand All @@ -37,6 +48,12 @@ public class MLTaskManagerTests extends OpenSearchTestCase {
@Before
public void setup() {
this.client = mock(Client.class);
this.threadPool = mock(ThreadPool.class);
Settings settings = Settings.builder().build();
threadContext = new ThreadContext(settings);
when(client.threadPool()).thenReturn(threadPool);
when(threadPool.getThreadContext()).thenReturn(threadContext);

this.mlIndicesHandler = mock(MLIndicesHandler.class);
this.mlTaskManager = spy(new MLTaskManager(client, mlIndicesHandler));
this.mlTask = MLTask
Expand Down Expand Up @@ -89,34 +106,126 @@ public void testUpdateTaskStateAndError() {
Assert.assertEquals(0, value.longValue());
}

public void testUpdateTaskStateAndError_SyncTask() {
mlTaskManager.add(mlTask);
mlTaskManager.updateTaskStateAndError(mlTask.getTaskId(), MLTaskState.FAILED, "test error", false);
verify(mlTaskManager, never()).updateMLTask(eq(mlTask.getTaskId()), any(), anyLong());
}

public void testUpdateMLTaskWithNullOrEmptyMap() {
mlTaskManager.add(mlTask);
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0);
verify(client, never()).index(any());
verify(client, never()).update(any(), any());
verify(listener, times(1)).onFailure(any());

mlTaskManager.updateMLTask(mlTask.getTaskId(), new HashMap<>(), listener, 0);
verify(client, never()).index(any());
verify(client, never()).update(any(), any());
verify(listener, times(2)).onFailure(any());
}

public void testUpdateMLTask_NonExistingTask() {
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
mlTaskManager.updateMLTask(mlTask.getTaskId(), null, listener, 0);
verify(client, never()).update(any(), any());
verify(listener, times(1)).onFailure(argumentCaptor.capture());
assertEquals("Can't find task", argumentCaptor.getValue().getMessage());
}

public void testUpdateMLTask_NoSemaphore() {
MLTask asyncMlTask = mlTask.toBuilder().async(true).build();
mlTaskManager.add(asyncMlTask);

doAnswer(invocation -> {
ActionListener<UpdateResponse> actionListener = invocation.getArgument(1);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
UpdateResponse output = new UpdateResponse(shardId, "_doc", "taskId", 1, 1, 1, DocWriteResponse.Result.CREATED);
actionListener.onResponse(output);
return null;
}).when(client).update(any(UpdateRequest.class), any());

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), ActionListener.wrap(r -> {
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), null, listener, 0);
verify(client, times(1)).update(any(), any());
verify(listener, times(1)).onFailure(argumentCaptor.capture());
assertEquals("Other updating request not finished yet", argumentCaptor.getValue().getMessage());
}, e -> { assertNull(e); }), 0);
}

public void testUpdateMLTask_FailedToUpdate() {
MLTask asyncMlTask = mlTask.toBuilder().async(true).build();
mlTaskManager.add(asyncMlTask);

String errorMessage = "test error message";
doAnswer(invocation -> {
ActionListener<UpdateResponse> actionListener = invocation.getArgument(1);
actionListener.onFailure(new RuntimeException(errorMessage));
return null;
}).when(client).update(any(UpdateRequest.class), any());

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0);
verify(client, times(1)).update(any(), any());
verify(listener, times(1)).onFailure(argumentCaptor.capture());
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
}

public void testUpdateMLTask_ThrowException() {
MLTask asyncMlTask = mlTask.toBuilder().async(true).build();
mlTaskManager.add(asyncMlTask);

String errorMessage = "test error message";
doThrow(new RuntimeException(errorMessage)).when(client).update(any(UpdateRequest.class), any());

ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
ActionListener<UpdateResponse> listener = mock(ActionListener.class);
mlTaskManager.updateMLTask(asyncMlTask.getTaskId(), ImmutableMap.of(MLTask.ERROR_FIELD, "test error"), listener, 0);
verify(client, times(1)).update(any(), any());
verify(listener, times(1)).onFailure(argumentCaptor.capture());
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
}

public void testRemove() {
mlTaskManager.add(mlTask);
Assert.assertTrue(mlTaskManager.contains(mlTask.getTaskId()));
mlTaskManager.remove(mlTask.getTaskId());
Assert.assertFalse(mlTaskManager.contains(mlTask.getTaskId()));
}

public void testRemove_NonExistingTask() {
Assert.assertFalse(mlTaskManager.contains(mlTask.getTaskId()));
mlTaskManager.remove(mlTask.getTaskId());
Assert.assertFalse(mlTaskManager.contains(mlTask.getTaskId()));
}

public void testGetTask() {
mlTaskManager.add(mlTask);
Assert.assertTrue(mlTaskManager.contains(mlTask.getTaskId()));
MLTask task = mlTaskManager.get(this.mlTask.getTaskId());
Assert.assertEquals(mlTask, task);
}

public void testGetTask_NonExisting() {
Assert.assertFalse(mlTaskManager.contains(mlTask.getTaskId()));
MLTask task = mlTaskManager.get(this.mlTask.getTaskId());
Assert.assertNull(task);
}

public void testGetRunningTaskCount() {
MLTask task1 = MLTask.builder().taskId("1").state(MLTaskState.CREATED).build();
MLTask task2 = MLTask.builder().taskId("2").state(MLTaskState.RUNNING).build();
MLTask task3 = MLTask.builder().taskId("3").state(MLTaskState.FAILED).build();
MLTask task4 = MLTask.builder().taskId("4").state(MLTaskState.COMPLETED).build();
MLTask task5 = MLTask.builder().taskId("5").state(null).build();
mlTaskManager.add(task1);
mlTaskManager.add(task2);
mlTaskManager.add(task3);
mlTaskManager.add(task4);
mlTaskManager.add(task5);
Assert.assertEquals(mlTaskManager.getRunningTaskCount(), 1);
}

Expand Down Expand Up @@ -155,9 +264,29 @@ public void testCreateMlTask_IndexException() {
return null;
}).when(mlIndicesHandler).initMLTaskIndex(any(ActionListener.class));

doThrow(new RuntimeException("test")).when(client).index(any(), any());
String errorMessage = "test error message";
doThrow(new RuntimeException(errorMessage)).when(client).index(any(), any());
ActionListener listener = mock(ActionListener.class);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
mlTaskManager.createMLTask(mlTask, listener);
verify(listener).onFailure(any());
verify(listener).onFailure(argumentCaptor.capture());
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
}

public void testCreateMlTask_FailToGetThreadPool() {
doAnswer(invocation -> {
ActionListener<Boolean> listener = invocation.getArgument(0);
listener.onResponse(true);
return null;
}).when(mlIndicesHandler).initMLTaskIndex(any(ActionListener.class));

String errorMessage = "test error message";
doThrow(new RuntimeException(errorMessage)).when(threadPool).getThreadContext();
ActionListener listener = mock(ActionListener.class);
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
mlTaskManager.createMLTask(mlTask, listener);
verify(listener).onFailure(argumentCaptor.capture());
assertEquals(errorMessage, argumentCaptor.getValue().getMessage());
}

}
Loading

0 comments on commit 1d30a22

Please sign in to comment.