Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] add more methods to client #1782

Merged
merged 1 commit into from
Dec 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.support.PlainActionFuture;
import org.opensearch.common.Nullable;
import org.opensearch.common.action.ActionFuture;
import org.opensearch.core.action.ActionListener;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand All @@ -27,6 +30,7 @@
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

/**
* A client to provide interfaces for machine learning jobs. This will be used by other plugins.
Expand Down Expand Up @@ -254,7 +258,7 @@ default ActionFuture<MLRegisterModelResponse> register(MLRegisterModelInput mlIn

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
* @param modelId the model id
*/
default ActionFuture<MLDeployModelResponse> deploy(String modelId) {
Expand All @@ -265,12 +269,33 @@ default ActionFuture<MLDeployModelResponse> deploy(String modelId) {

/**
* Deploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#deploying-a-model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/deploy-model/
* @param modelId the model id
* @param listener a listener to be notified of the result
*/
void deploy(String modelId, ActionListener<MLDeployModelResponse> listener);

/**
* Undeploy models
* For additional info on undeploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param nodeIds the node ids. May be null for all nodes.
*/
default ActionFuture<MLUndeployModelsResponse> undeploy(String[] modelIds, @Nullable String[] nodeIds) {
PlainActionFuture<MLUndeployModelsResponse> actionFuture = PlainActionFuture.newFuture();
undeploy(modelIds, nodeIds, actionFuture);
return actionFuture;
}

/**
* Undeploy model
* For additional info on deploy, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/model-apis/undeploy-model/
* @param modelIds the model ids
* @param modelIds the node ids. May be null for all nodes.
* @param listener a listener to be notified of the result
*/
void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener);

/**
* Create connector for remote model
* @param mlCreateConnectorInput Create Connector Input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/extensibility/connectors/
Expand All @@ -284,6 +309,19 @@ default ActionFuture<MLCreateConnectorResponse> createConnector(MLCreateConnecto

void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener);

/**
* Delete connector for remote model
* @param connectorId The id of the connector to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteConnector(String connectorId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteConnector(connectorId, actionFuture);
return actionFuture;
}

void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener);

/**
* Register model group
* For additional info on model group, refer: https://opensearch.org/docs/latest/ml-commons-plugin/model-access-control#registering-a-model-group
Expand Down Expand Up @@ -321,4 +359,35 @@ default ActionFuture<MLExecuteTaskResponse> execute(FunctionName name, Input inp
* @param listener a listener to be notified of the result
*/
void execute(FunctionName name, Input input, ActionListener<MLExecuteTaskResponse> listener);

/**
* Registers new agent and returns ActionFuture.
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
* @return the result future
*/
default ActionFuture<MLRegisterAgentResponse> registerAgent(MLAgent mlAgent) {
PlainActionFuture<MLRegisterAgentResponse> actionFuture = PlainActionFuture.newFuture();
registerAgent(mlAgent, actionFuture);
return actionFuture;
}

/**
* Registers new agent and returns agent ID in response
* @param mlAgent Register agent input, refer: https://opensearch.org/docs/latest/ml-commons-plugin/api/#register-agent
*/
void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener);

/**
* Delete agent
* @param agentId The id of the agent to delete
* @return the result future
*/
default ActionFuture<DeleteResponse> deleteAgent(String agentId) {
PlainActionFuture<DeleteResponse> actionFuture = PlainActionFuture.newFuture();
deleteAgent(agentId, actionFuture);
return actionFuture;
}

void deleteAgent(String agentId, ActionListener<DeleteResponse> listener);

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,19 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteAction;
import org.opensearch.ml.common.transport.agent.MLAgentDeleteRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentAction;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentRequest;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteAction;
import org.opensearch.ml.common.transport.connector.MLConnectorDeleteRequest;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorAction;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorRequest;
Expand Down Expand Up @@ -66,6 +74,9 @@
import org.opensearch.ml.common.transport.training.MLTrainingTaskAction;
import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest;
import org.opensearch.ml.common.transport.trainpredict.MLTrainAndPredictionTaskAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsAction;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsRequest;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

import lombok.AccessLevel;
import lombok.RequiredArgsConstructor;
Expand Down Expand Up @@ -242,12 +253,50 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
client.execute(MLDeployModelAction.INSTANCE, deployModelRequest, getMlDeployModelResponseActionListener(listener));
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
MLUndeployModelsRequest undeployModelRequest = new MLUndeployModelsRequest(modelIds, nodeIds);
client.execute(MLUndeployModelsAction.INSTANCE, undeployModelRequest, getMlUndeployModelsResponseActionListener(listener));
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
MLCreateConnectorRequest createConnectorRequest = new MLCreateConnectorRequest(mlCreateConnectorInput);
client.execute(MLCreateConnectorAction.INSTANCE, createConnectorRequest, getMlCreateConnectorResponseActionListener(listener));
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
MLConnectorDeleteRequest connectorDeleteRequest = new MLConnectorDeleteRequest(connectorId);
client.execute(MLConnectorDeleteAction.INSTANCE, connectorDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
}

@Override
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
MLRegisterAgentRequest mlRegisterAgentRequest = MLRegisterAgentRequest.builder().mlAgent(mlAgent).build();
client.execute(MLRegisterAgentAction.INSTANCE, mlRegisterAgentRequest, getMLRegisterAgentResponseActionListener(listener));
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
MLAgentDeleteRequest agentDeleteRequest = new MLAgentDeleteRequest(agentId);
client.execute(MLAgentDeleteAction.INSTANCE, agentDeleteRequest, ActionListener.wrap(deleteResponse -> {
listener.onResponse(deleteResponse);
}, listener::onFailure));
}

private ActionListener<MLRegisterAgentResponse> getMLRegisterAgentResponseActionListener(
ActionListener<MLRegisterAgentResponse> listener
) {
ActionListener<MLRegisterAgentResponse> actionListener = wrapActionListener(listener, res -> {
MLRegisterAgentResponse mlRegisterAgentResponse = MLRegisterAgentResponse.fromActionResponse(res);
return mlRegisterAgentResponse;
});
return actionListener;
}

private ActionListener<MLTaskGetResponse> getMLTaskResponseActionListener(ActionListener<MLTask> listener) {
ActionListener<MLTaskGetResponse> internalListener = ActionListener
.wrap(getResponse -> { listener.onResponse(getResponse.getMlTask()); }, listener::onFailure);
Expand All @@ -266,6 +315,16 @@ private ActionListener<MLDeployModelResponse> getMlDeployModelResponseActionList
return actionListener;
}

private ActionListener<MLUndeployModelsResponse> getMlUndeployModelsResponseActionListener(
ActionListener<MLUndeployModelsResponse> listener
) {
ActionListener<MLUndeployModelsResponse> actionListener = wrapActionListener(listener, response -> {
MLUndeployModelsResponse deployModelResponse = MLUndeployModelsResponse.fromActionResponse(response);
return deployModelResponse;
});
return actionListener;
}

private ActionListener<MLCreateConnectorResponse> getMlCreateConnectorResponseActionListener(
ActionListener<MLCreateConnectorResponse> listener
) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.agent.MLAgent;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.input.Input;
Expand All @@ -42,6 +43,7 @@
import org.opensearch.ml.common.model.TextEmbeddingModelConfig;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.MLTrainingOutput;
import org.opensearch.ml.common.transport.agent.MLRegisterAgentResponse;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorInput;
import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse;
import org.opensearch.ml.common.transport.deploy.MLDeployModelResponse;
Expand All @@ -50,6 +52,7 @@
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupResponse;
import org.opensearch.ml.common.transport.register.MLRegisterModelInput;
import org.opensearch.ml.common.transport.register.MLRegisterModelResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

public class MachineLearningClientTest {

Expand Down Expand Up @@ -79,6 +82,9 @@ public class MachineLearningClientTest {
@Mock
MLDeployModelResponse deployModelResponse;

@Mock
MLUndeployModelsResponse undeployModelsResponse;

@Mock
MLCreateConnectorResponse createConnectorResponse;

Expand All @@ -88,6 +94,9 @@ public class MachineLearningClientTest {
@Mock
MLExecuteTaskResponse mlExecuteTaskResponse;

@Mock
MLRegisterAgentResponse registerAgentResponse;

private String modekId = "test_model_id";
private MLModel mlModel;
private MLTask mlTask;
Expand Down Expand Up @@ -163,6 +172,11 @@ public void deploy(String modelId, ActionListener<MLDeployModelResponse> listene
listener.onResponse(deployModelResponse);
}

@Override
public void undeploy(String[] modelIds, String[] nodeIds, ActionListener<MLUndeployModelsResponse> listener) {
listener.onResponse(undeployModelsResponse);
}

@Override
public void createConnector(MLCreateConnectorInput mlCreateConnectorInput, ActionListener<MLCreateConnectorResponse> listener) {
listener.onResponse(createConnectorResponse);
Expand All @@ -173,12 +187,27 @@ public void execute(FunctionName name, Input input, ActionListener<MLExecuteTask
listener.onResponse(mlExecuteTaskResponse);
}

@Override
public void deleteConnector(String connectorId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}

public void registerModelGroup(
MLRegisterModelGroupInput mlRegisterModelGroupInput,
ActionListener<MLRegisterModelGroupResponse> listener
) {
listener.onResponse(registerModelGroupResponse);
}

@Override
public void registerAgent(MLAgent mlAgent, ActionListener<MLRegisterAgentResponse> listener) {
listener.onResponse(registerAgentResponse);
}

@Override
public void deleteAgent(String agentId, ActionListener<DeleteResponse> listener) {
listener.onResponse(deleteResponse);
}
};
}

Expand Down Expand Up @@ -344,6 +373,11 @@ public void deploy() {
assertEquals(deployModelResponse, machineLearningClient.deploy("modelId").actionGet());
}

@Test
public void undeploy() {
assertEquals(undeployModelsResponse, machineLearningClient.undeploy(new String[] { "modelId" }, null).actionGet());
}

@Test
public void createConnector() {
Map<String, String> params = Map.ofEntries(Map.entry("endpoint", "endpoint"), Map.entry("temp", "7"));
Expand Down Expand Up @@ -420,4 +454,20 @@ public void executeMetricsCorrelation() {
machineLearningClient.execute(FunctionName.METRICS_CORRELATION, metricsCorrelationInput).actionGet()
);
}

@Test
public void deleteConnector() {
assertEquals(deleteResponse, machineLearningClient.deleteConnector("connectorId").actionGet());
}

@Test
public void testRegisterAgent() {
MLAgent mlAgent = MLAgent.builder().name("Agent name").build();
assertEquals(registerAgentResponse, machineLearningClient.registerAgent(mlAgent).actionGet());
}

@Test
public void deleteAgent() {
assertEquals(deleteResponse, machineLearningClient.deleteAgent("agentId").actionGet());
}
}
Loading
Loading