Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reject delete model request if model is in Training #424

148 changes: 119 additions & 29 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.StepListener;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.delete.DeleteAction;
Expand All @@ -42,11 +43,14 @@
import org.opensearch.common.settings.Settings;
import org.opensearch.index.IndexNotFoundException;
import org.opensearch.knn.common.KNNConstants;
import org.opensearch.knn.plugin.BlockedModelIds;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.knn.plugin.transport.GetModelResponse;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest;
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse;
import org.opensearch.knn.plugin.transport.UpdateBlockedModelAction;
import org.opensearch.knn.plugin.transport.UpdateBlockedModelRequest;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;

Expand Down Expand Up @@ -122,9 +126,8 @@ public interface ModelDao {
*
* @param modelId to retrieve
* @param listener handles get model response
* @throws IOException thrown on search
*/
void get(String modelId, ActionListener<GetModelResponse> listener) throws IOException;
void get(String modelId, ActionListener<GetModelResponse> listener);

/**
* searches model from the system index. Non-blocking.
Expand All @@ -151,6 +154,14 @@ public interface ModelDao {
*/
void delete(String modelId, ActionListener<DeleteModelResponse> listener);

/**
* Check if modelId is in blocked modelIds list. Non-blocking.
*
* @param modelId to retrieve
* @return true if modelId is in blocked list, otherwise return false
*/
boolean isModelBlocked(String modelId);
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved

/**
* Implementation of ModelDao for k-NN model index
*/
Expand Down Expand Up @@ -356,10 +367,9 @@ public Model get(String modelId) throws ExecutionException, InterruptedException
*
* @param modelId to retrieve
* @param actionListener handles get model response
* @throws IOException thrown on search
*/
@Override
public void get(String modelId, ActionListener<GetModelResponse> actionListener) throws IOException {
public void get(String modelId, ActionListener<GetModelResponse> actionListener) {
/*
GET /<model_index>/<modelId>?_local
*/
Expand Down Expand Up @@ -427,6 +437,16 @@ private String getMapping() throws IOException {
return Resources.toString(url, Charsets.UTF_8);
}

// Check if the modelId is added to blocked list or not
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
public boolean isModelBlocked(String modelId) {
BlockedModelIds blockedModelIds = clusterService.state().metadata().custom(BlockedModelIds.TYPE);
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
if (blockedModelIds == null) {
return false;
}

return blockedModelIds.contains(modelId);
}

@Override
public void delete(String modelId, ActionListener<DeleteModelResponse> listener) {
// If the index is not created, there is no need to delete the model
Expand All @@ -437,51 +457,121 @@ public void delete(String modelId, ActionListener<DeleteModelResponse> listener)
return;
}

StepListener<GetModelResponse> getModelStep = new StepListener<>();
StepListener<AcknowledgedResponse> blockModelIdStep = new StepListener<>();
StepListener<AcknowledgedResponse> clearModelMetadataStep = new StepListener<>();
StepListener<DeleteResponse> deleteModelFromIndexStep = new StepListener<>();
StepListener<RemoveModelFromCacheResponse> clearModelFromCacheStep = new StepListener<>();
StepListener<AcknowledgedResponse> unblockModelIdStep = new StepListener<>();

// Get Model to check if model is in TRAINING
get(modelId, ActionListener.wrap(getModelStep::onResponse, exception -> {
if (exception instanceof ResourceNotFoundException) {
listener.onResponse(new DeleteModelResponse(modelId, "failed", exception.getMessage()));
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
return;
}
listener.onFailure(exception);
}));

getModelStep.whenComplete(getModelResponse -> {
// If model is in Training state, fail delete model request
if (ModelState.TRAINING.equals(getModelResponse.getModel().getModelMetadata().getState())) {
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
logger.error("Cannot delete model \"" + modelId + "\". Model is still in training.");
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", modelId);
listener.onResponse(new DeleteModelResponse(modelId, "failed", errorMessage));
return;
}

// Add modelId to blocked list until delete model request is processed
client.execute(
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
UpdateBlockedModelAction.INSTANCE,
new UpdateBlockedModelRequest(modelId, false),
ActionListener.wrap(blockModelIdStep::onResponse, blockModelIdStep::onFailure)
);

}, listener::onFailure);

// Remove the metadata asynchronously
blockModelIdStep.whenComplete(acknowledgedResponse -> {
client.execute(
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
UpdateModelMetadataAction.INSTANCE,
new UpdateModelMetadataRequest(modelId, true, null),
ActionListener.wrap(
clearModelMetadataStep::onResponse,
exception -> unblockModelIdOnFailure(modelId, exception, clearModelMetadataStep)
)
);
}, listener::onFailure);

// Setup delete model request
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model deletion from the index, remove the model from all nodes' model cache
ActionListener<DeleteResponse> onModelDeleteListener = ActionListener.wrap(deleteResponse -> {
// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteRequestBuilder.execute(
ActionListener.wrap(
deleteModelFromIndexStep::onResponse,
exception -> unblockModelIdOnFailure(modelId, exception, deleteModelFromIndexStep)
)
),
listener::onFailure
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
);

deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, return with error message
if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), errorMessage));
return;
}

// After model is deleted from the index, make sure the model is evicted from every cache in the
// cluster
// After model is deleted from the index, make sure the model is evicted from every cache in the cluster
client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(modelId),
ActionListener.wrap(removeModelFromCacheResponse -> {

if (!removeModelFromCacheResponse.hasFailures()) {
listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), null));
return;
}

String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse);

listener.onResponse(new DeleteModelResponse(modelId, "failed", failureMessage));

}, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage())))
ActionListener.wrap(
clearModelFromCacheStep::onResponse,
exception -> unblockModelIdOnFailure(modelId, exception, clearModelFromCacheStep)
)
);
}, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage())));
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved

// On model metadata removal, delete the model from the index
ActionListener<AcknowledgedResponse> onMetadataUpdateListener = ActionListener.wrap(
acknowledgedResponse -> deleteRequestBuilder.execute(onModelDeleteListener),
listener::onFailure
);
clearModelFromCacheStep.whenComplete(removeModelFromCacheResponse -> {
// After clearing the cache, if there are no errors remove modelId from blocked list
if (!removeModelFromCacheResponse.hasFailures()) {
client.execute(
UpdateBlockedModelAction.INSTANCE,
new UpdateBlockedModelRequest(modelId, true),
ActionListener.wrap(unblockModelIdStep::onResponse, unblockModelIdStep::onFailure)
);
} else {
String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse);
listener.onResponse(new DeleteModelResponse(modelId, "failed", failureMessage));
}
}, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage())));

// Remove the metadata asynchronously
// After unblocking modelId return the response
unblockModelIdStep.whenComplete(acknowledgedResponse -> {
listener.onResponse(new DeleteModelResponse(modelId, "deleted", null));
return;
}, listener::onFailure);
}

// This function helps to remove the model from blocked list when the delete request fails
// while executing after adding modelId to blocked list
private void unblockModelIdOnFailure(String modelId, Exception exceptionFromPreviousStep, StepListener<?> step) {
// If modelId is unblocked successfully, then we will just return the exception received from failed stepListener
// If unblocking modelId request fails, then we will return this exception along with the one received from failed stepListener
client.execute(
UpdateModelMetadataAction.INSTANCE,
new UpdateModelMetadataRequest(modelId, true, null),
onMetadataUpdateListener
UpdateBlockedModelAction.INSTANCE,
new UpdateBlockedModelRequest(modelId, true),
ActionListener.wrap(acknowledgedResponse -> step.onFailure(exceptionFromPreviousStep), unblockingFailedException -> {
String errorMsg = exceptionFromPreviousStep.getMessage();
errorMsg = errorMsg + "\n" + unblockingFailedException.getMessage();
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
step.onFailure(new Exception(errorMsg));
})
);
}

Expand Down
137 changes: 137 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/BlockedModelIds.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.plugin;

import org.apache.logging.log4j.LogManager;
import org.opensearch.Version;
import org.opensearch.cluster.Diff;
import org.opensearch.cluster.NamedDiff;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentParser;
import org.apache.logging.log4j.Logger;

import java.io.IOException;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.List;
import java.util.stream.Collectors;

public class BlockedModelIds implements Metadata.Custom {
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved

public static Logger logger = LogManager.getLogger(BlockedModelIds.class);
public static final String TYPE = "opensearch-knn-blocked-models";

List<String> blockedModelIds;
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved

public BlockedModelIds(List<String> blockedModelIds) {
this.blockedModelIds = blockedModelIds;
}

@Override
public EnumSet<Metadata.XContentContext> context() {
return Metadata.ALL_CONTEXTS;
}

@Override
public Diff<Metadata.Custom> diff(Metadata.Custom custom) {
return null;
}

@Override
public String getWriteableName() {
return TYPE;
}

@Override
public Version getMinimalSupportedVersion() {
return Version.CURRENT.minimumCompatibilityVersion();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(blockedModelIds.size());
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
for (String modelId : blockedModelIds) {
out.writeString(modelId);
}
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
return null;
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
}

public BlockedModelIds(StreamInput in) throws IOException {
int size = in.readVInt();
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
List<String> modelIds = new ArrayList<>(size);
for (int i = 0; i < size; i++) {
String modelId = in.readString();
modelIds.add(modelId);
}
this.blockedModelIds = modelIds;
}

public List<String> getBlockedModelIds() {
return blockedModelIds;
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
}

public void remove(String modelId) {
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
if (blockedModelIds.contains(modelId)) {
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
blockedModelIds.remove(modelId);
}
}

public void add(String modelId) {
blockedModelIds.add(modelId);
}

public int size() {
return blockedModelIds.size();
}

public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException {
return new BlockedModelIdsDiff(streamInput);
}

public static BlockedModelIds fromXContent(XContentParser xContentParser) throws IOException {
List<String> modelIds = xContentParser.list().stream().map(obj -> obj.toString()).collect(Collectors.toList());
return new BlockedModelIds(modelIds);
}

public boolean contains(String modelId) {
return blockedModelIds.contains(modelId);
}

static class BlockedModelIdsDiff implements NamedDiff<Metadata.Custom> {
private List<String> added;
private int removedCount;

public BlockedModelIdsDiff(StreamInput inp) throws IOException {
added = inp.readList((streamInput -> streamInput.toString()));
removedCount = inp.readVInt();
}

@Override
public Metadata.Custom apply(Metadata.Custom custom) {
return null;
}

@Override
public String getWriteableName() {
return TYPE;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeVInt(removedCount);
for (String modelId : added) {
out.writeString(modelId);
}
}
}
}
Loading