Skip to content

Commit

Permalink
fix parameter name in preprocess function; fix remote model function … (
Browse files Browse the repository at this point in the history
#1362)

* fix parameter name in preprocess function; fix remote model function name

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

* fix failed unit test

Signed-off-by: Yaliang Wu <ylwu@amazon.com>

---------

Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn authored Sep 26, 2023
1 parent 2c8cc02 commit 8de0431
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public class MLPreProcessFunction {
" }\n" +
" }\n" +
" builder.append(\"]\");\n" +
" def parameters = \"{\" +\"\\\"prompt\\\":\" + builder + \"}\";\n" +
" def parameters = \"{\" +\"\\\"texts\\\":\" + builder + \"}\";\n" +
" return \"{\" +\"\\\"parameters\\\":\" + parameters + \"}\";");

PRE_PROCESS_FUNCTIONS.put(TEXT_DOCS_TO_OPENAI_EMBEDDING_INPUT, "\n StringBuilder builder = new StringBuilder();\n" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ public void dispatchTask(
ActionListener<MLTaskResponse> listener
) {
String modelId = request.getModelId();
MLInput input = request.getMlInput();
FunctionName algorithm = input.getAlgorithm();
try {
ActionListener<DiscoveryNode> actionListener = ActionListener.wrap(node -> {
if (clusterService.localNode().getId().equals(node.getId())) {
Expand All @@ -133,9 +131,9 @@ public void dispatchTask(
transportService.sendRequest(node, getTransportActionName(), request, getResponseHandler(listener));
}
}, e -> { listener.onFailure(e); });
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, algorithm, true);
String[] workerNodes = mlModelManager.getWorkerNodes(modelId, functionName, true);
if (workerNodes == null || workerNodes.length == 0) {
if (algorithm == FunctionName.TEXT_EMBEDDING || algorithm == FunctionName.REMOTE) {
if (functionName == FunctionName.TEXT_EMBEDDING || functionName == FunctionName.REMOTE) {
listener
.onFailure(
new IllegalArgumentException(
Expand All @@ -144,7 +142,7 @@ public void dispatchTask(
);
return;
} else {
workerNodes = nodeHelper.getEligibleNodeIds(algorithm);
workerNodes = nodeHelper.getEligibleNodeIds(functionName);
}
}
mlTaskDispatcher.dispatchPredictTask(workerNodes, actionListener);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,18 +212,30 @@ public void setup() throws IOException {
public void testExecuteTask_OnLocalNode() {
setupMocks(true, false, false, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
verify(client).get(any(), any());
verify(mlTaskManager).remove(anyString());
}

public void testExecuteTask_OnLocalNode_RemoteModel() {
setupMocks(true, false, false, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
ArgumentCaptor<Exception> argumentCaptor = ArgumentCaptor.forClass(Exception.class);
verify(listener).onFailure(argumentCaptor.capture());
assertTrue(argumentCaptor.getValue().getMessage().contains("Model not ready yet."));
verify(mlTaskManager, never()).add(any(MLTask.class));
verify(client, never()).get(any(), any());
}

public void testExecuteTask_OnLocalNode_QueryInput() {
setupMocks(true, false, false, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener);
verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
Expand All @@ -234,7 +246,7 @@ public void testExecuteTask_OnLocalNode_QueryInput() {
public void testExecuteTask_OnLocalNode_QueryInput_Failure() {
setupMocks(true, true, false, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithQuery, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithQuery, transportService, listener);
verify(mlInputDatasetHandler).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler, never()).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager, never()).add(any(MLTask.class));
Expand All @@ -245,7 +257,7 @@ public void testExecuteTask_NoPermission() {
setupMocks(true, true, false, false);
threadContext.stashContext();
threadContext.putTransient(ConfigConstants.OPENSEARCH_SECURITY_USER_INFO_THREAD_CONTEXT, "test_user|test_role|test_tenant");
taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlTaskManager).add(any(MLTask.class));
verify(mlTaskManager).remove(anyString());
verify(client).get(any(), any());
Expand All @@ -256,14 +268,14 @@ public void testExecuteTask_NoPermission() {

public void testExecuteTask_OnRemoteNode() {
setupMocks(false, false, false, false);
taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(transportService).sendRequest(eq(remoteNode), eq(MLPredictionTaskAction.NAME), eq(requestWithDataFrame), any());
}

public void testExecuteTask_OnLocalNode_GetModelFail() {
setupMocks(true, false, true, false);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
Expand All @@ -277,7 +289,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() {
setupMocks(true, false, false, false);
requestWithDataFrame = MLPredictionTaskRequest.builder().mlInput(mlInputWithDataFrame).build();

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
Expand All @@ -291,7 +303,7 @@ public void testExecuteTask_OnLocalNode_NullModelIdException() {
public void testExecuteTask_OnLocalNode_NullGetResponse() {
setupMocks(true, false, false, true);

taskRunner.dispatchTask(FunctionName.REMOTE, requestWithDataFrame, transportService, listener);
taskRunner.dispatchTask(FunctionName.BATCH_RCF, requestWithDataFrame, transportService, listener);
verify(mlInputDatasetHandler, never()).parseSearchQueryInput(any(), any());
// verify(mlInputDatasetHandler).parseDataFrameInput(requestWithDataFrame.getMlInput().getInputDataset());
verify(mlTaskManager).add(any(MLTask.class));
Expand Down

0 comments on commit 8de0431

Please sign in to comment.