Skip to content

Commit

Permalink
new delete model group API
Browse files Browse the repository at this point in the history
Signed-off-by: Bhavana Ramaram <rbhavna@amazon.com>
  • Loading branch information
rbhavna committed May 22, 2023
1 parent 29a0241 commit e54e399
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 31 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import org.opensearch.action.ActionType;
import org.opensearch.action.delete.DeleteResponse;

public class MLModelGroupDeleteAction extends ActionType<DeleteResponse> {
public static final MLModelGroupDeleteAction INSTANCE = new MLModelGroupDeleteAction();
public static final String NAME = "cluster:admin/opensearch/ml/model_groups/delete";

private MLModelGroupDeleteAction() { super(NAME, DeleteResponse::new);}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.transport.model_group;

import lombok.Builder;
import lombok.Getter;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.InputStreamStreamInput;
import org.opensearch.common.io.stream.OutputStreamStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.UncheckedIOException;

import static org.opensearch.action.ValidateActions.addValidationError;

public class MLModelGroupDeleteRequest extends ActionRequest {
@Getter
String modelGroupId;

@Builder
public MLModelGroupDeleteRequest(String modelGroupId) {
this.modelGroupId = modelGroupId;
}

public MLModelGroupDeleteRequest(StreamInput input) throws IOException {
super(input);
this.modelGroupId = input.readString();
}

@Override
public void writeTo(StreamOutput output) throws IOException {
super.writeTo(output);
output.writeString(modelGroupId);
}

@Override
public ActionRequestValidationException validate() {
ActionRequestValidationException exception = null;

if (this.modelGroupId == null) {
exception = addValidationError("ML model group id can't be null", exception);
}

return exception;
}

public static MLModelGroupDeleteRequest fromActionRequest(ActionRequest actionRequest) {
if (actionRequest instanceof MLModelGroupDeleteRequest) {
return (MLModelGroupDeleteRequest)actionRequest;
}

try (ByteArrayOutputStream baos = new ByteArrayOutputStream();
OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) {
actionRequest.writeTo(osso);
try (StreamInput input = new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) {
return new MLModelGroupDeleteRequest(input);
}
} catch (IOException e) {
throw new UncheckedIOException("failed to parse ActionRequest into MLModelGroupDeleteRequest", e);
}
}
}
1 change: 1 addition & 0 deletions plugin/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ List<String> jacocoExclusions = [
'org.opensearch.ml.action.model_group.TransportUpdateModelGroupAction',
'org.opensearch.ml.action.model_group.DeleteModelGroupTransportAction',
'org.opensearch.ml.action.model_group.SearchModelGroupTransportAction',
'org.opensearch.ml.action.model_group.DeleteModelGroupTransportAction.1',
'org.opensearch.ml.rest.RestMLRegisterModelGroupAction',
'org.opensearch.ml.rest.RestMLUpdateModelGroupAction',
'org.opensearch.ml.rest.RestMLRegisterModelAction',
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.action.model_group;

import lombok.AccessLevel;
import lombok.experimental.FieldDefaults;
import lombok.extern.log4j.Log4j2;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.delete.DeleteRequest;
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.commons.authuser.User;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.ml.common.exception.MLValidationException;
import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteRequest;
import org.opensearch.ml.utils.RestActionUtils;
import org.opensearch.ml.utils.SecurityUtils;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_VALIDATE_BACKEND_ROLES;
import static org.opensearch.ml.utils.RestActionUtils.PARAMETER_MODEL_GROUP_ID;

@Log4j2
@FieldDefaults(level = AccessLevel.PRIVATE)
public class DeleteModelGroupTransportAction extends HandledTransportAction<ActionRequest, DeleteResponse> {

Client client;
NamedXContentRegistry xContentRegistry;
ClusterService clusterService;

private volatile boolean filterByEnabled;

@Inject
public DeleteModelGroupTransportAction(
TransportService transportService,
ActionFilters actionFilters,
Client client,
NamedXContentRegistry xContentRegistry,
Settings settings,
ClusterService clusterService
) {
super(MLModelGroupDeleteAction.NAME, transportService, actionFilters, MLModelGroupDeleteRequest::new);
this.client = client;
this.xContentRegistry = xContentRegistry;
this.clusterService = clusterService;
filterByEnabled = ML_COMMONS_VALIDATE_BACKEND_ROLES.get(settings);
clusterService.getClusterSettings().addSettingsUpdateConsumer(ML_COMMONS_VALIDATE_BACKEND_ROLES, it -> filterByEnabled = it);
}

@Override
protected void doExecute(Task task, ActionRequest request, ActionListener<DeleteResponse> actionListener) {
MLModelGroupDeleteRequest mlModelGroupDeleteRequest = MLModelGroupDeleteRequest.fromActionRequest(request);
String modelGroupId = mlModelGroupDeleteRequest.getModelGroupId();
DeleteRequest deleteRequest = new DeleteRequest(ML_MODEL_GROUP_INDEX, modelGroupId);
User user = RestActionUtils.getUserContext(client);

try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
SecurityUtils.validateModelGroupAccess(user, modelGroupId, client, ActionListener.wrap(access -> {
if ((filterByEnabled) && (!access)) {
actionListener
.onFailure(new MLValidationException("User Doesn't have previlege to perform this operation"));
} else {
BoolQueryBuilder query = new BoolQueryBuilder();
query.filter(new TermQueryBuilder(PARAMETER_MODEL_GROUP_ID, modelGroupId));
log.info(query.toString());

SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(query);
SearchRequest searchRequest = new SearchRequest(ML_MODEL_INDEX).source(searchSourceBuilder);
client.search(searchRequest, ActionListener.wrap(mlModels -> {
if (mlModels == null || mlModels.getHits().getTotalHits() == null || mlModels.getHits().getTotalHits().value == 0) {
client.delete(deleteRequest, new ActionListener<DeleteResponse>() {
@Override
public void onResponse(DeleteResponse deleteResponse) {
log.debug("Completed Delete Model Group Request, task id:{} deleted", modelGroupId);
actionListener.onResponse(deleteResponse);
}

@Override
public void onFailure(Exception e) {
log.error("Failed to delete ML Model Group " + modelGroupId, e);
actionListener.onFailure(e);
}
});
} else {
throw new MLValidationException("Cannot delete the model group when it has associated model versions");
}

}, e -> {
log.error("Failed to search models with the specified Model Group Id " + modelGroupId, e);
actionListener.onFailure(e);
}));
}
}, e -> {
log.error("Failed to validate Access for Model Group " + modelGroupId, e);
actionListener.onFailure(e);
}));
} catch (Exception e) {
log.error("Failed to delete ml model group" + modelGroupId, e);
actionListener.onFailure(e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,7 @@

package org.opensearch.ml.action.model_group;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_VALIDATE_BACKEND_ROLES;

import java.time.Instant;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;

import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.index.IndexRequest;
Expand Down Expand Up @@ -44,6 +37,12 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;

import java.time.Instant;
import java.util.stream.Collectors;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_VALIDATE_BACKEND_ROLES;

@Log4j2
public class TransportRegisterModelGroupAction extends HandledTransportAction<ActionRequest, MLRegisterModelGroupResponse> {

Expand Down Expand Up @@ -101,6 +100,9 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
MLModelGroup mlModelGroup;
try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) {
if (filterByEnabled && user != null) {
if (isInvalidRequest(input)){
throw new IllegalArgumentException("User cannot specify backend roles to a public/private model grouo");
}
if (Boolean.TRUE.equals(input.getIsPublic())) {
builder = builder.access(MLModelGroup.PUBLIC);
} else if (Boolean.TRUE.equals(input.getIsAddAllBackendRoles())) {
Expand Down Expand Up @@ -171,4 +173,18 @@ public void createModelGroup(MLRegisterModelGroupInput input, ActionListener<Str
listener.onFailure(e);
}
}


public static boolean isInvalidRequest(MLRegisterModelGroupInput input) {
Boolean isPublic = input.getIsPublic() == null ? false : input.getIsPublic();
Boolean isAddAllBackendRoles = input.getIsAddAllBackendRoles() == null ? false : input.getIsAddAllBackendRoles();
Boolean isBackendRoles = !CollectionUtils.isEmpty(input.getBackendRoles());
if (isPublic) {
return isAddAllBackendRoles || isBackendRoles;
}
if (isAddAllBackendRoles) {
return isBackendRoles;
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,7 @@

package org.opensearch.ml.action.model_group;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_VALIDATE_BACKEND_ROLES;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

import java.util.Map;
import java.util.stream.Collectors;

import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.StringUtils;
import org.opensearch.action.ActionListener;
import org.opensearch.action.ActionRequest;
Expand Down Expand Up @@ -41,6 +33,13 @@
import org.opensearch.tasks.Task;
import org.opensearch.transport.TransportService;

import java.util.Map;
import java.util.stream.Collectors;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_GROUP_INDEX;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_VALIDATE_BACKEND_ROLES;
import static org.opensearch.ml.utils.MLExceptionUtils.logException;

@Log4j2
public class TransportUpdateModelGroupAction extends HandledTransportAction<ActionRequest, MLUpdateModelGroupResponse> {

Expand Down Expand Up @@ -164,4 +163,5 @@ private void updateModelGroup(
);

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,8 @@

package org.opensearch.ml.plugin;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;

import com.google.common.collect.ImmutableList;
import lombok.SneakyThrows;

import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionResponse;
import org.opensearch.client.Client;
Expand All @@ -38,6 +28,7 @@
import org.opensearch.ml.action.execute.TransportExecuteTaskAction;
import org.opensearch.ml.action.forward.TransportForwardAction;
import org.opensearch.ml.action.handler.MLSearchHandler;
import org.opensearch.ml.action.model_group.DeleteModelGroupTransportAction;
import org.opensearch.ml.action.model_group.SearchModelGroupTransportAction;
import org.opensearch.ml.action.model_group.TransportRegisterModelGroupAction;
import org.opensearch.ml.action.model_group.TransportUpdateModelGroupAction;
Expand Down Expand Up @@ -86,6 +77,7 @@
import org.opensearch.ml.common.transport.model.MLModelDeleteAction;
import org.opensearch.ml.common.transport.model.MLModelGetAction;
import org.opensearch.ml.common.transport.model.MLModelSearchAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupDeleteAction;
import org.opensearch.ml.common.transport.model_group.MLModelGroupSearchAction;
import org.opensearch.ml.common.transport.model_group.MLRegisterModelGroupAction;
import org.opensearch.ml.common.transport.model_group.MLUpdateModelGroupAction;
Expand All @@ -112,6 +104,7 @@
import org.opensearch.ml.model.MLModelCacheHelper;
import org.opensearch.ml.model.MLModelManager;
import org.opensearch.ml.rest.RestMLDeleteModelAction;
import org.opensearch.ml.rest.RestMLDeleteModelGroupAction;
import org.opensearch.ml.rest.RestMLDeleteTaskAction;
import org.opensearch.ml.rest.RestMLDeployModelAction;
import org.opensearch.ml.rest.RestMLExecuteAction;
Expand Down Expand Up @@ -158,7 +151,15 @@
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import com.google.common.collect.ImmutableList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;

import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;

public class MachineLearningPlugin extends Plugin implements ActionPlugin {
public static final String ML_THREAD_POOL_PREFIX = "thread_pool.ml_commons.";
Expand Down Expand Up @@ -222,7 +223,8 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin {
new ActionHandler<>(MLSyncUpAction.INSTANCE, TransportSyncUpOnNodeAction.class),
new ActionHandler<>(MLRegisterModelGroupAction.INSTANCE, TransportRegisterModelGroupAction.class),
new ActionHandler<>(MLUpdateModelGroupAction.INSTANCE, TransportUpdateModelGroupAction.class),
new ActionHandler<>(MLModelGroupSearchAction.INSTANCE, SearchModelGroupTransportAction.class)
new ActionHandler<>(MLModelGroupSearchAction.INSTANCE, SearchModelGroupTransportAction.class),
new ActionHandler<>(MLModelGroupDeleteAction.INSTANCE, DeleteModelGroupTransportAction.class)
);
}

Expand Down Expand Up @@ -433,6 +435,7 @@ public List<RestHandler> getRestHandlers(
RestMLRegisterModelGroupAction restMLCreateModelGroupAction = new RestMLRegisterModelGroupAction();
RestMLUpdateModelGroupAction restMLUpdateModelGroupAction = new RestMLUpdateModelGroupAction();
RestMLSearchModelGroupAction restMLSearchModelGroupAction = new RestMLSearchModelGroupAction();
RestMLDeleteModelGroupAction restMLDeleteModelGroupAction = new RestMLDeleteModelGroupAction();
return ImmutableList
.of(
restMLStatsAction,
Expand All @@ -454,7 +457,8 @@ public List<RestHandler> getRestHandlers(
restMLUploadModelChunkAction,
restMLCreateModelGroupAction,
restMLUpdateModelGroupAction,
restMLSearchModelGroupAction
restMLSearchModelGroupAction,
restMLDeleteModelGroupAction
);
}

Expand Down
Loading

0 comments on commit e54e399

Please sign in to comment.