Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reject delete model request if model is in Training #424

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
import org.opensearch.common.xcontent.XContentParser;
import org.opensearch.common.xcontent.XContentType;
import org.opensearch.knn.index.SpaceType;
import org.opensearch.knn.index.util.KNNEngine;
import org.opensearch.knn.indices.ModelMetadata;
import org.opensearch.knn.indices.ModelState;
import org.opensearch.knn.plugin.KNNPlugin;
import org.opensearch.knn.plugin.transport.DeleteModelResponse;
import org.opensearch.rest.RestStatus;
import org.opensearch.search.SearchHit;

Expand Down Expand Up @@ -52,7 +56,7 @@ public class ModelIT extends AbstractRestartUpgradeTestCase {
private static int DOC_ID_TEST_MODEL_INDEX = 0;
private static int DOC_ID_TEST_MODEL_INDEX_DEFAULT = 0;
private static final int DELAY_MILLI_SEC = 1000;
private static final int EXP_NUM_OF_MODELS = 2;
private static final int EXP_NUM_OF_MODELS = 3;
private static final int K = 5;
private static final int NUM_DOCS = 10;
private static final int NUM_DOCS_TEST_MODEL_INDEX = 100;
Expand All @@ -63,6 +67,7 @@ public class ModelIT extends AbstractRestartUpgradeTestCase {
private static int QUERY_COUNT_TEST_MODEL_INDEX_DEFAULT = 0;
private static final String TEST_MODEL_ID = "test-model-id";
private static final String TEST_MODEL_ID_DEFAULT = "test-model-id-default";
private static final String TEST_MODEL_ID_TRAINING = "test-model-id-training";
private static final String MODEL_DESCRIPTION = "Description for train model test";

// KNN model test
Expand Down Expand Up @@ -135,12 +140,42 @@ public void testKNNModelDefault() throws Exception {
}
}

// KNN Delete Model test for model in Training State
public void testDeleteTrainingModel() throws IOException, InterruptedException {
byte[] testModelBlob = "hello".getBytes();
ModelMetadata testModelMetadata = getModelMetadata();
testModelMetadata.setState(ModelState.TRAINING);
if (isRunningAgainstOldCluster()) {
addModelToSystemIndex(TEST_MODEL_ID_TRAINING, testModelMetadata, testModelBlob);
} else {
String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, TEST_MODEL_ID_TRAINING);
Request request = new Request("DELETE", restURI);

Response response = client().performRequest(request);
assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode()));

assertEquals(3, getDocCount(MODEL_INDEX_NAME));

String responseBody = EntityUtils.toString(response.getEntity());
assertNotNull(responseBody);

Map<String, Object> responseMap = createParser(XContentType.JSON.xContent(), responseBody).map();

assertEquals(TEST_MODEL_ID_TRAINING, responseMap.get(MODEL_ID));
assertEquals("failed", responseMap.get(DeleteModelResponse.RESULT));

String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", TEST_MODEL_ID_TRAINING);
assertEquals(errorMessage, responseMap.get(DeleteModelResponse.ERROR_MSG));
}
}

// Delete Models and ".opensearch-knn-models" index to clear cluster metadata
@AfterClass
public static void wipeAllModels() throws IOException {
if (!isRunningAgainstOldCluster()) {
deleteKNNModel(TEST_MODEL_ID);
deleteKNNModel(TEST_MODEL_ID_DEFAULT);
deleteKNNModel(TEST_MODEL_ID_TRAINING);

Request request = new Request("DELETE", "/" + MODEL_INDEX_NAME);

Expand Down Expand Up @@ -241,4 +276,8 @@ public String modelIndexMapping(String fieldName, String modelId) throws IOExcep
.endObject()
);
}

private ModelMetadata getModelMetadata() {
return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", "");
}
}
200 changes: 171 additions & 29 deletions src/main/java/org/opensearch/knn/indices/ModelDao.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import com.google.common.io.Resources;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionListener;
import org.opensearch.action.DocWriteRequest;
import org.opensearch.action.DocWriteResponse;
import org.opensearch.action.FailedNodeException;
import org.opensearch.action.StepListener;
import org.opensearch.action.admin.indices.create.CreateIndexRequest;
import org.opensearch.action.admin.indices.create.CreateIndexResponse;
import org.opensearch.action.delete.DeleteAction;
Expand Down Expand Up @@ -49,14 +51,18 @@
import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction;
import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction;
import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest;

import java.io.IOException;
import java.net.URL;
import java.util.Base64;
import java.util.HashMap;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.ExecutionException;

import static java.util.Objects.isNull;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH;
import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME;
import static org.opensearch.knn.common.KNNConstants.MODEL_METADATA_FIELD;
Expand Down Expand Up @@ -122,9 +128,8 @@ public interface ModelDao {
*
* @param modelId to retrieve
* @param listener handles get model response
* @throws IOException thrown on search
*/
void get(String modelId, ActionListener<GetModelResponse> listener) throws IOException;
void get(String modelId, ActionListener<GetModelResponse> listener);

/**
* searches model from the system index. Non-blocking.
Expand All @@ -151,12 +156,24 @@ public interface ModelDao {
*/
void delete(String modelId, ActionListener<DeleteModelResponse> listener);

/**
* Check if modelId is in model graveyard or not. Non-blocking.
* A modelId is added to model graveyard before deleting that
* model and removed from it after deleting the model
*
* @param modelId to retrieve
* @return true if modelId is in model graveyard, otherwise return false
*/
boolean isModelInGraveyard(String modelId);

/**
* Implementation of ModelDao for k-NN model index
*/
final class OpenSearchKNNModelDao implements ModelDao {

public static Logger logger = LogManager.getLogger(ModelDao.class);
private static final String DELETED = "deleted";
private static final String FAILED = "failed";

private int numberOfShards;
private int numberOfReplicas;
Expand Down Expand Up @@ -356,10 +373,9 @@ public Model get(String modelId) throws ExecutionException, InterruptedException
*
* @param modelId to retrieve
* @param actionListener handles get model response
* @throws IOException thrown on search
*/
@Override
public void get(String modelId, ActionListener<GetModelResponse> actionListener) throws IOException {
public void get(String modelId, ActionListener<GetModelResponse> actionListener) {
/*
GET /<model_index>/<modelId>?_local
*/
Expand Down Expand Up @@ -427,61 +443,187 @@ private String getMapping() throws IOException {
return Resources.toString(url, Charsets.UTF_8);
}

@Override
public boolean isModelInGraveyard(String modelId) {
// Check if the objects are not null and throw a customized NullPointerException
Objects.requireNonNull(clusterService.state(), "Cluster state must not be null");
Objects.requireNonNull(clusterService.state().metadata(), "Cluster metadata must not be null");
ModelGraveyard modelGraveyard = clusterService.state().metadata().custom(ModelGraveyard.TYPE);

if (isNull(modelGraveyard)) {
jmazanec15 marked this conversation as resolved.
Show resolved Hide resolved
return false;
}

return modelGraveyard.contains(modelId);
}

@Override
public void delete(String modelId, ActionListener<DeleteModelResponse> listener) {
// If the index is not created, there is no need to delete the model
if (!isCreated()) {
logger.error("Cannot delete model \"" + modelId + "\". Model index " + MODEL_INDEX_NAME + "does not exist.");
String errorMessage = String.format("Cannot delete model \"%s\". Model index does not exist", modelId);
listener.onResponse(new DeleteModelResponse(modelId, "failed", errorMessage));
listener.onResponse(new DeleteModelResponse(modelId, FAILED, errorMessage));
return;
}

StepListener<GetModelResponse> getModelStep = new StepListener<>();
StepListener<AcknowledgedResponse> blockModelIdStep = new StepListener<>();
StepListener<AcknowledgedResponse> clearModelMetadataStep = new StepListener<>();
StepListener<DeleteResponse> deleteModelFromIndexStep = new StepListener<>();
StepListener<RemoveModelFromCacheResponse> clearModelFromCacheStep = new StepListener<>();
StepListener<AcknowledgedResponse> unblockModelIdStep = new StepListener<>();

// Get Model to check if model is in TRAINING
get(modelId, ActionListener.wrap(getModelStep::onResponse, exception -> {
if (exception instanceof ResourceNotFoundException) {
String errorMessage = String.format("Unable to delete model \"%s\". Model does not exist", modelId);
listener.onFailure(new ResourceNotFoundException(errorMessage));
return;
}
listener.onFailure(exception);
}));

getModelStep.whenComplete(getModelResponse -> {
// If model is in Training state, fail delete model request
if (ModelState.TRAINING == getModelResponse.getModel().getModelMetadata().getState()) {
String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", modelId);
listener.onResponse(new DeleteModelResponse(modelId, FAILED, errorMessage));
return;
}

// Add modelId to model graveyard until delete model request is processed
updateModelGraveyardToDelete(modelId, false, blockModelIdStep, null);
}, listener::onFailure);

// Remove the metadata asynchronously
blockModelIdStep.whenComplete(
acknowledgedResponse -> { clearModelMetadata(modelId, clearModelMetadataStep); },
listener::onFailure
);

// Setup delete model request
DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME);
deleteRequestBuilder.setId(modelId);
deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE);

// On model deletion from the index, remove the model from all nodes' model cache
ActionListener<DeleteResponse> onModelDeleteListener = ActionListener.wrap(deleteResponse -> {
// If model is not deleted, return with error message
// On model metadata removal, delete the model from the index
clearModelMetadataStep.whenComplete(
acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder),
listener::onFailure
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
);

deleteModelFromIndexStep.whenComplete(deleteResponse -> {
// If model is not deleted, remove modelId from model graveyard and return with error message
if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) {
updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, null);
String errorMessage = String.format("Model \" %s \" does not exist", modelId);
listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), errorMessage));
return;
}

// After model is deleted from the index, make sure the model is evicted from every cache in the
// cluster
client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(modelId),
ActionListener.wrap(removeModelFromCacheResponse -> {
// After model is deleted from the index, make sure the model is evicted from every cache in the cluster
removeModelFromCache(modelId, clearModelFromCacheStep);
}, e -> listener.onFailure(new OpenSearchException(e)));

if (!removeModelFromCacheResponse.hasFailures()) {
listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), null));
return;
}
clearModelFromCacheStep.whenComplete(removeModelFromCacheResponse -> {

String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse);
// If there are any failures while removing model from the cache build the error message
OpenSearchException exception = null;
if (removeModelFromCacheResponse.hasFailures()) {
String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse);
exception = new OpenSearchException(failureMessage);
}

listener.onResponse(new DeleteModelResponse(modelId, "failed", failureMessage));
// Remove modelId from model graveyard
updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, exception);

}, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage())))
);
}, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage())));
}, e -> listener.onFailure(new OpenSearchException(e)));

// On model metadata removal, delete the model from the index
ActionListener<AcknowledgedResponse> onMetadataUpdateListener = ActionListener.wrap(
acknowledgedResponse -> deleteRequestBuilder.execute(onModelDeleteListener),
listener::onFailure
unblockModelIdStep.whenComplete(acknowledgedResponse -> {
// After clearing the cache, if there are no errors return the response
listener.onResponse(new DeleteModelResponse(modelId, DELETED, null));

}, listener::onFailure);

}

// Remove model from cache in the cluster
private void removeModelFromCache(String modelId, StepListener<RemoveModelFromCacheResponse> clearModelFromCacheStep) {
client.execute(
RemoveModelFromCacheAction.INSTANCE,
new RemoveModelFromCacheRequest(modelId),
ActionListener.wrap(
clearModelFromCacheStep::onResponse,
exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, clearModelFromCacheStep)
)
);
}

// Delete model from the system index
private void deleteModelFromIndex(
String modelId,
StepListener<DeleteResponse> deleteModelFromIndexStep,
DeleteRequestBuilder deleteRequestBuilder
) {
deleteRequestBuilder.execute(
ActionListener.wrap(
deleteModelFromIndexStep::onResponse,
exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, deleteModelFromIndexStep)
)
);
}

// Update model graveyard to add/remove modelId
private void updateModelGraveyardToDelete(
String modelId,
boolean isRemoveRequest,
StepListener<AcknowledgedResponse> step,
Exception exception
naveentatikonda marked this conversation as resolved.
Show resolved Hide resolved
) {

// Remove the metadata asynchronously
client.execute(
UpdateModelGraveyardAction.INSTANCE,
new UpdateModelGraveyardRequest(modelId, isRemoveRequest),
ActionListener.wrap(acknowledgedResponse -> {
if (exception == null) {
step.onResponse(acknowledgedResponse);
return;
}
throw exception;

}, e -> {
if (exception == null) {
step.onFailure(e);
return;
}
step.onFailure(exception);
})
);
}

// Clear the metadata of the model for a given modelId
private void clearModelMetadata(String modelId, StepListener<AcknowledgedResponse> clearModelMetadataStep) {
client.execute(
UpdateModelMetadataAction.INSTANCE,
new UpdateModelMetadataRequest(modelId, true, null),
onMetadataUpdateListener
ActionListener.wrap(
clearModelMetadataStep::onResponse,
exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, clearModelMetadataStep)
)
);
}

// This function helps to remove the model from model graveyard and return the exception from previous step
// when the delete request fails while executing after adding modelId to model graveyard
private void removeModelIdFromGraveyardOnFailure(String modelId, Exception exceptionFromPreviousStep, StepListener<?> step) {
client.execute(
UpdateModelGraveyardAction.INSTANCE,
new UpdateModelGraveyardRequest(modelId, true),
ActionListener.wrap(
acknowledgedResponse -> { throw exceptionFromPreviousStep; },
unblockingFailedException -> step.onFailure(exceptionFromPreviousStep)
)
);
}

Expand Down
Loading