Skip to content

Commit

Permalink
add more methods to client (#1773) (#1782)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
(cherry picked from commit 13dbde1)

Co-authored-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
opensearch-trigger-bot[bot] and ylwu-amzn authored Dec 18, 2023
1 parent 3ab976d commit 5d33151
Show file tree
Hide file tree
Showing 4 changed files with 298 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Nullable;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand All @@ -27,6 +30,7 @@
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
Expand Down Expand Up @@ -254,7 +258,7 @@ default ActionFuture<MLRegisterModelResponse> register(MLRegisterModelInput mlIn

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
* @param modelId the model id
*/
default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
Expand All @@ -265,12 +269,33 @@ default ActionFuture<MLDeployModelResponse> deploy(String modelId) {

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
* @param modelId the model id
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);

/**
* Undeploy models
* For additional info on undeploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param nodeIds the node ids. May be null for all nodes.
*/
default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Nullable String[] nodeIds) {
PlainActionFuture<MLUndeployModelsResponse> actionFuture = PlainActionFuture.newFuture();
undeploy(modelIds, nodeIds, actionFuture);
return actionFuture;
}

/**
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param listener a listener to be notified of the result
*/
void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener);

/**
* Create connector for remote model
* @param mlCreateConnectorInput Create Connector Input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/extensibility/connectors/
Expand All @@ -284,6 +309,19 @@ default ActionFuture<MLCreateConnectorResponse> createConnector(MLCreateConnecto

void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener);

/**
* Delete connector for remote model
* @param connectorId The id of the connector to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteConnector(String connectorId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteConnector(connectorId, actionFuture);
return actionFuture;
}

void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener);

/**
* Register model group
* For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group
Expand Down Expand Up @@ -321,4 +359,35 @@ default ActionFuture<MLExecuteTaskResponse> execute(FunctionName name, Input inp
* @param listener a listener to be notified of the result
*/
void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener);

/**
* Registers new agent and returns ActionFuture.
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
* @return the result future
*/
default ActionFuture<MLRegisterAgentResponse> registerAgent(MLAgent mlAgent) {
PlainActionFuture<MLRegisterAgentResponse> actionFuture = PlainActionFuture.newFuture();
registerAgent(mlAgent, actionFuture);
return actionFuture;
}

/**
* Registers new agent and returns agent ID in response
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
*/
void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener);

/**
* Delete agent
* @param agentId The id of the agent to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, actionFuture);
return actionFuture;
}

void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,19 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
Expand Down Expand Up @@ -66,6 +74,9 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -242,12 +253,50 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener));
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds);
client.execute(MLUndeployModelsAction.INSTANCE, undeployModelRequest, getMlUndeployModelsResponseActionListener(listener));
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput);
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener));
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId);
client.execute(MLConnectorDeleteAction.INSTANCE, connectorDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
}

@Override
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
MLRegisterAgentRequest mlRegisterAgentRequest = MLRegisterAgentRequest.builder().mlAgent(mlAgent).build();
client.execute(MLRegisterAgentAction.INSTANCE, mlRegisterAgentRequest, getMLRegisterAgentResponseActionListener(listener));
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
ActionListener<MLRegisterAgentResponse> actionListener = wrapActionListener(listener, res -> {
MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res);
return mlRegisterAgentResponse;
});
return actionListener;
}

private ActionListener<MLTaskGetResponse> getMLTaskResponseActionListener(ActionListener<MLTask> listener) {
ActionListener<MLTaskGetResponse> internalListener = ActionListener
.wrap(getResponse -> { listener.onResponse(getResponse.getMlTask()); }, listener::onFailure);
Expand All @@ -266,6 +315,16 @@ private ActionListener<MLDeployModelResponse> getMlDeployModelResponseActionList
return actionListener;
}

private ActionListener<MLUndeployModelsResponse> getMlUndeployModelsResponseActionListener(
ActionListener<MLUndeployModelsResponse> listener
) {
ActionListener<MLUndeployModelsResponse> actionListener = wrapActionListener(listener, response -> {
MLUndeployModelsResponse deployModelResponse = MLUndeployModelsResponse.fromActionResponse(response);
return deployModelResponse;
});
return actionListener;
}

private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseActionListener(
ActionListener<MLCreateConnectorResponse> listener
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.input.Input;
Expand All @@ -42,6 +43,7 @@
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand All @@ -50,6 +52,7 @@
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

public class MachineLearningClientTest {

Expand Down Expand Up @@ -79,6 +82,9 @@ public class MachineLearningClientTest {
@Mock
MLDeployModelResponse deployModelResponse;

@Mock
MLUndeployModelsResponse undeployModelsResponse;

@Mock
MLCreateConnectorResponse createConnectorResponse;

Expand All @@ -88,6 +94,9 @@ public class MachineLearningClientTest {
@Mock
MLExecuteTaskResponse mlExecuteTaskResponse;

@Mock
MLRegisterAgentResponse registerAgentResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -163,6 +172,11 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
listener.onResponse(deployModelResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
listener.onResponse(createConnectorResponse);
Expand All @@ -173,12 +187,27 @@ public void execute(FunctionName name, Input input, ActionListener<MLExecuteTask
listener.onResponse(mlExecuteTaskResponse);
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

public void registerModelGroup(
MLRegisterModelGroupInput mlRegisterModelGroupInput,
ActionListener<MLRegisterModelGroupResponse> listener
) {
listener.onResponse(registerModelGroupResponse);
}

@Override
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
listener.onResponse(registerAgentResponse);
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}
};
}

Expand Down Expand Up @@ -344,6 +373,11 @@ public void deploy() {
assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet());
}

@Test
public void undeploy() {
assertEquals(undeployModelsResponse, machineLearningClient.undeploy(new String[] { "modelId" }, null).actionGet());
}

@Test
public void createConnector() {
Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7"));
Expand Down Expand Up @@ -420,4 +454,20 @@ public void executeMetricsCorrelation() {
machineLearningClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput).actionGet()
);
}

@Test
public void deleteConnector() {
assertEquals(deleteResponse, machineLearningClient.deleteConnector("connectorId").actionGet());
}

@Test
public void testRegisterAgent() {
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();
assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet());
}

@Test
public void deleteAgent() {
assertEquals(deleteResponse, machineLearningClient.deleteAgent("agentId").actionGet());
}
}
Loading

0 comments on commit 5d33151

Please sign in to comment.