Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.19] addressing client changes due to adding tenantId in the apis #3480

Merged
merged 1 commit into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ default void deleteModel(String modelId, ActionListener<DeleteResponse> listener
*/
default ActionFuture<DeleteResponse> deleteTask(String taskId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteModel(taskId, actionFuture);
deleteTask(taskId, actionFuture);
return actionFuture;
}

Expand Down Expand Up @@ -361,7 +361,7 @@ default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Null
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param nodeIds the node ids. May be null for all nodes.
* @param listener a listener to be notified of the result
*/
default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
Expand All @@ -372,7 +372,7 @@ default void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUnde
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param nodeIds the node ids. May be null for all nodes.
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
Expand Down Expand Up @@ -480,8 +480,7 @@ default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
* @param listener a listener to be notified of the result
*/
default void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, null, actionFuture);
deleteAgent(agentId, null, listener);
}

/**
Expand Down Expand Up @@ -543,5 +542,15 @@ default ActionFuture<MLConfig> getConfig(String configId) {
* @param configId ML config id
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, ActionListener<MLConfig> listener);
default void getConfig(String configId, ActionListener<MLConfig> listener) {
getConfig(configId, null, listener);
}

/**
* Delete agent
* @param configId ML config id
* @param tenantId the tenant id. This is necessary for multi-tenancy.
* @param listener a listener to be notified of the result
*/
void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener);
}
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ public void getTool(String toolName, ActionListener<ToolMetadata> listener) {
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).build();
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).tenantId(tenantId).build();

client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,6 @@ public void setUp() {
.build();

machineLearningClient = new MachineLearningClient() {
@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
listener.onResponse(output);
}

@Override
public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener<MLOutput> listener) {
Expand All @@ -169,21 +165,11 @@ public void run(MLInput mlInput, Map<String, Object> args, ActionListener<MLOutp
listener.onResponse(output);
}

@Override
public void getModel(String modelId, ActionListener<MLModel> listener) {
listener.onResponse(mlModel);
}

@Override
public void getModel(String modelId, String tenantId, ActionListener<MLModel> listener) {
listener.onResponse(mlModel);
}

@Override
public void deleteModel(String modelId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteModel(String modelId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
Expand All @@ -194,21 +180,11 @@ public void searchModel(SearchRequest searchRequest, ActionListener<SearchRespon
listener.onResponse(searchResponse);
}

@Override
public void getTask(String taskId, ActionListener<MLTask> listener) {
listener.onResponse(mlTask);
}

@Override
public void getTask(String taskId, String tenantId, ActionListener<MLTask> listener) {
listener.onResponse(mlTask);
}

@Override
public void deleteTask(String taskId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteTask(String taskId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
Expand All @@ -224,21 +200,11 @@ public void register(MLRegisterModelInput mlInput, ActionListener<MLRegisterMode
listener.onResponse(registerModelResponse);
}

@Override
public void deploy(String modelId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}

@Override
public void deploy(String modelId, String tenantId, ActionListener<MLDeployModelResponse> listener) {
listener.onResponse(deployModelResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
Expand All @@ -259,11 +225,6 @@ public void deleteConnector(String connectorId, String tenantId, ActionListener<
listener.onResponse(deleteResponse);
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void listTools(ActionListener<List<ToolMetadata>> listener) {
listener.onResponse(toolsList);
Expand All @@ -286,18 +247,13 @@ public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentRespons
listener.onResponse(registerAgentResponse);
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void deleteAgent(String agentId, String tenantId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

@Override
public void getConfig(String configId, ActionListener<MLConfig> listener) {
public void getConfig(String configId, String tenantId, ActionListener<MLConfig> listener) {
listener.onResponse(mlConfig);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ public void deleteTask() {
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<DeleteResponse> argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class);
machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener);
machineLearningNodeClient.deleteTask(taskId, null, deleteTaskActionListener);

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
verify(deleteTaskActionListener).onResponse(argumentCaptor.capture());
Expand Down Expand Up @@ -1276,6 +1276,185 @@ public void getConfigRejectedMasterKey() {
assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage());
}

@Test
public void predict_withTenantId() {
String tenantId = "testTenant";
doAnswer(invocation -> {
ActionListener<MLTaskResponse> actionListener = invocation.getArgument(2);
MLPredictionOutput predictionOutput = MLPredictionOutput
.builder()
.status("Success")
.predictionResult(output)
.taskId("taskId")
.build();
actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build());
return null;
}).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any());

ArgumentCaptor<MLPredictionTaskRequest> requestCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class);
MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build();
machineLearningNodeClient.predict("modelId", tenantId, mlInput, dataFrameActionListener);

verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), requestCaptor.capture(), any());
assertEquals(tenantId, requestCaptor.getValue().getTenantId());
assertEquals("modelId", requestCaptor.getValue().getModelId());
}

@Test
public void getTask_withFailure() {
String taskId = "taskId";
String errorMessage = "Task not found";

doAnswer(invocation -> {
ActionListener<MLTaskGetResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new IllegalArgumentException(errorMessage));
return null;
}).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any());

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

machineLearningNodeClient.getTask(taskId, new ActionListener<>() {
@Override
public void onResponse(MLTask mlTask) {
fail("Expected failure but got success");
}

@Override
public void onFailure(Exception e) {
assertEquals(errorMessage, e.getMessage());
}
});

verify(client).execute(eq(MLTaskGetAction.INSTANCE), isA(MLTaskGetRequest.class), any());
}

@Test
public void deploy_withTenantId() {
String modelId = "testModel";
String tenantId = "testTenant";
String taskId = "taskId";
String status = MLTaskState.CREATED.name();

doAnswer(invocation -> {
ActionListener<MLDeployModelResponse> actionListener = invocation.getArgument(2);
MLDeployModelResponse output = new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, status);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any());

ArgumentCaptor<MLDeployModelRequest> requestCaptor = ArgumentCaptor.forClass(MLDeployModelRequest.class);
machineLearningNodeClient.deploy(modelId, tenantId, deployModelActionListener);

verify(client).execute(eq(MLDeployModelAction.INSTANCE), requestCaptor.capture(), any());
assertEquals(modelId, requestCaptor.getValue().getModelId());
assertEquals(tenantId, requestCaptor.getValue().getTenantId());
}

@Test
public void trainAndPredict_withNullInput() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("ML Input can't be null");

machineLearningNodeClient.trainAndPredict(null, trainingActionListener);
}

@Test
public void trainAndPredict_withNullDataSet() {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("input data set can't be null");

MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build();
machineLearningNodeClient.trainAndPredict(mlInput, trainingActionListener);
}

@Test
public void getTask_withTaskIdAndTenantId() {
String taskId = "taskId";
String tenantId = "testTenant";
String modelId = "modelId";

doAnswer(invocation -> {
ActionListener<MLTaskGetResponse> actionListener = invocation.getArgument(2);
MLTask mlTask = MLTask.builder().taskId(taskId).modelId(modelId).functionName(FunctionName.KMEANS).build();
MLTaskGetResponse output = MLTaskGetResponse.builder().mlTask(mlTask).build();
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any());

ArgumentCaptor<MLTaskGetRequest> requestCaptor = ArgumentCaptor.forClass(MLTaskGetRequest.class);
ArgumentCaptor<MLTask> taskCaptor = ArgumentCaptor.forClass(MLTask.class);

machineLearningNodeClient.getTask(taskId, tenantId, getTaskActionListener);

verify(client).execute(eq(MLTaskGetAction.INSTANCE), requestCaptor.capture(), any());
verify(getTaskActionListener).onResponse(taskCaptor.capture());

// Verify request parameters
assertEquals(taskId, requestCaptor.getValue().getTaskId());
assertEquals(tenantId, requestCaptor.getValue().getTenantId());

// Verify response
assertEquals(taskId, taskCaptor.getValue().getTaskId());
assertEquals(modelId, taskCaptor.getValue().getModelId());
assertEquals(FunctionName.KMEANS, taskCaptor.getValue().getFunctionName());
}

@Test
public void deleteTask_withTaskId() {
String taskId = "taskId";

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1);
DeleteResponse output = new DeleteResponse(shardId, taskId, 1, 1, 1, true);
actionListener.onResponse(output);
return null;
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<MLTaskDeleteRequest> requestCaptor = ArgumentCaptor.forClass(MLTaskDeleteRequest.class);
ArgumentCaptor<DeleteResponse> responseCaptor = ArgumentCaptor.forClass(DeleteResponse.class);

machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener);

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), requestCaptor.capture(), any());
verify(deleteTaskActionListener).onResponse(responseCaptor.capture());

// Verify request parameter
assertEquals(taskId, requestCaptor.getValue().getTaskId());

// Verify response
assertEquals(taskId, responseCaptor.getValue().getId());
assertEquals("DELETED", responseCaptor.getValue().getResult().toString());
}

@Test
public void deleteTask_withFailure() {
String taskId = "taskId";
String errorMessage = "Task deletion failed";

doAnswer(invocation -> {
ActionListener<DeleteResponse> actionListener = invocation.getArgument(2);
actionListener.onFailure(new RuntimeException(errorMessage));
return null;
}).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any());

ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

machineLearningNodeClient.deleteTask(taskId, new ActionListener<>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
fail("Expected failure but got success");
}

@Override
public void onFailure(Exception e) {
assertEquals(errorMessage, e.getMessage());
}
});

verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any());
}

private SearchResponse createSearchResponse(ToXContentObject o) throws IOException {
XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);

Expand Down
Loading