diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java index f79c68ccaa..457f1b7d1f 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningClient.java @@ -237,7 +237,7 @@ default void deleteModel(String modelId, ActionListener listener */ default ActionFuture deleteTask(String taskId) { PlainActionFuture actionFuture = PlainActionFuture.newFuture(); - deleteModel(taskId, actionFuture); + deleteTask(taskId, actionFuture); return actionFuture; } @@ -361,7 +361,7 @@ default ActionFuture undeploy(String[] modelIds, @Null * Undeploy model * For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/ * @param modelIds the model ids - * @param modelIds the node ids. May be null for all nodes. + * @param nodeIds the node ids. May be null for all nodes. * @param listener a listener to be notified of the result */ default void undeploy(String[] modelIds, String[] nodeIds, ActionListener listener) { @@ -372,7 +372,7 @@ default void undeploy(String[] modelIds, String[] nodeIds, ActionListener deleteAgent(String agentId) { * @param listener a listener to be notified of the result */ default void deleteAgent(String agentId, ActionListener listener) { - PlainActionFuture actionFuture = PlainActionFuture.newFuture(); - deleteAgent(agentId, null, actionFuture); + deleteAgent(agentId, null, listener); } /** @@ -543,5 +542,15 @@ default ActionFuture getConfig(String configId) { * @param configId ML config id * @param listener a listener to be notified of the result */ - void getConfig(String configId, ActionListener listener); + default void getConfig(String configId, ActionListener listener) { + getConfig(configId, null, listener); + } + + /** + * Delete agent + * @param configId ML config id + * @param tenantId the tenant id. This is necessary for multi-tenancy. + * @param listener a listener to be notified of the result + */ + void getConfig(String configId, String tenantId, ActionListener listener); } diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index 1d29802bda..695b9f0892 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -312,8 +312,8 @@ public void getTool(String toolName, ActionListener listener) { } @Override - public void getConfig(String configId, ActionListener listener) { - MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).build(); + public void getConfig(String configId, String tenantId, ActionListener listener) { + MLConfigGetRequest mlConfigGetRequest = MLConfigGetRequest.builder().configId(configId).tenantId(tenantId).build(); client.execute(MLConfigGetAction.INSTANCE, mlConfigGetRequest, getMlGetConfigResponseActionListener(listener)); } diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java index e6a202806b..776aefd2cf 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningClientTest.java @@ -144,10 +144,6 @@ public void setUp() { .build(); machineLearningClient = new MachineLearningClient() { - @Override - public void predict(String modelId, MLInput mlInput, ActionListener listener) { - listener.onResponse(output); - } @Override public void predict(String modelId, String tenantId, MLInput mlInput, ActionListener listener) { @@ -169,21 +165,11 @@ public void run(MLInput mlInput, Map args, ActionListener listener) { - listener.onResponse(mlModel); - } - @Override public void getModel(String modelId, String tenantId, ActionListener listener) { listener.onResponse(mlModel); } - @Override - public void deleteModel(String modelId, ActionListener listener) { - listener.onResponse(deleteResponse); - } - @Override public void deleteModel(String modelId, String tenantId, ActionListener listener) { listener.onResponse(deleteResponse); @@ -194,21 +180,11 @@ public void searchModel(SearchRequest searchRequest, ActionListener listener) { - listener.onResponse(mlTask); - } - @Override public void getTask(String taskId, String tenantId, ActionListener listener) { listener.onResponse(mlTask); } - @Override - public void deleteTask(String taskId, ActionListener listener) { - listener.onResponse(deleteResponse); - } - @Override public void deleteTask(String taskId, String tenantId, ActionListener listener) { listener.onResponse(deleteResponse); @@ -224,21 +200,11 @@ public void register(MLRegisterModelInput mlInput, ActionListener listener) { - listener.onResponse(deployModelResponse); - } - @Override public void deploy(String modelId, String tenantId, ActionListener listener) { listener.onResponse(deployModelResponse); } - @Override - public void undeploy(String[] modelIds, String[] nodeIds, ActionListener listener) { - listener.onResponse(undeployModelsResponse); - } - @Override public void undeploy(String[] modelIds, String[] nodeIds, String tenantId, ActionListener listener) { listener.onResponse(undeployModelsResponse); @@ -259,11 +225,6 @@ public void deleteConnector(String connectorId, String tenantId, ActionListener< listener.onResponse(deleteResponse); } - @Override - public void deleteConnector(String connectorId, ActionListener listener) { - listener.onResponse(deleteResponse); - } - @Override public void listTools(ActionListener> listener) { listener.onResponse(toolsList); @@ -286,18 +247,13 @@ public void registerAgent(MLAgent mlAgent, ActionListener listener) { - listener.onResponse(deleteResponse); - } - @Override public void deleteAgent(String agentId, String tenantId, ActionListener listener) { listener.onResponse(deleteResponse); } @Override - public void getConfig(String configId, ActionListener listener) { + public void getConfig(String configId, String tenantId, ActionListener listener) { listener.onResponse(mlConfig); } }; diff --git a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java index 1f54795acf..76df1d9a8c 100644 --- a/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java +++ b/client/src/test/java/org/opensearch/ml/client/MachineLearningNodeClientTest.java @@ -884,7 +884,7 @@ public void deleteTask() { }).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any()); ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(DeleteResponse.class); - machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener); + machineLearningNodeClient.deleteTask(taskId, null, deleteTaskActionListener); verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any()); verify(deleteTaskActionListener).onResponse(argumentCaptor.capture()); @@ -1276,6 +1276,185 @@ public void getConfigRejectedMasterKey() { assertEquals("You are not allowed to access this config doc", argumentCaptor.getValue().getLocalizedMessage()); } + @Test + public void predict_withTenantId() { + String tenantId = "testTenant"; + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLPredictionOutput predictionOutput = MLPredictionOutput + .builder() + .status("Success") + .predictionResult(output) + .taskId("taskId") + .build(); + actionListener.onResponse(MLTaskResponse.builder().output(predictionOutput).build()); + return null; + }).when(client).execute(eq(MLPredictionTaskAction.INSTANCE), any(), any()); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).inputDataset(input).build(); + machineLearningNodeClient.predict("modelId", tenantId, mlInput, dataFrameActionListener); + + verify(client).execute(eq(MLPredictionTaskAction.INSTANCE), requestCaptor.capture(), any()); + assertEquals(tenantId, requestCaptor.getValue().getTenantId()); + assertEquals("modelId", requestCaptor.getValue().getModelId()); + } + + @Test + public void getTask_withFailure() { + String taskId = "taskId"; + String errorMessage = "Task not found"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new IllegalArgumentException(errorMessage)); + return null; + }).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any()); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + machineLearningNodeClient.getTask(taskId, new ActionListener<>() { + @Override + public void onResponse(MLTask mlTask) { + fail("Expected failure but got success"); + } + + @Override + public void onFailure(Exception e) { + assertEquals(errorMessage, e.getMessage()); + } + }); + + verify(client).execute(eq(MLTaskGetAction.INSTANCE), isA(MLTaskGetRequest.class), any()); + } + + @Test + public void deploy_withTenantId() { + String modelId = "testModel"; + String tenantId = "testTenant"; + String taskId = "taskId"; + String status = MLTaskState.CREATED.name(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLDeployModelResponse output = new MLDeployModelResponse(taskId, MLTaskType.DEPLOY_MODEL, status); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLDeployModelAction.INSTANCE), any(), any()); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(MLDeployModelRequest.class); + machineLearningNodeClient.deploy(modelId, tenantId, deployModelActionListener); + + verify(client).execute(eq(MLDeployModelAction.INSTANCE), requestCaptor.capture(), any()); + assertEquals(modelId, requestCaptor.getValue().getModelId()); + assertEquals(tenantId, requestCaptor.getValue().getTenantId()); + } + + @Test + public void trainAndPredict_withNullInput() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("ML Input can't be null"); + + machineLearningNodeClient.trainAndPredict(null, trainingActionListener); + } + + @Test + public void trainAndPredict_withNullDataSet() { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage("input data set can't be null"); + + MLInput mlInput = MLInput.builder().algorithm(FunctionName.KMEANS).build(); + machineLearningNodeClient.trainAndPredict(mlInput, trainingActionListener); + } + + @Test + public void getTask_withTaskIdAndTenantId() { + String taskId = "taskId"; + String tenantId = "testTenant"; + String modelId = "modelId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + MLTask mlTask = MLTask.builder().taskId(taskId).modelId(modelId).functionName(FunctionName.KMEANS).build(); + MLTaskGetResponse output = MLTaskGetResponse.builder().mlTask(mlTask).build(); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLTaskGetAction.INSTANCE), any(), any()); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(MLTaskGetRequest.class); + ArgumentCaptor taskCaptor = ArgumentCaptor.forClass(MLTask.class); + + machineLearningNodeClient.getTask(taskId, tenantId, getTaskActionListener); + + verify(client).execute(eq(MLTaskGetAction.INSTANCE), requestCaptor.capture(), any()); + verify(getTaskActionListener).onResponse(taskCaptor.capture()); + + // Verify request parameters + assertEquals(taskId, requestCaptor.getValue().getTaskId()); + assertEquals(tenantId, requestCaptor.getValue().getTenantId()); + + // Verify response + assertEquals(taskId, taskCaptor.getValue().getTaskId()); + assertEquals(modelId, taskCaptor.getValue().getModelId()); + assertEquals(FunctionName.KMEANS, taskCaptor.getValue().getFunctionName()); + } + + @Test + public void deleteTask_withTaskId() { + String taskId = "taskId"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, taskId, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any()); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(MLTaskDeleteRequest.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(DeleteResponse.class); + + machineLearningNodeClient.deleteTask(taskId, deleteTaskActionListener); + + verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), requestCaptor.capture(), any()); + verify(deleteTaskActionListener).onResponse(responseCaptor.capture()); + + // Verify request parameter + assertEquals(taskId, requestCaptor.getValue().getTaskId()); + + // Verify response + assertEquals(taskId, responseCaptor.getValue().getId()); + assertEquals("DELETED", responseCaptor.getValue().getResult().toString()); + } + + @Test + public void deleteTask_withFailure() { + String taskId = "taskId"; + String errorMessage = "Task deletion failed"; + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(new RuntimeException(errorMessage)); + return null; + }).when(client).execute(eq(MLTaskDeleteAction.INSTANCE), any(), any()); + + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + machineLearningNodeClient.deleteTask(taskId, new ActionListener<>() { + @Override + public void onResponse(DeleteResponse deleteResponse) { + fail("Expected failure but got success"); + } + + @Override + public void onFailure(Exception e) { + assertEquals(errorMessage, e.getMessage()); + } + }); + + verify(client).execute(eq(MLTaskDeleteAction.INSTANCE), isA(MLTaskDeleteRequest.class), any()); + } + private SearchResponse createSearchResponse(ToXContentObject o) throws IOException { XContentBuilder content = o.toXContent(XContentFactory.jsonBuilder(), ToXContent.EMPTY_PARAMS);