Skip to content

Commit

Permalink
addressing client changes due to adding tenantId in the apis (#3474)
Browse files Browse the repository at this point in the history
Signed-off-by: Dhrubo Saha <dhrubo@amazon.com>
  • Loading branch information
dhrubo-os authored Jan 31, 2025
1 parent da7d5b9 commit 17b4d74
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ default void deleteModel(String modelId, ActionListener<DeleteResponse> listener
*/
default ActionFuture<DeleteResponse> deleteTask(String taskId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteModel(taskId, actionFuture);
deleteTask(taskId, actionFuture);
return actionFuture;
}

Expand Down Expand Up @@ -361,7 +361,7 @@ default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Null
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param nodeIds the node ids. May be null for all nodes.
* @param listener a listener to be notified of the result
*/
default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
Expand All @@ -372,7 +372,7 @@ default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUnde
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param nodeIds the node ids. May be null for all nodes.
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
Expand Down Expand Up @@ -480,8 +480,7 @@ default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
* @param listener a listener to be notified of the result
*/
default void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, null, actionFuture);
deleteAgent(agentId, null, listener);
}

/**
Expand Down Expand Up @@ -543,5 +542,15 @@ default ActionFuture<MLConfig> getConfig(String configId) {
* @param configId ML config id
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, ActionListener<MLConfig> listener);
default void getConfig(String configId, ActionListener<MLConfig> listener) {
getConfig(configId, null, listener);
}

/**
* Delete agent
* @param configId ML config id
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).build();
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).tenantId(tenantId).build();

client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,6 @@ public void setUp() {
.build();

machineLearningClient = new MachineLearningClient() {
@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
listener.onResponse(output);
}

@Override
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
Expand All @@ -169,21 +165,11 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
listener.onResponse(output);
}

@Override
public void getModel(String modelId, ActionListener<MLModel> listener) {
listener.onResponse(mlModel);
}

@Override
public void getModel(String modelId, String tenantId, ActionListener<MLModel> listener) {
listener.onResponse(mlModel);
}

@Override
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
Expand All @@ -194,21 +180,11 @@ public void searchModel(SearchRequest searchRequest, ActionListener<SearchRespon
listener.onResponse(searchResponse);
}

@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
listener.onResponse(mlTask);
}

@Override
public void getTask(String taskId, String tenantId, ActionListener<MLTask> listener) {
listener.onResponse(mlTask);
}

@Override
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
Expand All @@ -224,21 +200,11 @@ public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterMode
listener.onResponse(registerModelResponse);
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}

@Override
public void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
Expand All @@ -259,11 +225,6 @@ public void deleteConnector(String connectorId, String tenantId, ActionListener<
listener.onResponse(deleteResponse);
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void listTools(ActionListener<List<ToolMetadata>> listener) {
listener.onResponse(toolsList);
Expand All @@ -286,18 +247,13 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
listener.onResponse(registerAgentResponse);
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ public void deleteTask() {
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener);
machineLearningNodeClient.deleteTask(taskId, null, deleteTaskActionListener);

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
verify(deleteTaskActionListener).onResponse(argumentCaptor.capture());
Expand Down Expand Up @@ -1276,6 +1276,185 @@ public void getConfigRejectedMasterKey() {
assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage());
}

@Test
public void predict_withTenantId() {
String tenantId = "testTenant";
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLPredictionOutput predictionOutput = MLPredictionOutput
.builder()
.status("Success")
.predictionResult(output)
.taskId("taskId")
.build();
actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());

ArgumentCaptor<MLPredictionTaskRequest> requestCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.predict("modelId", tenantId, mlInput, dataFrameActionListener);

verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), requestCaptor.capture(), any());
assertEquals(tenantId, requestCaptor.getValue().getTenantId());
assertEquals("modelId", requestCaptor.getValue().getModelId());
}

@Test
public void getTask_withFailure() {
String taskId = "taskId";
String errorMessage = "Task not found";

doAnswer(invocation -> {
ActionListener<MLTaskGetResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new IllegalArgumentException(errorMessage));
return null;
}).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any());

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

machineLearningNodeClient.getTask(taskId, new ActionListener<>() {
@Override
public void onResponse(MLTask mlTask) {
fail("Expected failure but got success");
}

@Override
public void onFailure(Exception e) {
assertEquals(errorMessage, e.getMessage());
}
});

verify(client).execute(eq(MLTaskGetAction.INSTANCE), isA(MLTaskGetRequest.class), any());
}

@Test
public void deploy_withTenantId() {
String modelId = "testModel";
String tenantId = "testTenant";
String taskId = "taskId";
String status = MLTaskState.CREATED.name();

doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, status);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLDeployModelRequest> requestCaptor = ArgumentCaptor.forClass(MLDeployModelRequest.class);
machineLearningNodeClient.deploy(modelId, tenantId, deployModelActionListener);

verify(client).execute(eq(MLDeployModelAction.INSTANCE), requestCaptor.capture(), any());
assertEquals(modelId, requestCaptor.getValue().getModelId());
assertEquals(tenantId, requestCaptor.getValue().getTenantId());
}

@Test
public void trainAndPredict_withNullInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("ML Input can't be null");

machineLearningNodeClient.trainAndPredict(null, trainingActionListener);
}

@Test
public void trainAndPredict_withNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data set can't be null");

MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build();
machineLearningNodeClient.trainAndPredict(mlInput, trainingActionListener);
}

@Test
public void getTask_withTaskIdAndTenantId() {
String taskId = "taskId";
String tenantId = "testTenant";
String modelId = "modelId";

doAnswer(invocation -> {
ActionListener<MLTaskGetResponse> actionListener = invocation.getArgument(2);
MLTask mlTask = MLTask.builder().taskId(taskId).modelId(modelId).functionName(FunctionName.KMEANS).build();
MLTaskGetResponse output = MLTaskGetResponse.builder().mlTask(mlTask).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any());

ArgumentCaptor<MLTaskGetRequest> requestCaptor = ArgumentCaptor.forClass(MLTaskGetRequest.class);
ArgumentCaptor<MLTask> taskCaptor = ArgumentCaptor.forClass(MLTask.class);

machineLearningNodeClient.getTask(taskId, tenantId, getTaskActionListener);

verify(client).execute(eq(MLTaskGetAction.INSTANCE), requestCaptor.capture(), any());
verify(getTaskActionListener).onResponse(taskCaptor.capture());

// Verify request parameters
assertEquals(taskId, requestCaptor.getValue().getTaskId());
assertEquals(tenantId, requestCaptor.getValue().getTenantId());

// Verify response
assertEquals(taskId, taskCaptor.getValue().getTaskId());
assertEquals(modelId, taskCaptor.getValue().getModelId());
assertEquals(FunctionName.KMEANS, taskCaptor.getValue().getFunctionName());
}

@Test
public void deleteTask_withTaskId() {
String taskId = "taskId";

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, taskId, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<MLTaskDeleteRequest> requestCaptor = ArgumentCaptor.forClass(MLTaskDeleteRequest.class);
ArgumentCaptor<DeleteResponse> responseCaptor = ArgumentCaptor.forClass(DeleteResponse.class);

machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener);

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), requestCaptor.capture(), any());
verify(deleteTaskActionListener).onResponse(responseCaptor.capture());

// Verify request parameter
assertEquals(taskId, requestCaptor.getValue().getTaskId());

// Verify response
assertEquals(taskId, responseCaptor.getValue().getId());
assertEquals("DELETED", responseCaptor.getValue().getResult().toString());
}

@Test
public void deleteTask_withFailure() {
String taskId = "taskId";
String errorMessage = "Task deletion failed";

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new RuntimeException(errorMessage));
return null;
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

machineLearningNodeClient.deleteTask(taskId, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
fail("Expected failure but got success");
}

@Override
public void onFailure(Exception e) {
assertEquals(errorMessage, e.getMessage());
}
});

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down

0 comments on commit 17b4d74

Please sign in to comment.