Skip to content

Commit

Permalink
support train ML model in either sync or async way
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <ylwu@amazon.com>
  • Loading branch information
ylwu-amzn committed Jan 21, 2022
1 parent 3ced87b commit 1d5da1d
Show file tree
Hide file tree
Showing 16 changed files with 254 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,24 @@ public class MLTrainingOutput extends MLOutput{

private static final MLOutputType OUTPUT_TYPE = MLOutputType.TRAINING;
public static final String MODEL_ID_FIELD = "model_id";
public static final String TASK_ID_FIELD = "task_id";
public static final String STATUS_FIELD = "status";
private String modelId;
private String taskId;
private String status;

@Builder
public MLTrainingOutput(String modelId, String status) {
public MLTrainingOutput(String modelId, String taskId, String status) {
super(OUTPUT_TYPE);
this.modelId = modelId;
this.taskId = taskId;
this.status= status;
}

public MLTrainingOutput(StreamInput in) throws IOException {
super(OUTPUT_TYPE);
this.modelId = in.readOptionalString();
this.taskId = in.readOptionalString();
this.status = in.readOptionalString();
}

Expand All @@ -53,13 +57,19 @@ public MLOutputType getType() {
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeOptionalString(modelId);
out.writeOptionalString(taskId);
out.writeOptionalString(status);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.startObject();
builder.field(MODEL_ID_FIELD, modelId);
if (modelId != null) {
builder.field(MODEL_ID_FIELD, modelId);
}
if (taskId != null) {
builder.field(TASK_ID_FIELD, taskId);
}
builder.field(STATUS_FIELD, status);
builder.endObject();
return builder;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,18 @@ public class MLTrainingTaskRequest extends ActionRequest {
* the name of algorithm
*/
MLInput mlInput;
boolean async;

@Builder
public MLTrainingTaskRequest(MLInput mlInput) {
public MLTrainingTaskRequest(MLInput mlInput, boolean async) {
this.mlInput = mlInput;
this.async = async;
}

public MLTrainingTaskRequest(StreamInput in) throws IOException {
super(in);
this.mlInput = new MLInput(in);
this.async = in.readBoolean();
}

@Override
Expand All @@ -69,6 +72,7 @@ public ActionRequestValidationException validate() {
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
this.mlInput.writeTo(out);
out.writeBoolean(async);
}

public static MLTrainingTaskRequest fromActionRequest(ActionRequest actionRequest) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import lombok.Data;

//TODO: remove this class, use MLModel
@Data
public class Model {
String name;
Expand Down
4 changes: 3 additions & 1 deletion plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.task.MLPredictTaskRunner',
'org.opensearch.ml.rest.RestMLTrainingAction',
'org.opensearch.ml.rest.RestMLPredictionAction',
'org.opensearch.ml.utils.RestActionUtils'
'org.opensearch.ml.utils.RestActionUtils',
'org.opensearch.ml.task.MLTaskCache',
'org.opensearch.ml.task.MLTaskManager'
]

jacocoTestCoverageVerification {
Expand Down
12 changes: 10 additions & 2 deletions plugin/src/main/java/org/opensearch/ml/model/MLTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ public class MLTask implements ToXContentObject, Writeable {
public static final String LAST_UPDATE_TIME_FIELD = "last_update_time";
public static final String ERROR_FIELD = "error";
public static final String USER_FIELD = "user";
public static final String IS_ASYNC_TASK_FIELD = "is_async";

@Setter
private String taskId;
private final String modelId;
@Setter
private String modelId;
private final MLTaskType taskType;
private final FunctionName functionName;
@Setter
Expand All @@ -63,6 +65,7 @@ public class MLTask implements ToXContentObject, Writeable {
@Setter
private String error;
private User user; // TODO: support document level access control later
private boolean async;

@Builder
public MLTask(
Expand All @@ -78,7 +81,8 @@ public MLTask(
Instant createTime,
Instant lastUpdateTime,
String error,
User user
User user,
boolean async
) {
this.taskId = taskId;
this.modelId = modelId;
Expand All @@ -93,6 +97,7 @@ public MLTask(
this.lastUpdateTime = lastUpdateTime;
this.error = error;
this.user = user;
this.async = async;
}

public MLTask(StreamInput input) throws IOException {
Expand All @@ -113,6 +118,7 @@ public MLTask(StreamInput input) throws IOException {
} else {
this.user = null;
}
this.async = input.readBoolean();
}

@Override
Expand All @@ -134,6 +140,7 @@ public void writeTo(StreamOutput out) throws IOException {
} else {
out.writeBoolean(false);
}
out.writeBoolean(async);
}

@Override
Expand Down Expand Up @@ -178,6 +185,7 @@ public XContentBuilder toXContent(XContentBuilder xContentBuilder, Params params
if (user != null) {
builder.field(USER_FIELD, user);
}
builder.field(IS_ASYNC_TASK_FIELD, async);
return builder.endObject();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_ALGORITHM;
import static org.opensearch.ml.utils.RestActionUtils.getAlgorithm;
import static org.opensearch.ml.utils.RestActionUtils.isAsync;

import java.io.IOException;
import java.util.List;
Expand Down Expand Up @@ -66,11 +67,12 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
@VisibleForTesting
MLTrainingTaskRequest getRequest(RestRequest request) throws IOException {
String algorithm = getAlgorithm(request);
boolean async = isAsync(request);

XContentParser parser = request.contentParser();
ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser);
MLInput mlInput = MLInput.parse(parser, algorithm);

return new MLTrainingTaskRequest(mlInput);
return new MLTrainingTaskRequest(mlInput, async);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,11 @@
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionListenerResponseHandler;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.get.GetRequest;
import org.opensearch.action.support.ThreadedActionListener;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.commons.authuser.User;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.ml.action.prediction.MLPredictionTaskExecutionAction;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
Expand All @@ -52,11 +50,11 @@
import org.opensearch.ml.engine.MLEngine;
import org.opensearch.ml.engine.Model;
import org.opensearch.ml.indices.MLInputDatasetHandler;
import org.opensearch.ml.model.MLModel;
import org.opensearch.ml.model.MLTask;
import org.opensearch.ml.model.MLTaskState;
import org.opensearch.ml.model.MLTaskType;
import org.opensearch.ml.stats.MLStats;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

Expand Down Expand Up @@ -126,21 +124,28 @@ public void execute(MLExecuteTaskRequest request, TransportService transportServ
* @param listener Action listener
*/
public void startPredictionTask(MLPredictionTaskRequest request, ActionListener<MLPredictionTaskResponse> listener) {
MLInputDataType inputDataType = request.getMlInput().getInputDataset().getInputDataType();
Instant now = Instant.now();
MLTask mlTask = MLTask
.builder()
.taskId(UUID.randomUUID().toString())
.modelId(request.getModelId())
.taskType(MLTaskType.PREDICTION)
.createTime(Instant.now())
.inputType(inputDataType)
.functionName(request.getMlInput().getFunctionName())
.state(MLTaskState.CREATED)
.workerNode(clusterService.localNode().getId())
.createTime(now)
.lastUpdateTime(now)
.async(false)
.build();
MLInput mlInput = request.getMlInput();
if (mlInput.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) {
ActionListener<DataFrame> dataFrameActionListener = ActionListener
.wrap(dataFrame -> { predict(mlTask, dataFrame, request, listener); }, e -> {
log.error("Failed to generate DataFrame from search query", e);
mlTaskManager.addIfAbsent(mlTask);
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.FAILED);
mlTaskManager.updateTaskError(mlTask.getTaskId(), e.getMessage());
handleMLTaskFailure(mlTask, e);
listener.onFailure(e);
});
mlInputDatasetHandler
Expand Down Expand Up @@ -168,26 +173,13 @@ private void predict(
// search model by model id.
Model model = new Model();
if (request.getModelId() != null) {
// Build search request to find the model by "taskId"
SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder();
QueryBuilder queryBuilder = QueryBuilders.termQuery(TASK_ID, request.getModelId());
searchSourceBuilder.query(queryBuilder);
SearchRequest searchRequest = new SearchRequest(new String[] { ML_MODEL_INDEX }, searchSourceBuilder);

// Search model.
client.search(searchRequest, ActionListener.wrap(searchResponse -> {
// No model found.
if (searchResponse.getHits().getTotalHits().value == 0
|| searchResponse.getHits().getAt(0).getSourceAsMap() == null
|| searchResponse.getHits().getAt(0).getSourceAsMap().isEmpty()) {
Exception e = new ResourceNotFoundException("No model found, please check the modelId.");
log.error(e);
handlePredictFailure(mlTask, listener, e);
GetRequest getRequest = new GetRequest(ML_MODEL_INDEX, mlTask.getModelId());
client.get(getRequest, ActionListener.wrap(r -> {
if (r == null || !r.isExists()) {
listener.onFailure(new ResourceNotFoundException("No model found, please check the modelId."));
return;
}

Map<String, Object> source = searchResponse.getHits().getAt(0).getSourceAsMap();

Map<String, Object> source = r.getSourceAsMap();
User requestUser = getUserContext(client);
User resourceUser = User.parse((String) source.get(USER));
if (!checkUserPermissions(requestUser, resourceUser, request.getModelId())) {
Expand All @@ -200,15 +192,15 @@ private void predict(
return;
}

model.setName((String) source.get(MODEL_NAME));
model.setVersion((Integer) source.get(MODEL_VERSION));
byte[] decoded = Base64.getDecoder().decode((String) source.get(MODEL_CONTENT));
model.setName((String) source.get(MLModel.MODEL_NAME));
model.setVersion((Integer) source.get(MLModel.MODEL_VERSION));
byte[] decoded = Base64.getDecoder().decode((String) source.get(MLModel.MODEL_CONTENT));
model.setContent(decoded);

// run predict
MLOutput output;
try {
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING);
mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.RUNNING, mlTask.isAsync());
output = MLEngine.predict(mlInput.toBuilder().inputDataset(new DataFrameInputDataset(inputDataFrame)).build(), model);
if (output instanceof MLPredictionOutput) {
((MLPredictionOutput) output).setTaskId(mlTask.getTaskId());
Expand All @@ -226,9 +218,9 @@ private void predict(

MLPredictionTaskResponse response = MLPredictionTaskResponse.builder().output(output).build();
listener.onResponse(response);
}, searchException -> {
log.error("Search model failed", searchException);
handlePredictFailure(mlTask, listener, searchException);
}, e -> {
log.error("Failed to predict model " + mlTask.getModelId(), e);
listener.onFailure(e);
}));
} else {
IllegalArgumentException e = new IllegalArgumentException("ModelId is invalid");
Expand Down
34 changes: 34 additions & 0 deletions plugin/src/main/java/org/opensearch/ml/task/MLTaskCache.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
* Modifications Copyright OpenSearch Contributors. See
* GitHub history for details.
*
*/

package org.opensearch.ml.task;

import java.util.concurrent.Semaphore;

import lombok.Builder;
import lombok.Getter;

import org.opensearch.ml.model.MLTask;

@Getter
public class MLTaskCache {
MLTask mlTask;
Semaphore updateTaskIndexSemaphore;

@Builder
public MLTaskCache(MLTask mlTask) {
this.mlTask = mlTask;
if (mlTask.isAsync()) {
updateTaskIndexSemaphore = new Semaphore(1);
}
}
}
Loading

0 comments on commit 1d5da1d

Please sign in to comment.