From c788b5725c1dbc3e55b7b5394fb5c9fcd8a2a178 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 16 Jun 2022 17:06:31 -0500 Subject: [PATCH 1/9] Reject Delete Model Request if Model is in Training Signed-off-by: Naveen Tatikonda --- .../org/opensearch/knn/indices/ModelDao.java | 147 ++++++++++++++---- .../transport/GetModelTransportAction.java | 10 +- .../action/RestDeleteModelHandlerIT.java | 38 ++++- 3 files changed, 156 insertions(+), 39 deletions(-) diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 937e5f811..77c27034e 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -20,6 +20,7 @@ import org.opensearch.action.DocWriteRequest; import org.opensearch.action.DocWriteResponse; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.StepListener; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.delete.DeleteAction; @@ -42,11 +43,14 @@ import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.plugin.BlockedModelIds; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse; +import org.opensearch.knn.plugin.transport.UpdateBlockedModelAction; +import org.opensearch.knn.plugin.transport.UpdateBlockedModelRequest; import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; @@ -122,9 +126,8 @@ public interface ModelDao { * * @param modelId to retrieve * @param listener handles get model response - * @throws IOException thrown on search */ - void get(String modelId, ActionListener listener) throws IOException; + void get(String modelId, ActionListener listener); /** * searches model from the system index. Non-blocking. @@ -151,6 +154,14 @@ public interface ModelDao { */ void delete(String modelId, ActionListener listener); + /** + * Check if modelId is in blocked modelIds list. Non-blocking. + * + * @param modelId to retrieve + * @return true if modelId is in blocked list, otherwise return false + */ + boolean isModelBlocked(String modelId); + /** * Implementation of ModelDao for k-NN model index */ @@ -359,7 +370,7 @@ public Model get(String modelId) throws ExecutionException, InterruptedException * @throws IOException thrown on search */ @Override - public void get(String modelId, ActionListener actionListener) throws IOException { + public void get(String modelId, ActionListener actionListener) { /* GET //?_local */ @@ -427,6 +438,16 @@ private String getMapping() throws IOException { return Resources.toString(url, Charsets.UTF_8); } + // Check if the modelId is added to blocked list or not + public boolean isModelBlocked(String modelId) { + BlockedModelIds blockedModelIds = clusterService.state().metadata().custom(BlockedModelIds.TYPE); + if (blockedModelIds == null) { + return false; + } + + return blockedModelIds.contains(modelId); + } + @Override public void delete(String modelId, ActionListener listener) { // If the index is not created, there is no need to delete the model @@ -437,13 +458,69 @@ public void delete(String modelId, ActionListener listener) return; } + StepListener getModelStep = new StepListener<>(); + StepListener blockModelIdStep = new StepListener<>(); + StepListener clearModelMetadataStep = new StepListener<>(); + StepListener deleteModelFromIndexStep = new StepListener<>(); + StepListener clearModelFromCacheStep = new StepListener<>(); + StepListener unblockModelIdStep = new StepListener<>(); + + // Get Model to check if model is in TRAINING + get(modelId, ActionListener.wrap(getModelStep::onResponse, exception -> { + if (exception instanceof ResourceNotFoundException) { + listener.onResponse(new DeleteModelResponse(modelId, "failed", exception.getMessage())); + return; + } + listener.onFailure(exception); + })); + + getModelStep.whenComplete(getModelResponse -> { + // If model is in Training state, fail delete model request + if (ModelState.TRAINING.equals(getModelResponse.getModel().getModelMetadata().getState())) { + logger.error("Cannot delete model \"" + modelId + "\". Model is still in training."); + String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", modelId); + listener.onResponse(new DeleteModelResponse(modelId, "failed", errorMessage)); + return; + } + + // Add modelId to blocked list until delete model request is processed + client.execute( + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest(modelId, false), + ActionListener.wrap(blockModelIdStep::onResponse, blockModelIdStep::onFailure) + ); + + }, listener::onFailure); + + // Remove the metadata asynchronously + blockModelIdStep.whenComplete(acknowledgedResponse -> { + client.execute( + UpdateModelMetadataAction.INSTANCE, + new UpdateModelMetadataRequest(modelId, true, null), + ActionListener.wrap( + clearModelMetadataStep::onResponse, + exception -> unblockModelIdOnFailure(modelId, exception, clearModelMetadataStep) + ) + ); + }, listener::onFailure); + // Setup delete model request DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME); deleteRequestBuilder.setId(modelId); deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); - // On model deletion from the index, remove the model from all nodes' model cache - ActionListener onModelDeleteListener = ActionListener.wrap(deleteResponse -> { + // On model metadata removal, delete the model from the index + clearModelMetadataStep.whenComplete( + acknowledgedResponse -> deleteRequestBuilder.execute( + ActionListener.wrap( + deleteModelFromIndexStep::onResponse, + exception -> unblockModelIdOnFailure(modelId, exception, deleteModelFromIndexStep) + ) + ), + listener::onFailure + ); + + deleteModelFromIndexStep.whenComplete(deleteResponse -> { // If model is not deleted, return with error message if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) { String errorMessage = String.format("Model \" %s \" does not exist", modelId); @@ -451,37 +528,51 @@ public void delete(String modelId, ActionListener listener) return; } - // After model is deleted from the index, make sure the model is evicted from every cache in the - // cluster + // After model is deleted from the index, make sure the model is evicted from every cache in the cluster client.execute( RemoveModelFromCacheAction.INSTANCE, new RemoveModelFromCacheRequest(modelId), - ActionListener.wrap(removeModelFromCacheResponse -> { - - if (!removeModelFromCacheResponse.hasFailures()) { - listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), null)); - return; - } - - String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse); - - listener.onResponse(new DeleteModelResponse(modelId, "failed", failureMessage)); - - }, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage()))) + ActionListener.wrap( + clearModelFromCacheStep::onResponse, + exception -> unblockModelIdOnFailure(modelId, exception, clearModelFromCacheStep) + ) ); }, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage()))); - // On model metadata removal, delete the model from the index - ActionListener onMetadataUpdateListener = ActionListener.wrap( - acknowledgedResponse -> deleteRequestBuilder.execute(onModelDeleteListener), - listener::onFailure - ); + clearModelFromCacheStep.whenComplete(removeModelFromCacheResponse -> { + // After clearing the cache, if there are no errors remove modelId from blocked list + if (!removeModelFromCacheResponse.hasFailures()) { + client.execute( + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest(modelId, true), + ActionListener.wrap(unblockModelIdStep::onResponse, unblockModelIdStep::onFailure) + ); + } else { + String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse); + listener.onResponse(new DeleteModelResponse(modelId, "failed", failureMessage)); + } + }, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage()))); - // Remove the metadata asynchronously + // After unblocking modelId return the response + unblockModelIdStep.whenComplete(acknowledgedResponse -> { + listener.onResponse(new DeleteModelResponse(modelId, "deleted", null)); + return; + }, listener::onFailure); + } + + // This function helps to remove the model from blocked list when the delete request fails + // while executing after adding modelId to blocked list + private void unblockModelIdOnFailure(String modelId, Exception exceptionFromPreviousStep, StepListener step) { + // If modelId is unblocked successfully, then we will just return the exception received from failed stepListener + // If unblocking modelId request fails, then we will return this exception along with the one received from failed stepListener client.execute( - UpdateModelMetadataAction.INSTANCE, - new UpdateModelMetadataRequest(modelId, true, null), - onMetadataUpdateListener + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest(modelId, true), + ActionListener.wrap(acknowledgedResponse -> step.onFailure(exceptionFromPreviousStep), unblockingFailedException -> { + String errorMsg = exceptionFromPreviousStep.getMessage(); + errorMsg = errorMsg + "\n" + unblockingFailedException.getMessage(); + step.onFailure(new Exception(errorMsg)); + }) ); } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java index 05cb12742..e47a42d8d 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/GetModelTransportAction.java @@ -20,8 +20,6 @@ import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; -import java.io.IOException; - /** * Transport Action for {@link GetModelAction} */ @@ -38,10 +36,8 @@ public GetModelTransportAction(TransportService transportService, ActionFilters @Override protected void doExecute(Task task, GetModelRequest request, ActionListener actionListener) { String modelID = request.getModelID(); - try { - modelDao.get(modelID, actionListener); - } catch (IOException e) { - actionListener.onFailure(e); - } + + modelDao.get(modelID, actionListener); + } } diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java index edd8d2106..3cb70ca86 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java @@ -12,7 +12,6 @@ package org.opensearch.knn.plugin.action; import org.apache.http.util.EntityUtils; -import org.opensearch.action.DocWriteResponse; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentType; @@ -28,8 +27,8 @@ import java.io.IOException; import java.util.Map; -import static org.opensearch.knn.common.KNNConstants.MODELS; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; +import static org.opensearch.knn.common.KNNConstants.MODELS; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; /** @@ -57,7 +56,37 @@ public void testDeleteModelExists() throws IOException { Response response = client().performRequest(request); assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - assertEquals(getDocCount(MODEL_INDEX_NAME), 0); + assertEquals(0, getDocCount(MODEL_INDEX_NAME)); + } + + public void testDeleteTrainingModel() throws IOException { + createModelSystemIndex(); + String testModelID = "test-model-id"; + byte[] testModelBlob = "hello".getBytes(); + ModelMetadata testModelMetadata = getModelMetadata(); + testModelMetadata.setState(ModelState.TRAINING); + + addModelToSystemIndex(testModelID, testModelMetadata, testModelBlob); + assertEquals(1, getDocCount(MODEL_INDEX_NAME)); + + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, testModelID); + Request request = new Request("DELETE", restURI); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + assertEquals(1, getDocCount(MODEL_INDEX_NAME)); + + String responseBody = EntityUtils.toString(response.getEntity()); + assertNotNull(responseBody); + + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + + assertEquals(testModelID, responseMap.get(MODEL_ID)); + assertEquals("failed", responseMap.get(DeleteModelResponse.RESULT)); + + String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", testModelID); + assertEquals(errorMessage, responseMap.get(DeleteModelResponse.ERROR_MSG)); } public void testDeleteModelFailsInvalid() throws IOException { @@ -73,7 +102,8 @@ public void testDeleteModelFailsInvalid() throws IOException { Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); assertEquals("invalid-model-id", responseMap.get(MODEL_ID)); - assertEquals(DocWriteResponse.Result.NOT_FOUND.getLowercase(), responseMap.get(DeleteModelResponse.RESULT)); + assertEquals("failed", responseMap.get(DeleteModelResponse.RESULT)); assertNotNull(responseMap.get(DeleteModelResponse.ERROR_MSG)); } + } From 8830890e209cb324a9d128916ad81b16602d1ff7 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 16 Jun 2022 17:17:24 -0500 Subject: [PATCH 2/9] Locking mechanism to block modelId if model is in the process of Deletion Signed-off-by: Naveen Tatikonda --- .../org/opensearch/knn/indices/ModelDao.java | 1 - .../knn/plugin/BlockedModelIds.java | 137 +++++++++ .../org/opensearch/knn/plugin/KNNPlugin.java | 48 ++- .../transport/UpdateBlockedModelAction.java | 29 ++ .../transport/UpdateBlockedModelRequest.java | 84 ++++++ .../UpdateBlockedModelTransportAction.java | 153 ++++++++++ .../opensearch/knn/indices/ModelDaoTests.java | 276 +++++++++++++++++- .../knn/plugin/BlockedModelIdsTests.java | 60 ++++ .../UpdateBlockedModelRequestTests.java | 58 ++++ ...pdateBlockedModelTransportActionTests.java | 111 +++++++ 10 files changed, 946 insertions(+), 11 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/plugin/BlockedModelIds.java create mode 100644 src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelAction.java create mode 100644 src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java create mode 100644 src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java create mode 100644 src/test/java/org/opensearch/knn/plugin/BlockedModelIdsTests.java create mode 100644 src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequestTests.java create mode 100644 src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 77c27034e..19815d865 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -367,7 +367,6 @@ public Model get(String modelId) throws ExecutionException, InterruptedException * * @param modelId to retrieve * @param actionListener handles get model response - * @throws IOException thrown on search */ @Override public void get(String modelId, ActionListener actionListener) { diff --git a/src/main/java/org/opensearch/knn/plugin/BlockedModelIds.java b/src/main/java/org/opensearch/knn/plugin/BlockedModelIds.java new file mode 100644 index 000000000..8d093cb67 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/BlockedModelIds.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin; + +import org.apache.logging.log4j.LogManager; +import org.opensearch.Version; +import org.opensearch.cluster.Diff; +import org.opensearch.cluster.NamedDiff; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; +import org.apache.logging.log4j.Logger; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.EnumSet; +import java.util.List; +import java.util.stream.Collectors; + +public class BlockedModelIds implements Metadata.Custom { + + public static Logger logger = LogManager.getLogger(BlockedModelIds.class); + public static final String TYPE = "opensearch-knn-blocked-models"; + + List blockedModelIds; + + public BlockedModelIds(List blockedModelIds) { + this.blockedModelIds = blockedModelIds; + } + + @Override + public EnumSet context() { + return Metadata.ALL_CONTEXTS; + } + + @Override + public Diff diff(Metadata.Custom custom) { + return null; + } + + @Override + public String getWriteableName() { + return TYPE; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.CURRENT.minimumCompatibilityVersion(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(blockedModelIds.size()); + for (String modelId : blockedModelIds) { + out.writeString(modelId); + } + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return null; + } + + public BlockedModelIds(StreamInput in) throws IOException { + int size = in.readVInt(); + List modelIds = new ArrayList<>(size); + for (int i = 0; i < size; i++) { + String modelId = in.readString(); + modelIds.add(modelId); + } + this.blockedModelIds = modelIds; + } + + public List getBlockedModelIds() { + return blockedModelIds; + } + + public void remove(String modelId) { + if (blockedModelIds.contains(modelId)) { + blockedModelIds.remove(modelId); + } + } + + public void add(String modelId) { + blockedModelIds.add(modelId); + } + + public int size() { + return blockedModelIds.size(); + } + + public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException { + return new BlockedModelIdsDiff(streamInput); + } + + public static BlockedModelIds fromXContent(XContentParser xContentParser) throws IOException { + List modelIds = xContentParser.list().stream().map(obj -> obj.toString()).collect(Collectors.toList()); + return new BlockedModelIds(modelIds); + } + + public boolean contains(String modelId) { + return blockedModelIds.contains(modelId); + } + + static class BlockedModelIdsDiff implements NamedDiff { + private List added; + private int removedCount; + + public BlockedModelIdsDiff(StreamInput inp) throws IOException { + added = inp.readList((streamInput -> streamInput.toString())); + removedCount = inp.readVInt(); + } + + @Override + public Metadata.Custom apply(Metadata.Custom custom) { + return null; + } + + @Override + public String getWriteableName() { + return TYPE; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeVInt(removedCount); + for (String modelId : added) { + out.writeString(modelId); + } + } + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 9cf0696f2..9a6e13933 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -5,6 +5,11 @@ package org.opensearch.knn.plugin; +import org.opensearch.cluster.NamedDiff; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.common.ParseField; +import org.opensearch.common.io.stream.NamedWriteable; +import org.opensearch.common.io.stream.Writeable; import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.knn.index.KNNCircuitBreaker; @@ -65,6 +70,8 @@ import org.opensearch.knn.plugin.transport.TrainingModelAction; import org.opensearch.knn.plugin.transport.TrainingModelRequest; import org.opensearch.knn.plugin.transport.TrainingModelTransportAction; +import org.opensearch.knn.plugin.transport.UpdateBlockedModelAction; +import org.opensearch.knn.plugin.transport.UpdateBlockedModelTransportAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; import org.opensearch.knn.training.TrainingJobRunner; @@ -88,6 +95,7 @@ import org.opensearch.watcher.ResourceWatcherService; import java.util.Arrays; +import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.List; @@ -235,7 +243,8 @@ public List getRestHandlers( new ActionHandler<>(TrainingJobRouterAction.INSTANCE, TrainingJobRouterTransportAction.class), new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), - new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class) + new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class), + new ActionHandler<>(UpdateBlockedModelAction.INSTANCE, UpdateBlockedModelTransportAction.class) ); } @@ -293,4 +302,41 @@ public ScriptEngine getScriptEngine(Settings settings, Collection> getExecutorBuilders(Settings settings) { return ImmutableList.of(new FixedExecutorBuilder(settings, TRAIN_THREAD_POOL, 1, 1, KNN_THREAD_POOL_PREFIX, false)); } + + @Override + public List getNamedWriteables() { + List entries = new ArrayList<>(); + registerMetadataCustom(entries, BlockedModelIds.TYPE, BlockedModelIds::new, BlockedModelIds::readDiffFrom); + return entries; + } + + @Override + public List getNamedXContent() { + List entries = new ArrayList<>(); + + entries.add( + new NamedXContentRegistry.Entry(Metadata.Custom.class, new ParseField(BlockedModelIds.TYPE), BlockedModelIds::fromXContent) + ); + return entries; + } + + private static void registerMetadataCustom( + List entries, + String name, + Writeable.Reader reader, + Writeable.Reader diffReader + ) { + registerCustom(entries, Metadata.Custom.class, name, reader, diffReader); + } + + private static void registerCustom( + List entries, + Class category, + String name, + Writeable.Reader reader, + Writeable.Reader diffReader + ) { + entries.add(new NamedWriteableRegistry.Entry(category, name, reader)); + entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, name, diffReader)); + } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelAction.java new file mode 100644 index 000000000..1b19db3fd --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelAction.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionType; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.common.io.stream.Writeable; + +/** + * Action to update blocked modelIds list + */ +public class UpdateBlockedModelAction extends ActionType { + + public static final String NAME = "cluster:admin/knn_update_blocked_model_action"; + public static final UpdateBlockedModelAction INSTANCE = new UpdateBlockedModelAction(NAME, AcknowledgedResponse::new); + + /** + * Constructor. + * + * @param name name of action + * @param acknowledgedResponseReader reader for acknowledged response + */ + public UpdateBlockedModelAction(String name, Writeable.Reader acknowledgedResponseReader) { + super(name, acknowledgedResponseReader); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java new file mode 100644 index 000000000..300231a8b --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.support.master.AcknowledgedRequest; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; + +import java.io.IOException; + +import static org.opensearch.action.ValidateActions.addValidationError; + +/** + * Request for updating blocked modelIds list while processing delete model request + */ +public class UpdateBlockedModelRequest extends AcknowledgedRequest { + + private String modelId; + private boolean isRemoveRequest; + + /** + * Constructor + * + * @param in input stream + * @throws IOException if read from stream fails + */ + public UpdateBlockedModelRequest(StreamInput in) throws IOException { + super(in); + this.modelId = in.readString(); + this.isRemoveRequest = in.readBoolean(); + } + + /** + * Constructor + * + * @param modelId Id of model + * @param isRemoveRequest should this model id be removed + */ + public UpdateBlockedModelRequest(String modelId, boolean isRemoveRequest) { + super(); + this.modelId = modelId; + this.isRemoveRequest = isRemoveRequest; + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + + if (modelId.isEmpty()) { + validationException = addValidationError("Missing model ID", validationException); + } + + return validationException; + } + + /** + * Getter for modelId + * + * @return modelId + */ + public String getModelId() { + return modelId; + } + + /** + * Getter for isRemoveRequest + * + * @return isRemoveRequest + */ + public boolean isRemoveRequest() { + return isRemoveRequest; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(modelId); + out.writeBoolean(isRemoveRequest); + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java new file mode 100644 index 000000000..24df82471 --- /dev/null +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java @@ -0,0 +1,153 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.action.support.master.TransportMasterNodeAction; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.ClusterStateTaskConfig; +import org.opensearch.cluster.ClusterStateTaskExecutor; +import org.opensearch.cluster.ClusterStateTaskListener; +import org.opensearch.cluster.block.ClusterBlockException; +import org.opensearch.cluster.metadata.IndexNameExpressionResolver; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.Priority; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.knn.plugin.BlockedModelIds; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.opensearch.knn.common.KNNConstants.PLUGIN_NAME; + +/** + * Transport action used to update blocked modelIds list on the cluster manager node. + */ +public class UpdateBlockedModelTransportAction extends TransportMasterNodeAction { + + public static Logger logger = LogManager.getLogger(UpdateBlockedModelTransportAction.class); + + private UpdateBlockedModelExecutor updateBlockedModelExecutor; + + @Inject + public UpdateBlockedModelTransportAction( + TransportService transportService, + ClusterService clusterService, + ThreadPool threadPool, + ActionFilters actionFilters, + IndexNameExpressionResolver indexNameExpressionResolver + ) { + super( + UpdateBlockedModelAction.NAME, + transportService, + clusterService, + threadPool, + actionFilters, + UpdateBlockedModelRequest::new, + indexNameExpressionResolver + ); + this.updateBlockedModelExecutor = new UpdateBlockedModelExecutor(); + } + + @Override + protected String executor() { + return ThreadPool.Names.SAME; + } + + @Override + protected AcknowledgedResponse read(StreamInput streamInput) throws IOException { + return new AcknowledgedResponse(streamInput); + } + + @Override + protected void masterOperation( + UpdateBlockedModelRequest request, + ClusterState clusterState, + ActionListener actionListener + ) { + // ClusterManager updates blocked modelIds list based on request parameters + clusterService.submitStateUpdateTask( + PLUGIN_NAME, + new UpdateBlockedModelTask(request.getModelId(), request.isRemoveRequest()), + ClusterStateTaskConfig.build(Priority.NORMAL), + updateBlockedModelExecutor, + new ClusterStateTaskListener() { + @Override + public void onFailure(String s, Exception e) { + actionListener.onFailure(e); + } + + @Override + public void clusterStateProcessed(String source, ClusterState oldState, ClusterState newState) { + actionListener.onResponse(new AcknowledgedResponse(true)); + } + } + ); + } + + @Override + protected ClusterBlockException checkBlock(UpdateBlockedModelRequest request, ClusterState clusterState) { + return null; + } + + /** + * UpdateBlockedModelTask is used to provide the executor with the information it needs to perform its task + */ + private static class UpdateBlockedModelTask { + + private String modelId; + private boolean isRemoveRequest; + + /** + * Constructor + * + * @param modelId id of model + * @param isRemoveRequest should this modelId be removed + */ + UpdateBlockedModelTask(String modelId, boolean isRemoveRequest) { + this.modelId = modelId; + this.isRemoveRequest = isRemoveRequest; + } + } + + private static class UpdateBlockedModelExecutor implements ClusterStateTaskExecutor { + @Override + public ClusterTasksResult execute(ClusterState clusterState, List list) { + + BlockedModelIds immutableBlockedModelIds = clusterState.metadata().custom(BlockedModelIds.TYPE); + BlockedModelIds blockedModelIds; + + if (immutableBlockedModelIds == null) { + blockedModelIds = new BlockedModelIds(new ArrayList<>()); + } else { + blockedModelIds = immutableBlockedModelIds; + } + + for (UpdateBlockedModelTask task : list) { + if (task.isRemoveRequest) { + blockedModelIds.remove(task.modelId); + } else { + blockedModelIds.add(task.modelId); + } + } + + Metadata.Builder metaDataBuilder = Metadata.builder(clusterState.metadata()); + metaDataBuilder.putCustom(BlockedModelIds.TYPE, blockedModelIds); + + ClusterState updatedClusterState = ClusterState.builder(clusterState).metadata(metaDataBuilder).build(); + return new ClusterTasksResult.Builder().successes(list).build(updatedClusterState); + } + } +} diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 51f8240f3..e3ee5603a 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -17,10 +17,15 @@ import org.opensearch.ResourceAlreadyExistsException; import org.opensearch.action.ActionListener; import org.opensearch.action.DocWriteResponse; +import org.opensearch.action.StepListener; import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.delete.DeleteAction; +import org.opensearch.action.delete.DeleteRequestBuilder; +import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.support.WriteRequest; +import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.index.IndexNotFoundException; @@ -29,6 +34,14 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.util.KNNEngine; import org.opensearch.knn.plugin.transport.DeleteModelResponse; +import org.opensearch.knn.plugin.transport.GetModelResponse; +import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; +import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest; +import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse; +import org.opensearch.knn.plugin.transport.UpdateBlockedModelAction; +import org.opensearch.knn.plugin.transport.UpdateBlockedModelRequest; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; +import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; import org.opensearch.rest.RestStatus; import java.io.IOException; @@ -479,6 +492,7 @@ public void testGetMetadata() throws IOException, InterruptedException { public void testDelete() throws IOException, InterruptedException { ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); String modelId = "testDeleteModelID"; + String modelId1 = "testDeleteModelID1"; byte[] modelBlob = "hello".getBytes(); int dimension = 2; @@ -495,28 +509,31 @@ public void testDelete() throws IOException, InterruptedException { final CountDownLatch inProgressLatch1 = new CountDownLatch(1); ActionListener deleteModelDoesNotExistListener = ActionListener.wrap(response -> { - assertEquals(DocWriteResponse.Result.NOT_FOUND.getLowercase(), response.getResult()); + assertEquals(modelId, response.getModelID()); + assertEquals("failed", response.getResult()); + assertNotNull(response.getErrorMessage()); inProgressLatch1.countDown(); - }, exception -> fail("Unable to delete the model: " + exception)); + }, exception -> fail(exception.getMessage())); modelDao.delete(modelId, deleteModelDoesNotExistListener); assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); final CountDownLatch inProgressLatch2 = new CountDownLatch(1); - ActionListener deleteModelExistsListener = ActionListener.wrap(response -> { + ActionListener deleteModelTrainingListener = ActionListener.wrap(response -> { assertEquals(modelId, response.getModelID()); - assertEquals(DocWriteResponse.Result.DELETED.getLowercase(), response.getResult()); - assertNull(response.getErrorMessage()); + assertEquals("failed", response.getResult()); + String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", modelId); + assertEquals(errorMessage, response.getErrorMessage()); inProgressLatch2.countDown(); }, exception -> fail("Unable to delete model: " + exception)); - // model id exists + // model id exists and model is still in Training Model model = new Model( new ModelMetadata( KNNEngine.DEFAULT, SpaceType.DEFAULT, dimension, - ModelState.CREATED, + ModelState.TRAINING, ZonedDateTime.now(ZoneOffset.UTC).toString(), "", "" @@ -527,13 +544,253 @@ public void testDelete() throws IOException, InterruptedException { ActionListener docCreationListener = ActionListener.wrap(response -> { assertEquals(modelId, response.getId()); - modelDao.delete(modelId, deleteModelExistsListener); + modelDao.delete(modelId, deleteModelTrainingListener); }, exception -> fail("Unable to put the model: " + exception)); - // We use put so that we can confirm cluster metadata gets added modelDao.put(model, docCreationListener); assertTrue(inProgressLatch2.await(100, TimeUnit.SECONDS)); + + final CountDownLatch inProgressLatch3 = new CountDownLatch(1); + ActionListener deleteModelExistsListener = ActionListener.wrap(response -> { + assertEquals(modelId1, response.getModelID()); + assertEquals(DocWriteResponse.Result.DELETED.getLowercase(), response.getResult()); + assertNull(response.getErrorMessage()); + inProgressLatch3.countDown(); + }, exception -> fail("Unable to delete model: " + exception)); + + // model id exists + Model model1 = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId1 + ); + + ActionListener docCreationListener1 = ActionListener.wrap(response -> { + assertEquals(modelId1, response.getId()); + modelDao.delete(modelId1, deleteModelExistsListener); + }, exception -> fail("Unable to put the model: " + exception)); + + // We use put so that we can confirm cluster metadata gets added + modelDao.put(model1, docCreationListener1); + + assertTrue(inProgressLatch3.await(100, TimeUnit.SECONDS)); + } + + public void testDeleteWithStepListeners() throws IOException, InterruptedException, ExecutionException { + String modelId = "test-model-id-delete"; + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + byte[] modelBlob = "deleteModel".getBytes(); + int dimension = 2; + createIndex(MODEL_INDEX_NAME); + + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); + + // created model and added it to index + addDoc(model); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + StepListener getModelStep = new StepListener<>(); + StepListener blockModelIdStep = new StepListener<>(); + StepListener clearModelMetadataStep = new StepListener<>(); + StepListener deleteModelFromIndexStep = new StepListener<>(); + StepListener clearModelFromCacheStep = new StepListener<>(); + StepListener unblockModelIdStep = new StepListener<>(); + + modelDao.get(modelId, ActionListener.wrap(getModelStep::onResponse, getModelStep::onFailure)); + + // Asserting that model is in CREATED state + getModelStep.whenComplete(getModelResponse -> { + assertEquals(model.getModelMetadata().getState(), getModelResponse.getModel().getModelMetadata().getState()); + assertNotEquals(ModelState.TRAINING.getName(), getModelResponse.getModel().getModelMetadata().getState().toString()); + + client().execute( + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest(modelId, false), + ActionListener.wrap(blockModelIdStep::onResponse, blockModelIdStep::onFailure) + ); + }, exception -> fail(exception.getMessage())); + + blockModelIdStep.whenComplete(acknowledgedResponse -> { + // Asserting that modelId is in blocked list + assertTrue(modelDao.isModelBlocked(modelId)); + + client().execute( + UpdateModelMetadataAction.INSTANCE, + new UpdateModelMetadataRequest(modelId, true, null), + ActionListener.wrap(clearModelMetadataStep::onResponse, clearModelMetadataStep::onFailure) + ); + + }, exception -> fail(exception.getMessage())); + + DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client(), DeleteAction.INSTANCE, MODEL_INDEX_NAME); + deleteRequestBuilder.setId(modelId); + deleteRequestBuilder.setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); + + clearModelMetadataStep.whenComplete(acknowledgedResponse -> { + // Asserting that metadata is cleared + assertNull(modelDao.getMetadata(modelId)); + + deleteRequestBuilder.execute(ActionListener.wrap(deleteModelFromIndexStep::onResponse, deleteModelFromIndexStep::onFailure)); + + }, exception -> fail(exception.getMessage())); + + deleteModelFromIndexStep.whenComplete(deleteResponse -> { + // Asserting that model is deleted from index + assertEquals(DocWriteResponse.Result.DELETED, deleteResponse.getResult()); + client().execute( + RemoveModelFromCacheAction.INSTANCE, + new RemoveModelFromCacheRequest(modelId), + ActionListener.wrap(clearModelFromCacheStep::onResponse, clearModelFromCacheStep::onFailure) + ); + + }, exception -> fail(exception.getMessage())); + + clearModelFromCacheStep.whenComplete(removeModelFromCacheResponse -> { + client().execute( + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest(modelId, true), + ActionListener.wrap(unblockModelIdStep::onResponse, unblockModelIdStep::onFailure) + ); + }, exception -> fail(exception.getMessage())); + + unblockModelIdStep.whenComplete(acknowledgedResponse -> { + // Asserting that model is unblocked + assertFalse(modelDao.isModelBlocked(modelId)); + inProgressLatch.countDown(); + }, exception -> fail(exception.getMessage())); + + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + + public void testDeleteWithStepListenersOnFailure() throws IOException, InterruptedException, ExecutionException { + String modelId = "test-model-id-delete1"; + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + byte[] modelBlob = "deleteModel".getBytes(); + int dimension = 2; + createIndex(MODEL_INDEX_NAME); + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.CREATED, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); + + addDoc(model); + + // We will validate if the modelId gets unblocked when some exception occurs + // during the process of deletion after adding that modelId to blocked list + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + StepListener blockModelIdStep = new StepListener<>(); + StepListener clearModelMetadataStep = new StepListener<>(); + + // Add modelId to blocked list + client().execute( + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest(modelId, false), + ActionListener.wrap(blockModelIdStep::onResponse, blockModelIdStep::onFailure) + ); + + // Asserting that the modelId is blocked + blockModelIdStep.whenComplete(acknowledgedResponse -> { + assertTrue(modelDao.isModelBlocked(modelId)); + + // Sending empty string for modelId to fail the clear model metadata request + client().execute( + UpdateModelMetadataAction.INSTANCE, + new UpdateModelMetadataRequest("", true, null), + ActionListener.wrap(clearModelMetadataStep::onResponse, exp -> { + // Asserting that modelId is still blocked and clearModelMetadata throws an exception + assertNotNull(exp.getMessage()); + assertTrue(modelDao.isModelBlocked(modelId)); + client().execute( + // OnFailure sending request to unblock modelId + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest(modelId, true), + ActionListener.wrap(ackResponse -> { + // Asserting that model is unblocked + assertFalse(modelDao.isModelBlocked(modelId)); + assertNotNull(exp.getMessage()); + }, exception -> fail(exception.getMessage())) + ); + }) + ); + inProgressLatch.countDown(); + }, exception -> fail(exception.getMessage())); + + assertTrue(inProgressLatch.await(50, TimeUnit.SECONDS)); + + // Some exception occurs during the process of deletion and unblocking model request also fails + final CountDownLatch inProgressLatch1 = new CountDownLatch(1); + + StepListener blockModelIdStep1 = new StepListener<>(); + StepListener clearModelMetadataStep1 = new StepListener<>(); + + // Add modelId to blocked list + client().execute( + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest(modelId, false), + ActionListener.wrap(blockModelIdStep1::onResponse, blockModelIdStep1::onFailure) + ); + + // Asserting that the modelId is blocked + blockModelIdStep1.whenComplete(acknowledgedResponse -> { + assertTrue(modelDao.isModelBlocked(modelId)); + + // Sending empty string for modelId to fail the clear model metadata request + client().execute( + UpdateModelMetadataAction.INSTANCE, + new UpdateModelMetadataRequest("", true, null), + ActionListener.wrap(clearModelMetadataStep1::onResponse, exp -> { + assertNotNull(exp.getMessage()); + assertTrue(modelDao.isModelBlocked(modelId)); + + // Failing unblock modelId request by sending modelId as an empty string + client().execute( + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest("", true), + ActionListener.wrap(ackResponse -> {}, unblockingFailedException -> { + // Asserting that model is still blocked and returns both exceptions in response + assertTrue(modelDao.isModelBlocked(modelId)); + assertNotNull(exp.getMessage()); + assertNotNull(unblockingFailedException.getMessage()); + }) + ); + }) + ); + inProgressLatch1.countDown(); + }, exception -> fail(exception.getMessage())); + + assertTrue(inProgressLatch1.await(50, TimeUnit.SECONDS)); } public void addDoc(Model model) throws IOException, ExecutionException, InterruptedException { @@ -564,4 +821,5 @@ public void addDoc(Model model) throws IOException, ExecutionException, Interrup IndexResponse response = client().index(indexRequest).get(); assertTrue(response.status() == RestStatus.CREATED || response.status() == RestStatus.OK); } + } diff --git a/src/test/java/org/opensearch/knn/plugin/BlockedModelIdsTests.java b/src/test/java/org/opensearch/knn/plugin/BlockedModelIdsTests.java new file mode 100644 index 000000000..0fa6b796d --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/BlockedModelIdsTests.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +public class BlockedModelIdsTests extends OpenSearchTestCase { + + public void testAdd() { + List modelIds = new ArrayList<>(); + BlockedModelIds testBlockedModelIds = new BlockedModelIds(modelIds); + String testModelId = "test-model-id"; + testBlockedModelIds.add(testModelId); + assertTrue(testBlockedModelIds.contains(testModelId)); + } + + public void testRemove() { + List modelIds = new ArrayList<>(); + String testModelId = "test-model-id"; + modelIds.add(testModelId); + BlockedModelIds testBlockedModelIds = new BlockedModelIds(modelIds); + + assertTrue(testBlockedModelIds.contains(testModelId)); + testBlockedModelIds.remove(testModelId); + assertFalse(testBlockedModelIds.contains(testModelId)); + } + + public void testContains() { + List modelIds = new ArrayList<>(); + String testModelId = "test-model-id"; + modelIds.add(testModelId); + + BlockedModelIds testBlockedModelIds = new BlockedModelIds(modelIds); + assertTrue(testBlockedModelIds.contains(testModelId)); + } + + public void testStreams() throws IOException { + List modelIds = new ArrayList<>(); + String testModelId = "test-model-id"; + modelIds.add(testModelId); + BlockedModelIds testBlockedModelIds = new BlockedModelIds(modelIds); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + testBlockedModelIds.writeTo(streamOutput); + + BlockedModelIds testBlockedModelIdsCopy = new BlockedModelIds(streamOutput.bytes().streamInput()); + + assertEquals(testBlockedModelIds.size(), testBlockedModelIdsCopy.size()); + assertEquals(testBlockedModelIds.getBlockedModelIds().get(0), testBlockedModelIdsCopy.getBlockedModelIds().get(0)); + } + +} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequestTests.java new file mode 100644 index 000000000..cc136234e --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequestTests.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.knn.KNNTestCase; +import java.io.IOException; + +public class UpdateBlockedModelRequestTests extends KNNTestCase { + + public void testStreams() throws IOException { + String modelId = "test-model-id"; + boolean isRemoveRequest = false; + + UpdateBlockedModelRequest updateBlockedModelRequest = new UpdateBlockedModelRequest(modelId, isRemoveRequest); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + updateBlockedModelRequest.writeTo(streamOutput); + + UpdateBlockedModelRequest updateBlockedModelRequest1 = new UpdateBlockedModelRequest(streamOutput.bytes().streamInput()); + + assertEquals(updateBlockedModelRequest.getModelId(), updateBlockedModelRequest1.getModelId()); + assertEquals(updateBlockedModelRequest.isRemoveRequest(), updateBlockedModelRequest1.isRemoveRequest()); + } + + public void testValidate() { + String modelId = "test-model-id"; + UpdateBlockedModelRequest updateBlockedModelRequest1 = new UpdateBlockedModelRequest(modelId, false); + assertNull(updateBlockedModelRequest1.validate()); + + UpdateBlockedModelRequest updateBlockedModelRequest2 = new UpdateBlockedModelRequest(modelId, true); + assertNull(updateBlockedModelRequest2.validate()); + + UpdateBlockedModelRequest updateBlockedModelRequest3 = new UpdateBlockedModelRequest("", false); + assertNotNull(updateBlockedModelRequest3.validate()); + + UpdateBlockedModelRequest updateBlockedModelRequest4 = new UpdateBlockedModelRequest("", true); + assertNotNull(updateBlockedModelRequest4.validate()); + } + + public void testGetModelId() { + String modelId = "test-model-id"; + UpdateBlockedModelRequest updateBlockedModelRequest = new UpdateBlockedModelRequest(modelId, false); + + assertEquals(modelId, updateBlockedModelRequest.getModelId()); + } + + public void testIsRemoveRequest() { + String modelId = "test-model-id"; + boolean isRemoveRequest = false; + UpdateBlockedModelRequest updateBlockedModelRequest = new UpdateBlockedModelRequest(modelId, isRemoveRequest); + + assertEquals(isRemoveRequest, updateBlockedModelRequest.isRemoveRequest()); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java new file mode 100644 index 000000000..82979f443 --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.master.AcknowledgedResponse; +import org.opensearch.cluster.ClusterState; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.knn.KNNSingleNodeTestCase; +import org.opensearch.knn.plugin.BlockedModelIds; +import org.opensearch.threadpool.ThreadPool; + +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +public class UpdateBlockedModelTransportActionTests extends KNNSingleNodeTestCase { + + public void testExecutor() { + UpdateBlockedModelTransportAction updateBlockedModelTransportAction = node().injector() + .getInstance(UpdateBlockedModelTransportAction.class); + assertEquals(ThreadPool.Names.SAME, updateBlockedModelTransportAction.executor()); + } + + public void testRead() throws IOException { + UpdateBlockedModelTransportAction updateBlockedModelTransportAction = node().injector() + .getInstance(UpdateBlockedModelTransportAction.class); + AcknowledgedResponse acknowledgedResponse = new AcknowledgedResponse(true); + BytesStreamOutput streamOutput = new BytesStreamOutput(); + acknowledgedResponse.writeTo(streamOutput); + AcknowledgedResponse acknowledgedResponse1 = updateBlockedModelTransportAction.read(streamOutput.bytes().streamInput()); + + assertEquals(acknowledgedResponse, acknowledgedResponse1); + } + + public void testClusterManagerOperation() throws InterruptedException { + + String modelId = "test-model-id"; + + // Get update transport action + UpdateBlockedModelTransportAction updateBlockedModelTransportAction = node().injector() + .getInstance(UpdateBlockedModelTransportAction.class); + + // Generate update request to add modelId to blocked list + UpdateBlockedModelRequest addBlockedModelRequest = new UpdateBlockedModelRequest(modelId, false); + + // Get cluster state, update metadata, check cluster state - all asynchronously + final CountDownLatch inProgressLatch1 = new CountDownLatch(1); + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { + ClusterState clusterState1 = stateResponse1.getState(); + updateBlockedModelTransportAction.masterOperation( + addBlockedModelRequest, + clusterState1, + ActionListener.wrap(acknowledgedResponse -> { + assertTrue(acknowledgedResponse.isAcknowledged()); + + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse2 -> { + ClusterState updatedClusterState = stateResponse2.getState(); + BlockedModelIds blockedModelIds = updatedClusterState.metadata().custom(BlockedModelIds.TYPE); + + assertNotNull(blockedModelIds); + assertEquals(1, blockedModelIds.size()); + assertTrue(blockedModelIds.contains(modelId)); + + inProgressLatch1.countDown(); + + }, e -> fail("Update failed:" + e))); + }, e -> fail("Update failed: " + e)) + ); + }, e -> fail("Update failed: " + e))); + + assertTrue(inProgressLatch1.await(60, TimeUnit.SECONDS)); + + // Generate remove request to remove the modelId from blocked list + UpdateBlockedModelRequest removeBlockedModelRequest = new UpdateBlockedModelRequest(modelId, true); + + final CountDownLatch inProgressLatch2 = new CountDownLatch(1); + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { + ClusterState clusterState1 = stateResponse1.getState(); + updateBlockedModelTransportAction.masterOperation( + removeBlockedModelRequest, + clusterState1, + ActionListener.wrap(acknowledgedResponse -> { + assertTrue(acknowledgedResponse.isAcknowledged()); + + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse2 -> { + ClusterState updatedClusterState = stateResponse2.getState(); + BlockedModelIds blockedModelIds = updatedClusterState.metadata().custom(BlockedModelIds.TYPE); + + assertNotNull(blockedModelIds); + assertEquals(0, blockedModelIds.size()); + assertFalse(blockedModelIds.contains(modelId)); + + inProgressLatch2.countDown(); + }, e -> fail("Update failed"))); + }, e -> fail("Update failed")) + ); + }, e -> fail("Update failed"))); + + assertTrue(inProgressLatch2.await(60, TimeUnit.SECONDS)); + } + + public void testCheckBlock() { + UpdateBlockedModelTransportAction updateBlockedModelTransportAction = node().injector() + .getInstance(UpdateBlockedModelTransportAction.class); + assertNull(updateBlockedModelTransportAction.checkBlock(null, null)); + } +} From 4b6b89f806ad791443a4deeef10ac117965f422d Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 16 Jun 2022 19:07:35 -0500 Subject: [PATCH 3/9] Add Blocked modelIds Validation for new Train Model Request Signed-off-by: Naveen Tatikonda --- .../transport/TrainingModelRequest.java | 13 ++++++ .../transport/TrainingModelRequestTests.java | 41 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index ce9397905..66cd484d5 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -260,6 +260,19 @@ public ActionRequestValidationException validate() { return exception; } + // Check if modelId is in blocked list + // ModelId is added to blocked list if that model is undergoing deletion + // and will be removed from blocked list after model is deleted + if (modelDao.isModelBlocked(modelId)) { + exception = new ActionRequestValidationException(); + String errorMessage = String.format( + "\"%s\" is in blocked list. Cannot create a model with same modelID until that model is deleted", + modelId + ); + exception.addValidationError(errorMessage); + return exception; + } + // Confirm that the passed in knnMethodContext is valid and requires training ValidationException validationException = this.knnMethodContext.validate(); if (validationException != null) { diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index 06b6f3d01..a4bbffbaf 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -186,6 +186,47 @@ public void testValidation_invalid_modelIdAlreadyExists() { assertTrue(validationErrors.get(0).contains("already exists")); } + // Check that the validation produces an exception when we are + // training a model with modelId that is in blocked list + public void testValidation_blocked_modelId() { + + // Setup the training request + String modelId = "test-model-id"; + KNNMethodContext knnMethodContext = mock(KNNMethodContext.class); + when(knnMethodContext.validate()).thenReturn(null); + when(knnMethodContext.isTrainingRequired()).thenReturn(true); + int dimension = 10; + String trainingIndex = "test-training-index"; + String trainingField = "test-training-field"; + + TrainingModelRequest trainingModelRequest = new TrainingModelRequest( + modelId, + knnMethodContext, + dimension, + trainingIndex, + trainingField, + null, + null + ); + + // Mock the model dao to return true to recognize that the modelId is blocked + ModelDao modelDao = mock(ModelDao.class); + when(modelDao.isModelBlocked(modelId)).thenReturn(true); + + // This cluster service will result in no validation exceptions + ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension); + + // Initialize static components with the mocks + TrainingModelRequest.initialize(modelDao, clusterService); + + // Test that validation produces an error message that modelId is in blocked list + ActionRequestValidationException exception = trainingModelRequest.validate(); + assertNotNull(exception); + List validationErrors = exception.validationErrors(); + assertEquals(1, validationErrors.size()); + assertTrue(validationErrors.get(0).contains("in blocked list")); + } + public void testValidation_invalid_invalidMethodContext() { // Check that validation produces exception when the method is invalid and does not require training From bf406ef903ecb024eb4a6b48d5672fade585eb94 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 16 Jun 2022 19:34:06 -0500 Subject: [PATCH 4/9] spotless fix and other changes Signed-off-by: Naveen Tatikonda --- .../org/opensearch/knn/plugin/KNNPlugin.java | 18 +++++++++--------- .../opensearch/knn/indices/ModelDaoTests.java | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 9a6e13933..d309a06ab 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -321,20 +321,20 @@ public List getNamedXContent() { } private static void registerMetadataCustom( - List entries, - String name, - Writeable.Reader reader, - Writeable.Reader diffReader + List entries, + String name, + Writeable.Reader reader, + Writeable.Reader diffReader ) { registerCustom(entries, Metadata.Custom.class, name, reader, diffReader); } private static void registerCustom( - List entries, - Class category, - String name, - Writeable.Reader reader, - Writeable.Reader diffReader + List entries, + Class category, + String name, + Writeable.Reader reader, + Writeable.Reader diffReader ) { entries.add(new NamedWriteableRegistry.Entry(category, name, reader)); entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, name, diffReader)); diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index e3ee5603a..5f4527427 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -747,7 +747,7 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt inProgressLatch.countDown(); }, exception -> fail(exception.getMessage())); - assertTrue(inProgressLatch.await(50, TimeUnit.SECONDS)); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); // Some exception occurs during the process of deletion and unblocking model request also fails final CountDownLatch inProgressLatch1 = new CountDownLatch(1); @@ -790,7 +790,7 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt inProgressLatch1.countDown(); }, exception -> fail(exception.getMessage())); - assertTrue(inProgressLatch1.await(50, TimeUnit.SECONDS)); + assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); } public void addDoc(Model model) throws IOException, ExecutionException, InterruptedException { From e2061d6c590c166407487257e094b46148a1c700 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Mon, 20 Jun 2022 17:32:01 -0500 Subject: [PATCH 5/9] Address Review Comments Signed-off-by: Naveen Tatikonda --- .../java/org/opensearch/knn/bwc/ModelIT.java | 41 ++- .../org/opensearch/knn/indices/ModelDao.java | 155 +++++++----- .../knn/indices/ModelGraveyard.java | 236 ++++++++++++++++++ .../knn/plugin/BlockedModelIds.java | 137 ---------- .../org/opensearch/knn/plugin/KNNPlugin.java | 5 +- .../transport/TrainingModelRequest.java | 10 +- .../transport/UpdateBlockedModelRequest.java | 25 +- .../UpdateBlockedModelTransportAction.java | 61 +++-- .../opensearch/knn/indices/ModelDaoTests.java | 66 ++++- .../knn/indices/ModelGraveyardTests.java | 138 ++++++++++ .../knn/plugin/BlockedModelIdsTests.java | 60 ----- .../transport/TrainingModelRequestTests.java | 8 +- ...pdateBlockedModelTransportActionTests.java | 22 +- 13 files changed, 617 insertions(+), 347 deletions(-) create mode 100644 src/main/java/org/opensearch/knn/indices/ModelGraveyard.java delete mode 100644 src/main/java/org/opensearch/knn/plugin/BlockedModelIds.java create mode 100644 src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java delete mode 100644 src/test/java/org/opensearch/knn/plugin/BlockedModelIdsTests.java diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java index 829382e13..f6ff2b1c2 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/knn/bwc/ModelIT.java @@ -16,7 +16,11 @@ import org.opensearch.common.xcontent.XContentParser; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.index.SpaceType; +import org.opensearch.knn.index.util.KNNEngine; +import org.opensearch.knn.indices.ModelMetadata; +import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.plugin.KNNPlugin; +import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.rest.RestStatus; import org.opensearch.search.SearchHit; @@ -52,7 +56,7 @@ public class ModelIT extends AbstractRestartUpgradeTestCase { private static int DOC_ID_TEST_MODEL_INDEX = 0; private static int DOC_ID_TEST_MODEL_INDEX_DEFAULT = 0; private static final int DELAY_MILLI_SEC = 1000; - private static final int EXP_NUM_OF_MODELS = 2; + private static final int EXP_NUM_OF_MODELS = 3; private static final int K = 5; private static final int NUM_DOCS = 10; private static final int NUM_DOCS_TEST_MODEL_INDEX = 100; @@ -63,6 +67,7 @@ public class ModelIT extends AbstractRestartUpgradeTestCase { private static int QUERY_COUNT_TEST_MODEL_INDEX_DEFAULT = 0; private static final String TEST_MODEL_ID = "test-model-id"; private static final String TEST_MODEL_ID_DEFAULT = "test-model-id-default"; + private static final String TEST_MODEL_ID_TRAINING = "test-model-id-training"; private static final String MODEL_DESCRIPTION = "Description for train model test"; // KNN model test @@ -135,12 +140,42 @@ public void testKNNModelDefault() throws Exception { } } + // KNN Delete Model test for model in Training State + public void testDeleteTrainingModel() throws IOException, InterruptedException { + byte[] testModelBlob = "hello".getBytes(); + ModelMetadata testModelMetadata = getModelMetadata(); + testModelMetadata.setState(ModelState.TRAINING); + if (isRunningAgainstOldCluster()) { + addModelToSystemIndex(TEST_MODEL_ID_TRAINING, testModelMetadata, testModelBlob); + } else { + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, TEST_MODEL_ID_TRAINING); + Request request = new Request("DELETE", restURI); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + assertEquals(3, getDocCount(MODEL_INDEX_NAME)); + + String responseBody = EntityUtils.toString(response.getEntity()); + assertNotNull(responseBody); + + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + + assertEquals(TEST_MODEL_ID_TRAINING, responseMap.get(MODEL_ID)); + assertEquals("failed", responseMap.get(DeleteModelResponse.RESULT)); + + String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", TEST_MODEL_ID_TRAINING); + assertEquals(errorMessage, responseMap.get(DeleteModelResponse.ERROR_MSG)); + } + } + // Delete Models and ".opensearch-knn-models" index to clear cluster metadata @AfterClass public static void wipeAllModels() throws IOException { if (!isRunningAgainstOldCluster()) { deleteKNNModel(TEST_MODEL_ID); deleteKNNModel(TEST_MODEL_ID_DEFAULT); + deleteKNNModel(TEST_MODEL_ID_TRAINING); Request request = new Request("DELETE", "/" + MODEL_INDEX_NAME); @@ -241,4 +276,8 @@ public String modelIndexMapping(String fieldName, String modelId) throws IOExcep .endObject() ); } + + private ModelMetadata getModelMetadata() { + return new ModelMetadata(KNNEngine.DEFAULT, SpaceType.DEFAULT, 4, ModelState.CREATED, "2021-03-27", "test model", ""); + } } diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 19815d865..5853b53d1 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -43,7 +43,6 @@ import org.opensearch.common.settings.Settings; import org.opensearch.index.IndexNotFoundException; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.plugin.BlockedModelIds; import org.opensearch.knn.plugin.transport.DeleteModelResponse; import org.opensearch.knn.plugin.transport.GetModelResponse; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; @@ -59,8 +58,10 @@ import java.util.Base64; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.concurrent.ExecutionException; +import static java.util.Objects.isNull; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_MAPPING_PATH; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; import static org.opensearch.knn.common.KNNConstants.MODEL_METADATA_FIELD; @@ -155,12 +156,14 @@ public interface ModelDao { void delete(String modelId, ActionListener listener); /** - * Check if modelId is in blocked modelIds list. Non-blocking. + * Check if modelId is in blocked model set (ModelGraveyard) or not. Non-blocking. + * A modelId is added to blocked model set before deleting that + * model and removed from the set after deleting the model * * @param modelId to retrieve - * @return true if modelId is in blocked list, otherwise return false + * @return true if modelId is in blocked model set, otherwise return false */ - boolean isModelBlocked(String modelId); + boolean isModelBlockedForDelete(String modelId); /** * Implementation of ModelDao for k-NN model index @@ -437,14 +440,18 @@ private String getMapping() throws IOException { return Resources.toString(url, Charsets.UTF_8); } - // Check if the modelId is added to blocked list or not - public boolean isModelBlocked(String modelId) { - BlockedModelIds blockedModelIds = clusterService.state().metadata().custom(BlockedModelIds.TYPE); - if (blockedModelIds == null) { + @Override + public boolean isModelBlockedForDelete(String modelId) { + // Check if the objects are not null and throw a customized NullPointerException + Objects.requireNonNull(clusterService.state(), "Cluster state must not be null"); + Objects.requireNonNull(clusterService.state().metadata(), "Cluster metadata must not be null"); + ModelGraveyard modelGraveyard = clusterService.state().metadata().custom(ModelGraveyard.TYPE); + + if (isNull(modelGraveyard)) { return false; } - return blockedModelIds.contains(modelId); + return modelGraveyard.contains(modelId); } @Override @@ -467,7 +474,8 @@ public void delete(String modelId, ActionListener listener) // Get Model to check if model is in TRAINING get(modelId, ActionListener.wrap(getModelStep::onResponse, exception -> { if (exception instanceof ResourceNotFoundException) { - listener.onResponse(new DeleteModelResponse(modelId, "failed", exception.getMessage())); + String errorMessage = String.format("Unable to delete model \"%s\". Model does not exist", modelId); + listener.onResponse(new DeleteModelResponse(modelId, "failed", errorMessage)); return; } listener.onFailure(exception); @@ -476,32 +484,21 @@ public void delete(String modelId, ActionListener listener) getModelStep.whenComplete(getModelResponse -> { // If model is in Training state, fail delete model request if (ModelState.TRAINING.equals(getModelResponse.getModel().getModelMetadata().getState())) { - logger.error("Cannot delete model \"" + modelId + "\". Model is still in training."); String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", modelId); + logger.error(errorMessage); listener.onResponse(new DeleteModelResponse(modelId, "failed", errorMessage)); return; } - // Add modelId to blocked list until delete model request is processed - client.execute( - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest(modelId, false), - ActionListener.wrap(blockModelIdStep::onResponse, blockModelIdStep::onFailure) - ); - + // Add modelId to blocked set until delete model request is processed + updateBlockedModelToDelete(modelId, false, blockModelIdStep); }, listener::onFailure); // Remove the metadata asynchronously - blockModelIdStep.whenComplete(acknowledgedResponse -> { - client.execute( - UpdateModelMetadataAction.INSTANCE, - new UpdateModelMetadataRequest(modelId, true, null), - ActionListener.wrap( - clearModelMetadataStep::onResponse, - exception -> unblockModelIdOnFailure(modelId, exception, clearModelMetadataStep) - ) - ); - }, listener::onFailure); + blockModelIdStep.whenComplete( + acknowledgedResponse -> { clearModelMetadata(modelId, clearModelMetadataStep); }, + listener::onFailure + ); // Setup delete model request DeleteRequestBuilder deleteRequestBuilder = new DeleteRequestBuilder(client, DeleteAction.INSTANCE, MODEL_INDEX_NAME); @@ -510,53 +507,88 @@ public void delete(String modelId, ActionListener listener) // On model metadata removal, delete the model from the index clearModelMetadataStep.whenComplete( - acknowledgedResponse -> deleteRequestBuilder.execute( - ActionListener.wrap( - deleteModelFromIndexStep::onResponse, - exception -> unblockModelIdOnFailure(modelId, exception, deleteModelFromIndexStep) - ) - ), + acknowledgedResponse -> deleteModelFromIndex(modelId, deleteModelFromIndexStep, deleteRequestBuilder), listener::onFailure ); deleteModelFromIndexStep.whenComplete(deleteResponse -> { - // If model is not deleted, return with error message + // If model is not deleted, unblock modelId and return with error message if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) { + updateBlockedModelToDelete(modelId, true, unblockModelIdStep); String errorMessage = String.format("Model \" %s \" does not exist", modelId); listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), errorMessage)); return; } // After model is deleted from the index, make sure the model is evicted from every cache in the cluster - client.execute( - RemoveModelFromCacheAction.INSTANCE, - new RemoveModelFromCacheRequest(modelId), - ActionListener.wrap( - clearModelFromCacheStep::onResponse, - exception -> unblockModelIdOnFailure(modelId, exception, clearModelFromCacheStep) - ) - ); + removeModelFromCache(modelId, clearModelFromCacheStep); }, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage()))); clearModelFromCacheStep.whenComplete(removeModelFromCacheResponse -> { - // After clearing the cache, if there are no errors remove modelId from blocked list - if (!removeModelFromCacheResponse.hasFailures()) { - client.execute( - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest(modelId, true), - ActionListener.wrap(unblockModelIdStep::onResponse, unblockModelIdStep::onFailure) - ); - } else { + // Remove modelId from blocked set + updateBlockedModelToDelete(modelId, true, unblockModelIdStep); + + unblockModelIdStep.whenComplete(acknowledgedResponse -> { + + // After clearing the cache, if there are no errors return the response + if (!removeModelFromCacheResponse.hasFailures()) { + listener.onResponse(new DeleteModelResponse(modelId, "deleted", null)); + return; + } + + // Build the error message if there are any failures in model cache response and return response String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse); listener.onResponse(new DeleteModelResponse(modelId, "failed", failureMessage)); - } + return; + }, listener::onFailure); }, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage()))); + } - // After unblocking modelId return the response - unblockModelIdStep.whenComplete(acknowledgedResponse -> { - listener.onResponse(new DeleteModelResponse(modelId, "deleted", null)); - return; - }, listener::onFailure); + // Remove model from cache in the cluster + private void removeModelFromCache(String modelId, StepListener clearModelFromCacheStep) { + client.execute( + RemoveModelFromCacheAction.INSTANCE, + new RemoveModelFromCacheRequest(modelId), + ActionListener.wrap( + clearModelFromCacheStep::onResponse, + exception -> unblockModelIdOnFailure(modelId, exception, clearModelFromCacheStep) + ) + ); + } + + // Delete model from the system index + private void deleteModelFromIndex( + String modelId, + StepListener deleteModelFromIndexStep, + DeleteRequestBuilder deleteRequestBuilder + ) { + deleteRequestBuilder.execute( + ActionListener.wrap( + deleteModelFromIndexStep::onResponse, + exception -> unblockModelIdOnFailure(modelId, exception, deleteModelFromIndexStep) + ) + ); + } + + // Update blocked model set to add/remove modelId from that set + private void updateBlockedModelToDelete(String modelId, boolean isRemoveRequest, StepListener step) { + client.execute( + UpdateBlockedModelAction.INSTANCE, + new UpdateBlockedModelRequest(modelId, isRemoveRequest), + ActionListener.wrap(step::onResponse, step::onFailure) + ); + } + + // Clear the metadata of the model for a given modelId + private void clearModelMetadata(String modelId, StepListener clearModelMetadataStep) { + client.execute( + UpdateModelMetadataAction.INSTANCE, + new UpdateModelMetadataRequest(modelId, true, null), + ActionListener.wrap( + clearModelMetadataStep::onResponse, + exception -> unblockModelIdOnFailure(modelId, exception, clearModelMetadataStep) + ) + ); } // This function helps to remove the model from blocked list when the delete request fails @@ -567,11 +599,10 @@ private void unblockModelIdOnFailure(String modelId, Exception exceptionFromPrev client.execute( UpdateBlockedModelAction.INSTANCE, new UpdateBlockedModelRequest(modelId, true), - ActionListener.wrap(acknowledgedResponse -> step.onFailure(exceptionFromPreviousStep), unblockingFailedException -> { - String errorMsg = exceptionFromPreviousStep.getMessage(); - errorMsg = errorMsg + "\n" + unblockingFailedException.getMessage(); - step.onFailure(new Exception(errorMsg)); - }) + ActionListener.wrap( + acknowledgedResponse -> step.onFailure(exceptionFromPreviousStep), + unblockingFailedException -> step.onFailure(exceptionFromPreviousStep) + ) ); } diff --git a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java new file mode 100644 index 000000000..117046e25 --- /dev/null +++ b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java @@ -0,0 +1,236 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.indices; + +import lombok.extern.log4j.Log4j2; +import org.opensearch.Version; +import org.opensearch.cluster.Diff; +import org.opensearch.cluster.NamedDiff; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Collections; +import java.util.EnumSet; +import java.util.HashSet; +import java.util.Set; +import java.util.stream.Collectors; +import com.google.common.collect.Sets; + +/** + * This class implements Metadata.Custom Interface to store a set of modelIds in the cluster metadata. + * The modelIds of the models that are under deletion are added to this set and later removed from this set after deletion. + * Also, this class implements the methods to perform operations on this set (like add, remove, contains) + */ +@Log4j2 +public class ModelGraveyard implements Metadata.Custom { + public static final String TYPE = "opensearch-knn-blocked-models"; + private final Set modelGraveyard; + + /** + * Constructor + * @param modelGraveyard Set which contains blocked model Ids + */ + public ModelGraveyard(Set modelGraveyard) { + this.modelGraveyard = modelGraveyard; + } + + /** + * Default Constructor to initialize object when it is null + */ + public ModelGraveyard() { + this.modelGraveyard = new HashSet<>(); + } + + /** + * @param in input stream + * @throws IOException if read from stream fails + */ + public ModelGraveyard(StreamInput in) throws IOException { + this.modelGraveyard = new HashSet<>(in.readStringList()); + } + + @Override + public EnumSet context() { + return Metadata.ALL_CONTEXTS; + } + + /** + * @return WriteableName for ModelGraveyard + */ + @Override + public String getWriteableName() { + return TYPE; + } + + @Override + public Version getMinimalSupportedVersion() { + return Version.CURRENT.minimumCompatibilityVersion(); + } + + /** + * @param out output stream + * @throws IOException if write to stream fails + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(modelGraveyard); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder; + } + + /** + * @param modelId id of the model that needs to be removed from modelGraveyard set + */ + public void remove(String modelId) { + modelGraveyard.remove(modelId); + } + + /** + * @param modelId id of the model that needs to be added to modelGraveyard set + */ + public void add(String modelId) { + modelGraveyard.add(modelId); + } + + /** + * @return number of modelIds in modelGraveyard set + */ + public int size() { + return modelGraveyard.size(); + } + + /** + * @param modelId to check if the id of given model is there in modelGraveyard set + * @return true if the modelId is in the modelGraveyard set, otherwise false + */ + public boolean contains(String modelId) { + return modelGraveyard.contains(modelId); + } + + /** + * @param before The previous custom metadata object + * @return the diff between the current updated object and the previous object + */ + @Override + public Diff diff(Metadata.Custom before) { + return new ModelGraveyardDiff((ModelGraveyard) before, this); + } + + /** + * @param streamInput input stream + * @return ModelGraveyardDiff + * @throws IOException if read from stream fails + */ + public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException { + return new ModelGraveyardDiff(streamInput); + } + + /** + * @param xContentParser + * @return ModelGraveyard + * @throws IOException + */ + public static ModelGraveyard fromXContent(XContentParser xContentParser) throws IOException { + return new ModelGraveyard(xContentParser.list().stream().map(Object::toString).collect(Collectors.toSet())); + } + + /** + * The ModelGraveyardDiff class compares the previous modelGraveyard object with the current updated modelGraveyard object + * and returns only the diff of those 2 objects. So that, whenever there is a change in cluster state, master node only + * sends the diff to all the data nodes instead of the full cluster state + */ + static class ModelGraveyardDiff implements NamedDiff { + private final Set added; + private final Set removed; + + /** + * @param inp input stream + * @throws IOException if read from stream fails + */ + public ModelGraveyardDiff(StreamInput inp) throws IOException { + added = Set.copyOf(inp.readStringList()); + removed = Set.copyOf(inp.readStringList()); + } + + /** + * @param previous previous ModelGraveyard object + * @param current current updated ModelGraveyard object + * + * Constructor which compares both the objects to find the entries that are newly added in current object, + * entries that are deleted from previous object and the deleted entries count + */ + public ModelGraveyardDiff(ModelGraveyard previous, ModelGraveyard current) { + final Set previousModelGraveyard = previous.modelGraveyard; + final Set currentModelGraveyard = current.modelGraveyard; + final Set added, removed; + if (previousModelGraveyard.isEmpty()) { + // nothing will have been removed in previous object, and all entries in current object are new + added = new HashSet<>(currentModelGraveyard); + removed = new HashSet<>(); + } else if (currentModelGraveyard.isEmpty()) { + // nothing will have been added to current object, and all entries in previous object are removed + added = new HashSet<>(); + removed = new HashSet<>(previousModelGraveyard); + } else { + // some entries in previous object are removed and few entries are added to current object + removed = Sets.difference(previousModelGraveyard, currentModelGraveyard); + added = Sets.difference(currentModelGraveyard, previousModelGraveyard); + } + this.added = Collections.unmodifiableSet(added); + this.removed = Collections.unmodifiableSet(removed); + } + + /** + * @param previous Previous custom metadata object + * @return ModelGraveyard object after calculating the diff + */ + @Override + public ModelGraveyard apply(Metadata.Custom previous) { + final ModelGraveyard old = (ModelGraveyard) previous; + int removedCount = removed.size(); + if (removedCount > old.size()) { + throw new IllegalStateException( + "ModelGraveyardDiff cannot remove [" + removedCount + "] entries from [" + old.size() + "] modelIds." + ); + } + Set updatedOldGraveyardSet = Sets.difference(old.modelGraveyard, removed); + Set modelGraveyardDiffSet = new HashSet<>(); + modelGraveyardDiffSet.addAll(added); + modelGraveyardDiffSet.addAll(updatedOldGraveyardSet); + return new ModelGraveyard(modelGraveyardDiffSet); + } + + public Set getAdded() { + return added; + } + + public Set getRemoved() { + return removed; + } + + @Override + public String getWriteableName() { + return TYPE; + } + + /** + * @param out output stream + * @throws IOException if write to stream fails + */ + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(added); + out.writeStringCollection(removed); + } + } +} diff --git a/src/main/java/org/opensearch/knn/plugin/BlockedModelIds.java b/src/main/java/org/opensearch/knn/plugin/BlockedModelIds.java deleted file mode 100644 index 8d093cb67..000000000 --- a/src/main/java/org/opensearch/knn/plugin/BlockedModelIds.java +++ /dev/null @@ -1,137 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.plugin; - -import org.apache.logging.log4j.LogManager; -import org.opensearch.Version; -import org.opensearch.cluster.Diff; -import org.opensearch.cluster.NamedDiff; -import org.opensearch.cluster.metadata.Metadata; -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.common.xcontent.XContentBuilder; -import org.opensearch.common.xcontent.XContentParser; -import org.apache.logging.log4j.Logger; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.EnumSet; -import java.util.List; -import java.util.stream.Collectors; - -public class BlockedModelIds implements Metadata.Custom { - - public static Logger logger = LogManager.getLogger(BlockedModelIds.class); - public static final String TYPE = "opensearch-knn-blocked-models"; - - List blockedModelIds; - - public BlockedModelIds(List blockedModelIds) { - this.blockedModelIds = blockedModelIds; - } - - @Override - public EnumSet context() { - return Metadata.ALL_CONTEXTS; - } - - @Override - public Diff diff(Metadata.Custom custom) { - return null; - } - - @Override - public String getWriteableName() { - return TYPE; - } - - @Override - public Version getMinimalSupportedVersion() { - return Version.CURRENT.minimumCompatibilityVersion(); - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(blockedModelIds.size()); - for (String modelId : blockedModelIds) { - out.writeString(modelId); - } - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - return null; - } - - public BlockedModelIds(StreamInput in) throws IOException { - int size = in.readVInt(); - List modelIds = new ArrayList<>(size); - for (int i = 0; i < size; i++) { - String modelId = in.readString(); - modelIds.add(modelId); - } - this.blockedModelIds = modelIds; - } - - public List getBlockedModelIds() { - return blockedModelIds; - } - - public void remove(String modelId) { - if (blockedModelIds.contains(modelId)) { - blockedModelIds.remove(modelId); - } - } - - public void add(String modelId) { - blockedModelIds.add(modelId); - } - - public int size() { - return blockedModelIds.size(); - } - - public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException { - return new BlockedModelIdsDiff(streamInput); - } - - public static BlockedModelIds fromXContent(XContentParser xContentParser) throws IOException { - List modelIds = xContentParser.list().stream().map(obj -> obj.toString()).collect(Collectors.toList()); - return new BlockedModelIds(modelIds); - } - - public boolean contains(String modelId) { - return blockedModelIds.contains(modelId); - } - - static class BlockedModelIdsDiff implements NamedDiff { - private List added; - private int removedCount; - - public BlockedModelIdsDiff(StreamInput inp) throws IOException { - added = inp.readList((streamInput -> streamInput.toString())); - removedCount = inp.readVInt(); - } - - @Override - public Metadata.Custom apply(Metadata.Custom custom) { - return null; - } - - @Override - public String getWriteableName() { - return TYPE; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeVInt(removedCount); - for (String modelId : added) { - out.writeString(modelId); - } - } - } -} diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index d309a06ab..675edfec9 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -20,6 +20,7 @@ import org.opensearch.knn.index.KNNWeight; import org.opensearch.knn.index.codec.KNNCodecService; import org.opensearch.knn.index.memory.NativeMemoryLoadStrategy; +import org.opensearch.knn.indices.ModelGraveyard; import org.opensearch.knn.indices.ModelCache; import org.opensearch.knn.indices.ModelDao; import org.opensearch.knn.plugin.rest.RestDeleteModelHandler; @@ -306,7 +307,7 @@ public List> getExecutorBuilders(Settings settings) { @Override public List getNamedWriteables() { List entries = new ArrayList<>(); - registerMetadataCustom(entries, BlockedModelIds.TYPE, BlockedModelIds::new, BlockedModelIds::readDiffFrom); + registerMetadataCustom(entries, ModelGraveyard.TYPE, ModelGraveyard::new, ModelGraveyard::readDiffFrom); return entries; } @@ -315,7 +316,7 @@ public List getNamedXContent() { List entries = new ArrayList<>(); entries.add( - new NamedXContentRegistry.Entry(Metadata.Custom.class, new ParseField(BlockedModelIds.TYPE), BlockedModelIds::fromXContent) + new NamedXContentRegistry.Entry(Metadata.Custom.class, new ParseField(ModelGraveyard.TYPE), ModelGraveyard::fromXContent) ); return entries; } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index 66cd484d5..fd1629afc 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -260,13 +260,13 @@ public ActionRequestValidationException validate() { return exception; } - // Check if modelId is in blocked list - // ModelId is added to blocked list if that model is undergoing deletion - // and will be removed from blocked list after model is deleted - if (modelDao.isModelBlocked(modelId)) { + // Check if modelId is in blocked model set + // ModelId is added to blocked set if that model is undergoing deletion + // and will be removed from blocked set after model is deleted + if (modelDao.isModelBlockedForDelete(modelId)) { exception = new ActionRequestValidationException(); String errorMessage = String.format( - "\"%s\" is in blocked list. Cannot create a model with same modelID until that model is deleted", + "\"%s\" is blocked. Cannot create a model with same modelID until that model is deleted", modelId ); exception.addValidationError(errorMessage); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java index 300231a8b..b80330908 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java @@ -5,6 +5,7 @@ package org.opensearch.knn.plugin.transport; +import lombok.Getter; import org.opensearch.action.ActionRequestValidationException; import org.opensearch.action.support.master.AcknowledgedRequest; import org.opensearch.common.io.stream.StreamInput; @@ -19,8 +20,10 @@ */ public class UpdateBlockedModelRequest extends AcknowledgedRequest { - private String modelId; - private boolean isRemoveRequest; + @Getter + private final String modelId; + @Getter + private final boolean isRemoveRequest; /** * Constructor @@ -57,24 +60,6 @@ public ActionRequestValidationException validate() { return validationException; } - /** - * Getter for modelId - * - * @return modelId - */ - public String getModelId() { - return modelId; - } - - /** - * Getter for isRemoveRequest - * - * @return isRemoveRequest - */ - public boolean isRemoveRequest() { - return isRemoveRequest; - } - @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java index 24df82471..1703e1596 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java @@ -5,8 +5,7 @@ package org.opensearch.knn.plugin.transport; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; +import lombok.extern.log4j.Log4j2; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.master.AcknowledgedResponse; @@ -22,23 +21,22 @@ import org.opensearch.common.Priority; import org.opensearch.common.inject.Inject; import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.knn.plugin.BlockedModelIds; +import org.opensearch.knn.indices.ModelGraveyard; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; import java.io.IOException; -import java.util.ArrayList; import java.util.List; +import java.util.Objects; +import lombok.AllArgsConstructor; import static org.opensearch.knn.common.KNNConstants.PLUGIN_NAME; /** - * Transport action used to update blocked modelIds list on the cluster manager node. + * Transport action used to update blocked modelIds (ModelGraveyard) on the cluster manager node. */ +@Log4j2 public class UpdateBlockedModelTransportAction extends TransportMasterNodeAction { - - public static Logger logger = LogManager.getLogger(UpdateBlockedModelTransportAction.class); - private UpdateBlockedModelExecutor updateBlockedModelExecutor; @Inject @@ -105,49 +103,46 @@ protected ClusterBlockException checkBlock(UpdateBlockedModelRequest request, Cl /** * UpdateBlockedModelTask is used to provide the executor with the information it needs to perform its task */ + @AllArgsConstructor private static class UpdateBlockedModelTask { - private String modelId; private boolean isRemoveRequest; - - /** - * Constructor - * - * @param modelId id of model - * @param isRemoveRequest should this modelId be removed - */ - UpdateBlockedModelTask(String modelId, boolean isRemoveRequest) { - this.modelId = modelId; - this.isRemoveRequest = isRemoveRequest; - } } + /** + * Updates the cluster state based on the UpdateBlockedModelTask + */ private static class UpdateBlockedModelExecutor implements ClusterStateTaskExecutor { + /** + * @param clusterState ClusterState + * @param taskList contains the list of UpdateBlockedModelTask request parameters (modelId and isRemoveRequest) + * @return Represents the result of a batched execution of cluster state update tasks (UpdateBlockedModelTasks) + */ @Override - public ClusterTasksResult execute(ClusterState clusterState, List list) { + public ClusterTasksResult execute(ClusterState clusterState, List taskList) { - BlockedModelIds immutableBlockedModelIds = clusterState.metadata().custom(BlockedModelIds.TYPE); - BlockedModelIds blockedModelIds; + // Check if the objects are not null and throw a customized NullPointerException + Objects.requireNonNull(clusterState, "Cluster state must not be null"); + Objects.requireNonNull(clusterState.metadata(), "Cluster metadata must not be null"); + ModelGraveyard modelGraveyard = clusterState.metadata().custom(ModelGraveyard.TYPE); - if (immutableBlockedModelIds == null) { - blockedModelIds = new BlockedModelIds(new ArrayList<>()); - } else { - blockedModelIds = immutableBlockedModelIds; + if (modelGraveyard == null) { + modelGraveyard = new ModelGraveyard(); } - for (UpdateBlockedModelTask task : list) { + for (UpdateBlockedModelTask task : taskList) { if (task.isRemoveRequest) { - blockedModelIds.remove(task.modelId); - } else { - blockedModelIds.add(task.modelId); + modelGraveyard.remove(task.modelId); + continue; } + modelGraveyard.add(task.modelId); } Metadata.Builder metaDataBuilder = Metadata.builder(clusterState.metadata()); - metaDataBuilder.putCustom(BlockedModelIds.TYPE, blockedModelIds); + metaDataBuilder.putCustom(ModelGraveyard.TYPE, modelGraveyard); ClusterState updatedClusterState = ClusterState.builder(clusterState).metadata(metaDataBuilder).build(); - return new ClusterTasksResult.Builder().successes(list).build(updatedClusterState); + return new ClusterTasksResult.Builder().successes(taskList).build(updatedClusterState); } } } diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 5f4527427..9bcdd14c4 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -54,6 +54,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; +import static org.opensearch.cluster.metadata.Metadata.builder; import static org.opensearch.knn.common.KNNConstants.DIMENSION; import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; @@ -585,6 +586,46 @@ public void testDelete() throws IOException, InterruptedException { assertTrue(inProgressLatch3.await(100, TimeUnit.SECONDS)); } + public void testDeleteModelInTrainingWithStepListeners() throws IOException, ExecutionException, InterruptedException { + String modelId = "test-model-id-training"; + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + byte[] modelBlob = "deleteModel".getBytes(); + int dimension = 2; + createIndex(MODEL_INDEX_NAME); + + Model model = new Model( + new ModelMetadata( + KNNEngine.DEFAULT, + SpaceType.DEFAULT, + dimension, + ModelState.TRAINING, + ZonedDateTime.now(ZoneOffset.UTC).toString(), + "", + "" + ), + modelBlob, + modelId + ); + + // created model and added it to index + addDoc(model); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + + StepListener getModelStep = new StepListener<>(); + + modelDao.get(modelId, ActionListener.wrap(getModelStep::onResponse, getModelStep::onFailure)); + + // Asserting that model is in TRAINING state + getModelStep.whenComplete(getModelResponse -> { + assertEquals(model.getModelMetadata().getState(), getModelResponse.getModel().getModelMetadata().getState()); + assertEquals(ModelState.TRAINING, getModelResponse.getModel().getModelMetadata().getState()); + + inProgressLatch.countDown(); + }, exception -> fail(exception.getMessage())); + assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); + } + public void testDeleteWithStepListeners() throws IOException, InterruptedException, ExecutionException { String modelId = "test-model-id-delete"; ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); @@ -634,7 +675,7 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti blockModelIdStep.whenComplete(acknowledgedResponse -> { // Asserting that modelId is in blocked list - assertTrue(modelDao.isModelBlocked(modelId)); + assertTrue(modelDao.isModelBlockedForDelete(modelId)); client().execute( UpdateModelMetadataAction.INSTANCE, @@ -673,12 +714,13 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti new UpdateBlockedModelRequest(modelId, true), ActionListener.wrap(unblockModelIdStep::onResponse, unblockModelIdStep::onFailure) ); - }, exception -> fail(exception.getMessage())); - unblockModelIdStep.whenComplete(acknowledgedResponse -> { - // Asserting that model is unblocked - assertFalse(modelDao.isModelBlocked(modelId)); - inProgressLatch.countDown(); + unblockModelIdStep.whenComplete(acknowledgedResponse -> { + // Asserting that model is unblocked + assertFalse(modelDao.isModelBlockedForDelete(modelId)); + assertFalse(removeModelFromCacheResponse.hasFailures()); + inProgressLatch.countDown(); + }, exception -> fail(exception.getMessage())); }, exception -> fail(exception.getMessage())); assertTrue(inProgressLatch.await(100, TimeUnit.SECONDS)); @@ -722,7 +764,7 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt // Asserting that the modelId is blocked blockModelIdStep.whenComplete(acknowledgedResponse -> { - assertTrue(modelDao.isModelBlocked(modelId)); + assertTrue(modelDao.isModelBlockedForDelete(modelId)); // Sending empty string for modelId to fail the clear model metadata request client().execute( @@ -731,14 +773,14 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt ActionListener.wrap(clearModelMetadataStep::onResponse, exp -> { // Asserting that modelId is still blocked and clearModelMetadata throws an exception assertNotNull(exp.getMessage()); - assertTrue(modelDao.isModelBlocked(modelId)); + assertTrue(modelDao.isModelBlockedForDelete(modelId)); client().execute( // OnFailure sending request to unblock modelId UpdateBlockedModelAction.INSTANCE, new UpdateBlockedModelRequest(modelId, true), ActionListener.wrap(ackResponse -> { // Asserting that model is unblocked - assertFalse(modelDao.isModelBlocked(modelId)); + assertFalse(modelDao.isModelBlockedForDelete(modelId)); assertNotNull(exp.getMessage()); }, exception -> fail(exception.getMessage())) ); @@ -764,7 +806,7 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt // Asserting that the modelId is blocked blockModelIdStep1.whenComplete(acknowledgedResponse -> { - assertTrue(modelDao.isModelBlocked(modelId)); + assertTrue(modelDao.isModelBlockedForDelete(modelId)); // Sending empty string for modelId to fail the clear model metadata request client().execute( @@ -772,7 +814,7 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt new UpdateModelMetadataRequest("", true, null), ActionListener.wrap(clearModelMetadataStep1::onResponse, exp -> { assertNotNull(exp.getMessage()); - assertTrue(modelDao.isModelBlocked(modelId)); + assertTrue(modelDao.isModelBlockedForDelete(modelId)); // Failing unblock modelId request by sending modelId as an empty string client().execute( @@ -780,7 +822,7 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt new UpdateBlockedModelRequest("", true), ActionListener.wrap(ackResponse -> {}, unblockingFailedException -> { // Asserting that model is still blocked and returns both exceptions in response - assertTrue(modelDao.isModelBlocked(modelId)); + assertTrue(modelDao.isModelBlockedForDelete(modelId)); assertNotNull(exp.getMessage()); assertNotNull(unblockingFailedException.getMessage()); }) diff --git a/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java new file mode 100644 index 000000000..e50d7ca37 --- /dev/null +++ b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java @@ -0,0 +1,138 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.indices; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + +public class ModelGraveyardTests extends OpenSearchTestCase { + + public void testAdd() { + ModelGraveyard testModelGraveyard = new ModelGraveyard(); + String testModelId = "test-model-id"; + testModelGraveyard.add(testModelId); + assertTrue(testModelGraveyard.contains(testModelId)); + } + + public void testRemove() { + Set modelIds = new HashSet<>(); + String testModelId = "test-model-id"; + modelIds.add(testModelId); + ModelGraveyard testModelGraveyard = new ModelGraveyard(modelIds); + + assertTrue(testModelGraveyard.contains(testModelId)); + testModelGraveyard.remove(testModelId); + assertFalse(testModelGraveyard.contains(testModelId)); + } + + public void testContains() { + Set modelIds = new HashSet<>(); + String testModelId = "test-model-id"; + modelIds.add(testModelId); + + ModelGraveyard testModelGraveyard = new ModelGraveyard(modelIds); + assertTrue(testModelGraveyard.contains(testModelId)); + } + + public void testStreams() throws IOException { + Set modelIds = new HashSet<>(); + String testModelId = "test-model-id"; + modelIds.add(testModelId); + ModelGraveyard testModelGraveyard = new ModelGraveyard(modelIds); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + testModelGraveyard.writeTo(streamOutput); + + ModelGraveyard testModelGraveyardCopy = new ModelGraveyard(streamOutput.bytes().streamInput()); + + assertEquals(testModelGraveyard.size(), testModelGraveyardCopy.size()); + assertTrue(testModelGraveyard.contains(testModelId)); + assertTrue(testModelGraveyardCopy.contains(testModelId)); + } + + public void testDiffStreams() throws IOException { + Set added = new HashSet<>(); + Set removed = new HashSet<>(); + String testModelId = "test-model-id"; + String testModelId1 = "test-model-id-1"; + added.add(testModelId); + removed.add(testModelId1); + + ModelGraveyard modelGraveyardCurrent = new ModelGraveyard(added); + ModelGraveyard modelGraveyardPrevious = new ModelGraveyard(removed); + + ModelGraveyard.ModelGraveyardDiff modelGraveyardDiff = new ModelGraveyard.ModelGraveyardDiff( + modelGraveyardPrevious, + modelGraveyardCurrent + ); + assertEquals(added, modelGraveyardDiff.getAdded()); + assertEquals(removed, modelGraveyardDiff.getRemoved()); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + modelGraveyardDiff.writeTo(streamOutput); + + ModelGraveyard.ModelGraveyardDiff modelGraveyardDiffCopy = new ModelGraveyard.ModelGraveyardDiff( + streamOutput.bytes().streamInput() + ); + assertEquals(added, modelGraveyardDiffCopy.getAdded()); + assertEquals(removed, modelGraveyardDiffCopy.getRemoved()); + } + + public void testDiff() { + + // nothing will have been removed in previous object, and all entries in current object are new + ModelGraveyard modelGraveyard1 = new ModelGraveyard(); + + Set modelIds = new HashSet<>(); + modelIds.add("1"); + modelIds.add("2"); + ModelGraveyard modelGraveyard2 = new ModelGraveyard(modelIds); + + ModelGraveyard.ModelGraveyardDiff diff1 = new ModelGraveyard.ModelGraveyardDiff(modelGraveyard1, modelGraveyard2); + Set added1 = diff1.getAdded(); + assertEquals(0, diff1.getRemoved().size()); + assertEquals(2, added1.size()); + + ModelGraveyard updatedGraveyard1 = diff1.apply(modelGraveyard1); + assertEquals(2, updatedGraveyard1.size()); + assertTrue(updatedGraveyard1.contains("1")); + assertTrue(updatedGraveyard1.contains("2")); + + // nothing will have been added to current object, and all entries in previous object are removed + ModelGraveyard modelGraveyard3 = new ModelGraveyard(); + ModelGraveyard.ModelGraveyardDiff diff2 = new ModelGraveyard.ModelGraveyardDiff(modelGraveyard2, modelGraveyard3); + Set added2 = diff2.getAdded(); + assertEquals(2, diff2.getRemoved().size()); + assertEquals(0, added2.size()); + + ModelGraveyard updatedGraveyard2 = diff2.apply(modelGraveyard2); + assertEquals(0, updatedGraveyard2.size()); + + // some entries in previous object are removed and few entries are added to current object + modelIds = new HashSet<>(); + modelIds.add("1"); + modelIds.add("3"); + modelIds.add("4"); + ModelGraveyard modelGraveyard4 = new ModelGraveyard(modelIds); + + ModelGraveyard.ModelGraveyardDiff diff3 = new ModelGraveyard.ModelGraveyardDiff(modelGraveyard2, modelGraveyard4); + Set added3 = diff3.getAdded(); + assertEquals(1, diff3.getRemoved().size()); + assertEquals(2, added3.size()); + + ModelGraveyard updatedGraveyard3 = diff3.apply(modelGraveyard2); + assertEquals(3, updatedGraveyard3.size()); + assertTrue(updatedGraveyard3.contains("1")); + assertTrue(updatedGraveyard3.contains("3")); + assertTrue(updatedGraveyard3.contains("4")); + assertFalse(updatedGraveyard3.contains("2")); + } + +} diff --git a/src/test/java/org/opensearch/knn/plugin/BlockedModelIdsTests.java b/src/test/java/org/opensearch/knn/plugin/BlockedModelIdsTests.java deleted file mode 100644 index 0fa6b796d..000000000 --- a/src/test/java/org/opensearch/knn/plugin/BlockedModelIdsTests.java +++ /dev/null @@ -1,60 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.plugin; - -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.test.OpenSearchTestCase; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - -public class BlockedModelIdsTests extends OpenSearchTestCase { - - public void testAdd() { - List modelIds = new ArrayList<>(); - BlockedModelIds testBlockedModelIds = new BlockedModelIds(modelIds); - String testModelId = "test-model-id"; - testBlockedModelIds.add(testModelId); - assertTrue(testBlockedModelIds.contains(testModelId)); - } - - public void testRemove() { - List modelIds = new ArrayList<>(); - String testModelId = "test-model-id"; - modelIds.add(testModelId); - BlockedModelIds testBlockedModelIds = new BlockedModelIds(modelIds); - - assertTrue(testBlockedModelIds.contains(testModelId)); - testBlockedModelIds.remove(testModelId); - assertFalse(testBlockedModelIds.contains(testModelId)); - } - - public void testContains() { - List modelIds = new ArrayList<>(); - String testModelId = "test-model-id"; - modelIds.add(testModelId); - - BlockedModelIds testBlockedModelIds = new BlockedModelIds(modelIds); - assertTrue(testBlockedModelIds.contains(testModelId)); - } - - public void testStreams() throws IOException { - List modelIds = new ArrayList<>(); - String testModelId = "test-model-id"; - modelIds.add(testModelId); - BlockedModelIds testBlockedModelIds = new BlockedModelIds(modelIds); - - BytesStreamOutput streamOutput = new BytesStreamOutput(); - testBlockedModelIds.writeTo(streamOutput); - - BlockedModelIds testBlockedModelIdsCopy = new BlockedModelIds(streamOutput.bytes().streamInput()); - - assertEquals(testBlockedModelIds.size(), testBlockedModelIdsCopy.size()); - assertEquals(testBlockedModelIds.getBlockedModelIds().get(0), testBlockedModelIdsCopy.getBlockedModelIds().get(0)); - } - -} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index a4bbffbaf..cdbb0c76f 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -187,7 +187,7 @@ public void testValidation_invalid_modelIdAlreadyExists() { } // Check that the validation produces an exception when we are - // training a model with modelId that is in blocked list + // training a model with modelId that is in blocked set public void testValidation_blocked_modelId() { // Setup the training request @@ -211,7 +211,7 @@ public void testValidation_blocked_modelId() { // Mock the model dao to return true to recognize that the modelId is blocked ModelDao modelDao = mock(ModelDao.class); - when(modelDao.isModelBlocked(modelId)).thenReturn(true); + when(modelDao.isModelBlockedForDelete(modelId)).thenReturn(true); // This cluster service will result in no validation exceptions ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension); @@ -219,12 +219,12 @@ public void testValidation_blocked_modelId() { // Initialize static components with the mocks TrainingModelRequest.initialize(modelDao, clusterService); - // Test that validation produces an error message that modelId is in blocked list + // Test that validation produces an error message that modelId is blocked ActionRequestValidationException exception = trainingModelRequest.validate(); assertNotNull(exception); List validationErrors = exception.validationErrors(); assertEquals(1, validationErrors.size()); - assertTrue(validationErrors.get(0).contains("in blocked list")); + assertTrue(validationErrors.get(0).contains("is blocked")); } public void testValidation_invalid_invalidMethodContext() { diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java index 82979f443..6d5708f01 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java @@ -10,7 +10,7 @@ import org.opensearch.cluster.ClusterState; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.knn.KNNSingleNodeTestCase; -import org.opensearch.knn.plugin.BlockedModelIds; +import org.opensearch.knn.indices.ModelGraveyard; import org.opensearch.threadpool.ThreadPool; import java.io.IOException; @@ -44,7 +44,7 @@ public void testClusterManagerOperation() throws InterruptedException { UpdateBlockedModelTransportAction updateBlockedModelTransportAction = node().injector() .getInstance(UpdateBlockedModelTransportAction.class); - // Generate update request to add modelId to blocked list + // Generate update request to add modelId to blocked set (ModelGraveyard) UpdateBlockedModelRequest addBlockedModelRequest = new UpdateBlockedModelRequest(modelId, false); // Get cluster state, update metadata, check cluster state - all asynchronously @@ -59,11 +59,11 @@ public void testClusterManagerOperation() throws InterruptedException { client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse2 -> { ClusterState updatedClusterState = stateResponse2.getState(); - BlockedModelIds blockedModelIds = updatedClusterState.metadata().custom(BlockedModelIds.TYPE); + ModelGraveyard modelGraveyard = updatedClusterState.metadata().custom(ModelGraveyard.TYPE); - assertNotNull(blockedModelIds); - assertEquals(1, blockedModelIds.size()); - assertTrue(blockedModelIds.contains(modelId)); + assertNotNull(modelGraveyard); + assertEquals(1, modelGraveyard.size()); + assertTrue(modelGraveyard.contains(modelId)); inProgressLatch1.countDown(); @@ -74,7 +74,7 @@ public void testClusterManagerOperation() throws InterruptedException { assertTrue(inProgressLatch1.await(60, TimeUnit.SECONDS)); - // Generate remove request to remove the modelId from blocked list + // Generate remove request to remove the modelId from blocked set (ModelGraveyard) UpdateBlockedModelRequest removeBlockedModelRequest = new UpdateBlockedModelRequest(modelId, true); final CountDownLatch inProgressLatch2 = new CountDownLatch(1); @@ -88,11 +88,11 @@ public void testClusterManagerOperation() throws InterruptedException { client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse2 -> { ClusterState updatedClusterState = stateResponse2.getState(); - BlockedModelIds blockedModelIds = updatedClusterState.metadata().custom(BlockedModelIds.TYPE); + ModelGraveyard modelGraveyard = updatedClusterState.metadata().custom(ModelGraveyard.TYPE); - assertNotNull(blockedModelIds); - assertEquals(0, blockedModelIds.size()); - assertFalse(blockedModelIds.contains(modelId)); + assertNotNull(modelGraveyard); + assertEquals(0, modelGraveyard.size()); + assertFalse(modelGraveyard.contains(modelId)); inProgressLatch2.countDown(); }, e -> fail("Update failed"))); From 72a5bba6ab0831d6ac4132d2d560e6ae1cb3b646 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Wed, 6 Jul 2022 21:50:21 -0500 Subject: [PATCH 6/9] Bug fix for copying ModelGraveyard reference Signed-off-by: Naveen Tatikonda --- .../org/opensearch/knn/indices/ModelGraveyard.java | 9 ++++++++- .../UpdateBlockedModelTransportAction.java | 14 ++++++++++++-- .../knn/indices/ModelGraveyardTests.java | 9 +++------ 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java index 117046e25..f2d695ffc 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java +++ b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java @@ -102,6 +102,13 @@ public void add(String modelId) { modelGraveyard.add(modelId); } + /** + * @return Set of modelIds in modelGraveyard + */ + public Set getModelGraveyard() { + return modelGraveyard; + } + /** * @return number of modelIds in modelGraveyard set */ @@ -149,7 +156,7 @@ public static ModelGraveyard fromXContent(XContentParser xContentParser) throws * and returns only the diff of those 2 objects. So that, whenever there is a change in cluster state, master node only * sends the diff to all the data nodes instead of the full cluster state */ - static class ModelGraveyardDiff implements NamedDiff { + public static class ModelGraveyardDiff implements NamedDiff { private final Set added; private final Set removed; diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java index 1703e1596..d3c6ba849 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportAction.java @@ -26,8 +26,11 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.HashSet; import java.util.List; import java.util.Objects; +import java.util.Set; + import lombok.AllArgsConstructor; import static org.opensearch.knn.common.KNNConstants.PLUGIN_NAME; @@ -124,10 +127,17 @@ public ClusterTasksResult execute(ClusterState clusterSt // Check if the objects are not null and throw a customized NullPointerException Objects.requireNonNull(clusterState, "Cluster state must not be null"); Objects.requireNonNull(clusterState.metadata(), "Cluster metadata must not be null"); - ModelGraveyard modelGraveyard = clusterState.metadata().custom(ModelGraveyard.TYPE); + ModelGraveyard immutableModelGraveyard = clusterState.metadata().custom(ModelGraveyard.TYPE); + ModelGraveyard modelGraveyard; + Set copySet; - if (modelGraveyard == null) { + if (immutableModelGraveyard == null) { modelGraveyard = new ModelGraveyard(); + } else { + // Deep Copy to copy all the modelIds in ModelGraveyard to local object + // to avoid copying the reference + copySet = new HashSet<>(immutableModelGraveyard.getModelGraveyard()); + modelGraveyard = new ModelGraveyard(copySet); } for (UpdateBlockedModelTask task : taskList) { diff --git a/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java index e50d7ca37..28b3c7474 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java @@ -96,9 +96,8 @@ public void testDiff() { ModelGraveyard modelGraveyard2 = new ModelGraveyard(modelIds); ModelGraveyard.ModelGraveyardDiff diff1 = new ModelGraveyard.ModelGraveyardDiff(modelGraveyard1, modelGraveyard2); - Set added1 = diff1.getAdded(); assertEquals(0, diff1.getRemoved().size()); - assertEquals(2, added1.size()); + assertEquals(2, diff1.getAdded().size()); ModelGraveyard updatedGraveyard1 = diff1.apply(modelGraveyard1); assertEquals(2, updatedGraveyard1.size()); @@ -108,9 +107,8 @@ public void testDiff() { // nothing will have been added to current object, and all entries in previous object are removed ModelGraveyard modelGraveyard3 = new ModelGraveyard(); ModelGraveyard.ModelGraveyardDiff diff2 = new ModelGraveyard.ModelGraveyardDiff(modelGraveyard2, modelGraveyard3); - Set added2 = diff2.getAdded(); assertEquals(2, diff2.getRemoved().size()); - assertEquals(0, added2.size()); + assertEquals(0, diff2.getAdded().size()); ModelGraveyard updatedGraveyard2 = diff2.apply(modelGraveyard2); assertEquals(0, updatedGraveyard2.size()); @@ -123,9 +121,8 @@ public void testDiff() { ModelGraveyard modelGraveyard4 = new ModelGraveyard(modelIds); ModelGraveyard.ModelGraveyardDiff diff3 = new ModelGraveyard.ModelGraveyardDiff(modelGraveyard2, modelGraveyard4); - Set added3 = diff3.getAdded(); assertEquals(1, diff3.getRemoved().size()); - assertEquals(2, added3.size()); + assertEquals(2, diff3.getAdded().size()); ModelGraveyard updatedGraveyard3 = diff3.apply(modelGraveyard2); assertEquals(3, updatedGraveyard3.size()); From 4633db84b67ebaec463ec6b87bc087777c48e3a0 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Wed, 6 Jul 2022 21:51:00 -0500 Subject: [PATCH 7/9] Add Integration Tests for ModelGraveyardDiff Signed-off-by: Naveen Tatikonda --- ...pdateBlockedModelTransportActionTests.java | 65 +++++++++++++++++-- 1 file changed, 61 insertions(+), 4 deletions(-) diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java index 6d5708f01..8139ef294 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java @@ -74,10 +74,53 @@ public void testClusterManagerOperation() throws InterruptedException { assertTrue(inProgressLatch1.await(60, TimeUnit.SECONDS)); + String modelId1 = "test-model-id-1"; + // Generate update request to add modelId1 to blocked set (ModelGraveyard) + UpdateBlockedModelRequest addBlockedModelRequest1 = new UpdateBlockedModelRequest(modelId1, false); + + final CountDownLatch inProgressLatch2 = new CountDownLatch(1); + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { + ClusterState clusterState1 = stateResponse1.getState(); + updateBlockedModelTransportAction.masterOperation( + addBlockedModelRequest1, + clusterState1, + ActionListener.wrap(acknowledgedResponse -> { + assertTrue(acknowledgedResponse.isAcknowledged()); + + client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse2 -> { + ClusterState updatedClusterState = stateResponse2.getState(); + ModelGraveyard modelGraveyard = updatedClusterState.metadata().custom(ModelGraveyard.TYPE); + + assertNotNull(modelGraveyard); + assertEquals(2, modelGraveyard.size()); + assertTrue(modelGraveyard.contains(modelId1)); + + ModelGraveyard modelGraveyardPrev = clusterState1.metadata().custom(ModelGraveyard.TYPE); + assertFalse(modelGraveyardPrev.contains(modelId1)); + + // Assertions to validate ModelGraveyard Diff + ModelGraveyard.ModelGraveyardDiff diff = new ModelGraveyard.ModelGraveyardDiff(modelGraveyardPrev, modelGraveyard); + assertEquals(0, diff.getRemoved().size()); + assertEquals(1, diff.getAdded().size()); + assertTrue(diff.getAdded().contains(modelId1)); + + ModelGraveyard updatedModelGraveyard = diff.apply(modelGraveyardPrev); + assertEquals(2, updatedModelGraveyard.size()); + assertTrue(updatedModelGraveyard.contains(modelId)); + assertTrue(updatedModelGraveyard.contains(modelId1)); + + inProgressLatch2.countDown(); + }, e -> fail("Update failed"))); + }, e -> fail("Update failed")) + ); + }, e -> fail("Update failed"))); + + assertTrue(inProgressLatch2.await(60, TimeUnit.SECONDS)); + // Generate remove request to remove the modelId from blocked set (ModelGraveyard) UpdateBlockedModelRequest removeBlockedModelRequest = new UpdateBlockedModelRequest(modelId, true); - final CountDownLatch inProgressLatch2 = new CountDownLatch(1); + final CountDownLatch inProgressLatch3 = new CountDownLatch(1); client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { ClusterState clusterState1 = stateResponse1.getState(); updateBlockedModelTransportAction.masterOperation( @@ -91,16 +134,30 @@ public void testClusterManagerOperation() throws InterruptedException { ModelGraveyard modelGraveyard = updatedClusterState.metadata().custom(ModelGraveyard.TYPE); assertNotNull(modelGraveyard); - assertEquals(0, modelGraveyard.size()); + assertEquals(1, modelGraveyard.size()); assertFalse(modelGraveyard.contains(modelId)); - inProgressLatch2.countDown(); + ModelGraveyard modelGraveyardPrev = clusterState1.metadata().custom(ModelGraveyard.TYPE); + assertTrue(modelGraveyardPrev.contains(modelId)); + + // Assertions to validate ModelGraveyard Diff + ModelGraveyard.ModelGraveyardDiff diff = new ModelGraveyard.ModelGraveyardDiff(modelGraveyardPrev, modelGraveyard); + assertEquals(1, diff.getRemoved().size()); + assertEquals(0, diff.getAdded().size()); + assertTrue(diff.getRemoved().contains(modelId)); + + ModelGraveyard updatedModelGraveyard = diff.apply(modelGraveyardPrev); + assertEquals(1, updatedModelGraveyard.size()); + assertFalse(updatedModelGraveyard.contains(modelId)); + assertTrue(updatedModelGraveyard.contains(modelId1)); + + inProgressLatch3.countDown(); }, e -> fail("Update failed"))); }, e -> fail("Update failed")) ); }, e -> fail("Update failed"))); - assertTrue(inProgressLatch2.await(60, TimeUnit.SECONDS)); + assertTrue(inProgressLatch3.await(60, TimeUnit.SECONDS)); } public void testCheckBlock() { From a03f0a3781b11b964f7c6a7000777264294c7389 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Fri, 8 Jul 2022 10:19:16 -0500 Subject: [PATCH 8/9] Refactoring and Addressing other review comments Signed-off-by: Naveen Tatikonda --- .../org/opensearch/knn/indices/ModelDao.java | 117 +++++++++++------- .../knn/indices/ModelGraveyard.java | 59 ++++----- .../org/opensearch/knn/plugin/KNNPlugin.java | 31 +---- .../transport/TrainingModelRequest.java | 13 +- ...n.java => UpdateModelGraveyardAction.java} | 10 +- ....java => UpdateModelGraveyardRequest.java} | 8 +- ... UpdateModelGraveyardTransportAction.java} | 59 +++++---- .../opensearch/knn/indices/ModelDaoTests.java | 66 +++++----- .../action/RestDeleteModelHandlerIT.java | 16 +-- .../transport/TrainingModelRequestTests.java | 13 +- .../UpdateBlockedModelRequestTests.java | 58 --------- .../UpdateModelGraveyardRequestTests.java | 58 +++++++++ ...teModelGraveyardTransportActionTests.java} | 48 +++---- 13 files changed, 277 insertions(+), 279 deletions(-) rename src/main/java/org/opensearch/knn/plugin/transport/{UpdateBlockedModelAction.java => UpdateModelGraveyardAction.java} (59%) rename src/main/java/org/opensearch/knn/plugin/transport/{UpdateBlockedModelRequest.java => UpdateModelGraveyardRequest.java} (82%) rename src/main/java/org/opensearch/knn/plugin/transport/{UpdateBlockedModelTransportAction.java => UpdateModelGraveyardTransportAction.java} (67%) delete mode 100644 src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequestTests.java create mode 100644 src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequestTests.java rename src/test/java/org/opensearch/knn/plugin/transport/{UpdateBlockedModelTransportActionTests.java => UpdateModelGraveyardTransportActionTests.java} (77%) diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 5853b53d1..8014e73ad 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -15,6 +15,7 @@ import com.google.common.io.Resources; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchException; import org.opensearch.ResourceNotFoundException; import org.opensearch.action.ActionListener; import org.opensearch.action.DocWriteRequest; @@ -48,10 +49,10 @@ import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse; -import org.opensearch.knn.plugin.transport.UpdateBlockedModelAction; -import org.opensearch.knn.plugin.transport.UpdateBlockedModelRequest; import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; +import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; +import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest; import java.io.IOException; import java.net.URL; @@ -156,14 +157,14 @@ public interface ModelDao { void delete(String modelId, ActionListener listener); /** - * Check if modelId is in blocked model set (ModelGraveyard) or not. Non-blocking. - * A modelId is added to blocked model set before deleting that - * model and removed from the set after deleting the model + * Check if modelId is in model graveyard or not. Non-blocking. + * A modelId is added to model graveyard before deleting that + * model and removed from it after deleting the model * * @param modelId to retrieve - * @return true if modelId is in blocked model set, otherwise return false + * @return true if modelId is in model graveyard, otherwise return false */ - boolean isModelBlockedForDelete(String modelId); + boolean isModelInGraveyard(String modelId); /** * Implementation of ModelDao for k-NN model index @@ -171,6 +172,8 @@ public interface ModelDao { final class OpenSearchKNNModelDao implements ModelDao { public static Logger logger = LogManager.getLogger(ModelDao.class); + private static final String DELETED = "deleted"; + private static final String FAILED = "failed"; private int numberOfShards; private int numberOfReplicas; @@ -441,7 +444,7 @@ private String getMapping() throws IOException { } @Override - public boolean isModelBlockedForDelete(String modelId) { + public boolean isModelInGraveyard(String modelId) { // Check if the objects are not null and throw a customized NullPointerException Objects.requireNonNull(clusterService.state(), "Cluster state must not be null"); Objects.requireNonNull(clusterService.state().metadata(), "Cluster metadata must not be null"); @@ -460,7 +463,7 @@ public void delete(String modelId, ActionListener listener) if (!isCreated()) { logger.error("Cannot delete model \"" + modelId + "\". Model index " + MODEL_INDEX_NAME + "does not exist."); String errorMessage = String.format("Cannot delete model \"%s\". Model index does not exist", modelId); - listener.onResponse(new DeleteModelResponse(modelId, "failed", errorMessage)); + listener.onResponse(new DeleteModelResponse(modelId, FAILED, errorMessage)); return; } @@ -475,7 +478,7 @@ public void delete(String modelId, ActionListener listener) get(modelId, ActionListener.wrap(getModelStep::onResponse, exception -> { if (exception instanceof ResourceNotFoundException) { String errorMessage = String.format("Unable to delete model \"%s\". Model does not exist", modelId); - listener.onResponse(new DeleteModelResponse(modelId, "failed", errorMessage)); + listener.onFailure(new ResourceNotFoundException(errorMessage)); return; } listener.onFailure(exception); @@ -483,15 +486,14 @@ public void delete(String modelId, ActionListener listener) getModelStep.whenComplete(getModelResponse -> { // If model is in Training state, fail delete model request - if (ModelState.TRAINING.equals(getModelResponse.getModel().getModelMetadata().getState())) { + if (ModelState.TRAINING == getModelResponse.getModel().getModelMetadata().getState()) { String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", modelId); - logger.error(errorMessage); - listener.onResponse(new DeleteModelResponse(modelId, "failed", errorMessage)); + listener.onResponse(new DeleteModelResponse(modelId, FAILED, errorMessage)); return; } - // Add modelId to blocked set until delete model request is processed - updateBlockedModelToDelete(modelId, false, blockModelIdStep); + // Add modelId to model graveyard until delete model request is processed + updateModelGraveyardToDelete(modelId, false, blockModelIdStep, null); }, listener::onFailure); // Remove the metadata asynchronously @@ -512,9 +514,9 @@ public void delete(String modelId, ActionListener listener) ); deleteModelFromIndexStep.whenComplete(deleteResponse -> { - // If model is not deleted, unblock modelId and return with error message + // If model is not deleted, remove modelId from model graveyard and return with error message if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) { - updateBlockedModelToDelete(modelId, true, unblockModelIdStep); + updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, null); String errorMessage = String.format("Model \" %s \" does not exist", modelId); listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), errorMessage)); return; @@ -522,26 +524,28 @@ public void delete(String modelId, ActionListener listener) // After model is deleted from the index, make sure the model is evicted from every cache in the cluster removeModelFromCache(modelId, clearModelFromCacheStep); - }, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage()))); + }, e -> listener.onFailure(new OpenSearchException(e))); clearModelFromCacheStep.whenComplete(removeModelFromCacheResponse -> { - // Remove modelId from blocked set - updateBlockedModelToDelete(modelId, true, unblockModelIdStep); - unblockModelIdStep.whenComplete(acknowledgedResponse -> { + // If there are any failures while removing model from the cache build the error message + OpenSearchException exception = null; + if (removeModelFromCacheResponse.hasFailures()) { + String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse); + exception = new OpenSearchException(failureMessage); + } - // After clearing the cache, if there are no errors return the response - if (!removeModelFromCacheResponse.hasFailures()) { - listener.onResponse(new DeleteModelResponse(modelId, "deleted", null)); - return; - } + // Remove modelId from model graveyard + updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, exception); + + }, e -> listener.onFailure(new OpenSearchException(e))); + + unblockModelIdStep.whenComplete(acknowledgedResponse -> { + // After clearing the cache, if there are no errors return the response + listener.onResponse(new DeleteModelResponse(modelId, DELETED, null)); + + }, listener::onFailure); - // Build the error message if there are any failures in model cache response and return response - String failureMessage = buildRemoveModelErrorMessage(modelId, removeModelFromCacheResponse); - listener.onResponse(new DeleteModelResponse(modelId, "failed", failureMessage)); - return; - }, listener::onFailure); - }, e -> listener.onResponse(new DeleteModelResponse(modelId, "failed", e.getMessage()))); } // Remove model from cache in the cluster @@ -551,7 +555,7 @@ private void removeModelFromCache(String modelId, StepListener unblockModelIdOnFailure(modelId, exception, clearModelFromCacheStep) + exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, clearModelFromCacheStep) ) ); } @@ -565,17 +569,36 @@ private void deleteModelFromIndex( deleteRequestBuilder.execute( ActionListener.wrap( deleteModelFromIndexStep::onResponse, - exception -> unblockModelIdOnFailure(modelId, exception, deleteModelFromIndexStep) + exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, deleteModelFromIndexStep) ) ); } - // Update blocked model set to add/remove modelId from that set - private void updateBlockedModelToDelete(String modelId, boolean isRemoveRequest, StepListener step) { + // Update model graveyard to add/remove modelId + private void updateModelGraveyardToDelete( + String modelId, + boolean isRemoveRequest, + StepListener step, + Exception exception + ) { + client.execute( - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest(modelId, isRemoveRequest), - ActionListener.wrap(step::onResponse, step::onFailure) + UpdateModelGraveyardAction.INSTANCE, + new UpdateModelGraveyardRequest(modelId, isRemoveRequest), + ActionListener.wrap(acknowledgedResponse -> { + if (exception == null) { + step.onResponse(acknowledgedResponse); + return; + } + throw exception; + + }, e -> { + if (exception == null) { + step.onFailure(e); + return; + } + step.onFailure(exception); + }) ); } @@ -586,21 +609,19 @@ private void clearModelMetadata(String modelId, StepListener unblockModelIdOnFailure(modelId, exception, clearModelMetadataStep) + exception -> removeModelIdFromGraveyardOnFailure(modelId, exception, clearModelMetadataStep) ) ); } - // This function helps to remove the model from blocked list when the delete request fails - // while executing after adding modelId to blocked list - private void unblockModelIdOnFailure(String modelId, Exception exceptionFromPreviousStep, StepListener step) { - // If modelId is unblocked successfully, then we will just return the exception received from failed stepListener - // If unblocking modelId request fails, then we will return this exception along with the one received from failed stepListener + // This function helps to remove the model from model graveyard and return the exception from previous step + // when the delete request fails while executing after adding modelId to model graveyard + private void removeModelIdFromGraveyardOnFailure(String modelId, Exception exceptionFromPreviousStep, StepListener step) { client.execute( - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest(modelId, true), + UpdateModelGraveyardAction.INSTANCE, + new UpdateModelGraveyardRequest(modelId, true), ActionListener.wrap( - acknowledgedResponse -> step.onFailure(exceptionFromPreviousStep), + acknowledgedResponse -> { throw exceptionFromPreviousStep; }, unblockingFailedException -> step.onFailure(exceptionFromPreviousStep) ) ); diff --git a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java index f2d695ffc..78499c1f3 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java +++ b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java @@ -5,6 +5,7 @@ package org.opensearch.knn.indices; +import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; import org.opensearch.Version; import org.opensearch.cluster.Diff; @@ -28,24 +29,18 @@ * The modelIds of the models that are under deletion are added to this set and later removed from this set after deletion. * Also, this class implements the methods to perform operations on this set (like add, remove, contains) */ + +@AllArgsConstructor @Log4j2 public class ModelGraveyard implements Metadata.Custom { public static final String TYPE = "opensearch-knn-blocked-models"; - private final Set modelGraveyard; - - /** - * Constructor - * @param modelGraveyard Set which contains blocked model Ids - */ - public ModelGraveyard(Set modelGraveyard) { - this.modelGraveyard = modelGraveyard; - } + private final Set modelIds; /** * Default Constructor to initialize object when it is null */ public ModelGraveyard() { - this.modelGraveyard = new HashSet<>(); + this.modelIds = new HashSet<>(); } /** @@ -53,7 +48,7 @@ public ModelGraveyard() { * @throws IOException if read from stream fails */ public ModelGraveyard(StreamInput in) throws IOException { - this.modelGraveyard = new HashSet<>(in.readStringList()); + this.modelIds = new HashSet<>(in.readStringList()); } @Override @@ -80,7 +75,7 @@ public Version getMinimalSupportedVersion() { */ @Override public void writeTo(StreamOutput out) throws IOException { - out.writeStringCollection(modelGraveyard); + out.writeStringCollection(modelIds); } @Override @@ -89,39 +84,39 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } /** - * @param modelId id of the model that needs to be removed from modelGraveyard set + * @param modelId id of the model that needs to be removed from modelIds set */ public void remove(String modelId) { - modelGraveyard.remove(modelId); + modelIds.remove(modelId); } /** - * @param modelId id of the model that needs to be added to modelGraveyard set + * @param modelId id of the model that needs to be added to modelIds set */ public void add(String modelId) { - modelGraveyard.add(modelId); + modelIds.add(modelId); } /** * @return Set of modelIds in modelGraveyard */ - public Set getModelGraveyard() { - return modelGraveyard; + public Set getModelIds() { + return modelIds; } /** - * @return number of modelIds in modelGraveyard set + * @return number of modelIds in modelGraveyard */ public int size() { - return modelGraveyard.size(); + return modelIds.size(); } /** - * @param modelId to check if the id of given model is there in modelGraveyard set - * @return true if the modelId is in the modelGraveyard set, otherwise false + * @param modelId to check if the id of given model is there in modelIds set + * @return true if the modelId is in the modelIds set, otherwise false */ public boolean contains(String modelId) { - return modelGraveyard.contains(modelId); + return modelIds.contains(modelId); } /** @@ -177,21 +172,21 @@ public ModelGraveyardDiff(StreamInput inp) throws IOException { * entries that are deleted from previous object and the deleted entries count */ public ModelGraveyardDiff(ModelGraveyard previous, ModelGraveyard current) { - final Set previousModelGraveyard = previous.modelGraveyard; - final Set currentModelGraveyard = current.modelGraveyard; + final Set previousModelIdsSet = previous.modelIds; + final Set currentModelIdsSet = current.modelIds; final Set added, removed; - if (previousModelGraveyard.isEmpty()) { + if (previousModelIdsSet.isEmpty()) { // nothing will have been removed in previous object, and all entries in current object are new - added = new HashSet<>(currentModelGraveyard); + added = new HashSet<>(currentModelIdsSet); removed = new HashSet<>(); - } else if (currentModelGraveyard.isEmpty()) { + } else if (currentModelIdsSet.isEmpty()) { // nothing will have been added to current object, and all entries in previous object are removed added = new HashSet<>(); - removed = new HashSet<>(previousModelGraveyard); + removed = new HashSet<>(previousModelIdsSet); } else { // some entries in previous object are removed and few entries are added to current object - removed = Sets.difference(previousModelGraveyard, currentModelGraveyard); - added = Sets.difference(currentModelGraveyard, previousModelGraveyard); + removed = Sets.difference(previousModelIdsSet, currentModelIdsSet); + added = Sets.difference(currentModelIdsSet, previousModelIdsSet); } this.added = Collections.unmodifiableSet(added); this.removed = Collections.unmodifiableSet(removed); @@ -210,7 +205,7 @@ public ModelGraveyard apply(Metadata.Custom previous) { "ModelGraveyardDiff cannot remove [" + removedCount + "] entries from [" + old.size() + "] modelIds." ); } - Set updatedOldGraveyardSet = Sets.difference(old.modelGraveyard, removed); + Set updatedOldGraveyardSet = Sets.difference(old.modelIds, removed); Set modelGraveyardDiffSet = new HashSet<>(); modelGraveyardDiffSet.addAll(added); modelGraveyardDiffSet.addAll(updatedOldGraveyardSet); diff --git a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java index 675edfec9..f3b7d8197 100644 --- a/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java +++ b/src/main/java/org/opensearch/knn/plugin/KNNPlugin.java @@ -8,8 +8,6 @@ import org.opensearch.cluster.NamedDiff; import org.opensearch.cluster.metadata.Metadata; import org.opensearch.common.ParseField; -import org.opensearch.common.io.stream.NamedWriteable; -import org.opensearch.common.io.stream.Writeable; import org.opensearch.index.codec.CodecServiceFactory; import org.opensearch.index.engine.EngineFactory; import org.opensearch.knn.index.KNNCircuitBreaker; @@ -71,10 +69,10 @@ import org.opensearch.knn.plugin.transport.TrainingModelAction; import org.opensearch.knn.plugin.transport.TrainingModelRequest; import org.opensearch.knn.plugin.transport.TrainingModelTransportAction; -import org.opensearch.knn.plugin.transport.UpdateBlockedModelAction; -import org.opensearch.knn.plugin.transport.UpdateBlockedModelTransportAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataTransportAction; +import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; +import org.opensearch.knn.plugin.transport.UpdateModelGraveyardTransportAction; import org.opensearch.knn.training.TrainingJobRunner; import org.opensearch.knn.training.VectorReader; import org.opensearch.plugins.ActionPlugin; @@ -245,7 +243,7 @@ public List getRestHandlers( new ActionHandler<>(TrainingModelAction.INSTANCE, TrainingModelTransportAction.class), new ActionHandler<>(RemoveModelFromCacheAction.INSTANCE, RemoveModelFromCacheTransportAction.class), new ActionHandler<>(SearchModelAction.INSTANCE, SearchModelTransportAction.class), - new ActionHandler<>(UpdateBlockedModelAction.INSTANCE, UpdateBlockedModelTransportAction.class) + new ActionHandler<>(UpdateModelGraveyardAction.INSTANCE, UpdateModelGraveyardTransportAction.class) ); } @@ -307,7 +305,9 @@ public List> getExecutorBuilders(Settings settings) { @Override public List getNamedWriteables() { List entries = new ArrayList<>(); - registerMetadataCustom(entries, ModelGraveyard.TYPE, ModelGraveyard::new, ModelGraveyard::readDiffFrom); + + entries.add(new NamedWriteableRegistry.Entry(Metadata.Custom.class, ModelGraveyard.TYPE, ModelGraveyard::new)); + entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, ModelGraveyard.TYPE, ModelGraveyard::readDiffFrom)); return entries; } @@ -321,23 +321,4 @@ public List getNamedXContent() { return entries; } - private static void registerMetadataCustom( - List entries, - String name, - Writeable.Reader reader, - Writeable.Reader diffReader - ) { - registerCustom(entries, Metadata.Custom.class, name, reader, diffReader); - } - - private static void registerCustom( - List entries, - Class category, - String name, - Writeable.Reader reader, - Writeable.Reader diffReader - ) { - entries.add(new NamedWriteableRegistry.Entry(category, name, reader)); - entries.add(new NamedWriteableRegistry.Entry(NamedDiff.class, name, diffReader)); - } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java index fd1629afc..9b7066c81 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/TrainingModelRequest.java @@ -254,19 +254,20 @@ public ActionRequestValidationException validate() { ActionRequestValidationException exception = null; // Check if model id exists via model metadata - if (modelDao.getMetadata(modelId) != null) { + // Also, check if model is not in model graveyard to make sure it is not being deleted + if (modelDao.getMetadata(modelId) != null && !modelDao.isModelInGraveyard(modelId)) { exception = new ActionRequestValidationException(); exception.addValidationError("Model with id=\"" + modelId + "\" already exists"); return exception; } - // Check if modelId is in blocked model set - // ModelId is added to blocked set if that model is undergoing deletion - // and will be removed from blocked set after model is deleted - if (modelDao.isModelBlockedForDelete(modelId)) { + // Check if modelId is in model graveyard + // ModelId is added to model graveyard if that model is undergoing deletion + // and will be removed from it after model is deleted + if (modelDao.isModelInGraveyard(modelId)) { exception = new ActionRequestValidationException(); String errorMessage = String.format( - "\"%s\" is blocked. Cannot create a model with same modelID until that model is deleted", + "Model with id = \"%s\" is being deleted. Cannot create a model with same modelID until that model is deleted", modelId ); exception.addValidationError(errorMessage); diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelAction.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java similarity index 59% rename from src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelAction.java rename to src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java index 1b19db3fd..a9897f711 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelAction.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardAction.java @@ -10,12 +10,12 @@ import org.opensearch.common.io.stream.Writeable; /** - * Action to update blocked modelIds list + * Action to update model graveyard */ -public class UpdateBlockedModelAction extends ActionType { +public class UpdateModelGraveyardAction extends ActionType { - public static final String NAME = "cluster:admin/knn_update_blocked_model_action"; - public static final UpdateBlockedModelAction INSTANCE = new UpdateBlockedModelAction(NAME, AcknowledgedResponse::new); + public static final String NAME = "cluster:admin/knn_update_model_graveyard_action"; + public static final UpdateModelGraveyardAction INSTANCE = new UpdateModelGraveyardAction(NAME, AcknowledgedResponse::new); /** * Constructor. @@ -23,7 +23,7 @@ public class UpdateBlockedModelAction extends ActionType { * @param name name of action * @param acknowledgedResponseReader reader for acknowledged response */ - public UpdateBlockedModelAction(String name, Writeable.Reader acknowledgedResponseReader) { + public UpdateModelGraveyardAction(String name, Writeable.Reader acknowledgedResponseReader) { super(name, acknowledgedResponseReader); } } diff --git a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java similarity index 82% rename from src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java rename to src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java index b80330908..f8ca38507 100644 --- a/src/main/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequest.java +++ b/src/main/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequest.java @@ -16,9 +16,9 @@ import static org.opensearch.action.ValidateActions.addValidationError; /** - * Request for updating blocked modelIds list while processing delete model request + * Request for updating model graveyard while processing delete model request */ -public class UpdateBlockedModelRequest extends AcknowledgedRequest { +public class UpdateModelGraveyardRequest extends AcknowledgedRequest { @Getter private final String modelId; @@ -31,7 +31,7 @@ public class UpdateBlockedModelRequest extends AcknowledgedRequest { - private UpdateBlockedModelExecutor updateBlockedModelExecutor; +public class UpdateModelGraveyardTransportAction extends TransportMasterNodeAction { + private UpdateModelGraveyardExecutor updateModelGraveyardExecutor; @Inject - public UpdateBlockedModelTransportAction( + public UpdateModelGraveyardTransportAction( TransportService transportService, ClusterService clusterService, ThreadPool threadPool, @@ -51,15 +50,15 @@ public UpdateBlockedModelTransportAction( IndexNameExpressionResolver indexNameExpressionResolver ) { super( - UpdateBlockedModelAction.NAME, + UpdateModelGraveyardAction.NAME, transportService, clusterService, threadPool, actionFilters, - UpdateBlockedModelRequest::new, + UpdateModelGraveyardRequest::new, indexNameExpressionResolver ); - this.updateBlockedModelExecutor = new UpdateBlockedModelExecutor(); + this.updateModelGraveyardExecutor = new UpdateModelGraveyardExecutor(); } @Override @@ -74,16 +73,16 @@ protected AcknowledgedResponse read(StreamInput streamInput) throws IOException @Override protected void masterOperation( - UpdateBlockedModelRequest request, + UpdateModelGraveyardRequest request, ClusterState clusterState, ActionListener actionListener ) { - // ClusterManager updates blocked modelIds list based on request parameters + // ClusterManager updates model graveyard based on request parameters clusterService.submitStateUpdateTask( PLUGIN_NAME, - new UpdateBlockedModelTask(request.getModelId(), request.isRemoveRequest()), + new UpdateModelGraveyardTask(request.getModelId(), request.isRemoveRequest()), ClusterStateTaskConfig.build(Priority.NORMAL), - updateBlockedModelExecutor, + updateModelGraveyardExecutor, new ClusterStateTaskListener() { @Override public void onFailure(String s, Exception e) { @@ -99,30 +98,30 @@ public void clusterStateProcessed(String source, ClusterState oldState, ClusterS } @Override - protected ClusterBlockException checkBlock(UpdateBlockedModelRequest request, ClusterState clusterState) { + protected ClusterBlockException checkBlock(UpdateModelGraveyardRequest request, ClusterState clusterState) { return null; } /** - * UpdateBlockedModelTask is used to provide the executor with the information it needs to perform its task + * UpdateModelGraveyardTask is used to provide the executor with the information it needs to perform its task */ - @AllArgsConstructor - private static class UpdateBlockedModelTask { - private String modelId; - private boolean isRemoveRequest; + @Value + private static class UpdateModelGraveyardTask { + String modelId; + boolean isRemoveRequest; } /** - * Updates the cluster state based on the UpdateBlockedModelTask + * Updates the cluster state based on the UpdateModelGraveyardTask */ - private static class UpdateBlockedModelExecutor implements ClusterStateTaskExecutor { + private static class UpdateModelGraveyardExecutor implements ClusterStateTaskExecutor { /** * @param clusterState ClusterState - * @param taskList contains the list of UpdateBlockedModelTask request parameters (modelId and isRemoveRequest) - * @return Represents the result of a batched execution of cluster state update tasks (UpdateBlockedModelTasks) + * @param taskList contains the list of UpdateModelGraveyardTask request parameters (modelId and isRemoveRequest) + * @return Represents the result of a batched execution of cluster state update tasks (UpdateModelGraveyardTasks) */ @Override - public ClusterTasksResult execute(ClusterState clusterState, List taskList) { + public ClusterTasksResult execute(ClusterState clusterState, List taskList) { // Check if the objects are not null and throw a customized NullPointerException Objects.requireNonNull(clusterState, "Cluster state must not be null"); @@ -136,23 +135,23 @@ public ClusterTasksResult execute(ClusterState clusterSt } else { // Deep Copy to copy all the modelIds in ModelGraveyard to local object // to avoid copying the reference - copySet = new HashSet<>(immutableModelGraveyard.getModelGraveyard()); + copySet = new HashSet<>(immutableModelGraveyard.getModelIds()); modelGraveyard = new ModelGraveyard(copySet); } - for (UpdateBlockedModelTask task : taskList) { - if (task.isRemoveRequest) { - modelGraveyard.remove(task.modelId); + for (UpdateModelGraveyardTask task : taskList) { + if (task.isRemoveRequest()) { + modelGraveyard.remove(task.getModelId()); continue; } - modelGraveyard.add(task.modelId); + modelGraveyard.add(task.getModelId()); } Metadata.Builder metaDataBuilder = Metadata.builder(clusterState.metadata()); metaDataBuilder.putCustom(ModelGraveyard.TYPE, modelGraveyard); ClusterState updatedClusterState = ClusterState.builder(clusterState).metadata(metaDataBuilder).build(); - return new ClusterTasksResult.Builder().successes(taskList).build(updatedClusterState); + return new ClusterTasksResult.Builder().successes(taskList).build(updatedClusterState); } } } diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 9bcdd14c4..9ea025cf0 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -12,6 +12,7 @@ package org.opensearch.knn.indices; import org.junit.AfterClass; +import org.junit.Assert; import org.junit.BeforeClass; import org.opensearch.ExceptionsHelper; import org.opensearch.ResourceAlreadyExistsException; @@ -38,10 +39,10 @@ import org.opensearch.knn.plugin.transport.RemoveModelFromCacheAction; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheRequest; import org.opensearch.knn.plugin.transport.RemoveModelFromCacheResponse; -import org.opensearch.knn.plugin.transport.UpdateBlockedModelAction; -import org.opensearch.knn.plugin.transport.UpdateBlockedModelRequest; import org.opensearch.knn.plugin.transport.UpdateModelMetadataAction; import org.opensearch.knn.plugin.transport.UpdateModelMetadataRequest; +import org.opensearch.knn.plugin.transport.UpdateModelGraveyardAction; +import org.opensearch.knn.plugin.transport.UpdateModelGraveyardRequest; import org.opensearch.rest.RestStatus; import java.io.IOException; @@ -69,6 +70,7 @@ public class ModelDaoTests extends KNNSingleNodeTestCase { private static ExecutorService modelGetterExecutor; + private static final String FAILED = "failed"; @BeforeClass public static void setup() { @@ -499,7 +501,7 @@ public void testDelete() throws IOException, InterruptedException { final CountDownLatch inProgressLatch = new CountDownLatch(1); ActionListener deleteModelIndexDoesNotExistListener = ActionListener.wrap(response -> { - assertEquals("failed", response.getResult()); + assertEquals(FAILED, response.getResult()); inProgressLatch.countDown(); }, exception -> fail("Unable to delete the model: " + exception)); // model index doesnt exist @@ -508,21 +510,22 @@ public void testDelete() throws IOException, InterruptedException { createIndex(MODEL_INDEX_NAME); + // Model does not exist final CountDownLatch inProgressLatch1 = new CountDownLatch(1); - ActionListener deleteModelDoesNotExistListener = ActionListener.wrap(response -> { - assertEquals(modelId, response.getModelID()); - assertEquals("failed", response.getResult()); - assertNotNull(response.getErrorMessage()); + ActionListener deleteModelDoesNotExistListener = ActionListener.wrap(Assert::assertNull, exception -> { + assertNotNull(exception); + assertTrue(exception.getMessage().contains(modelId)); + assertTrue(exception.getMessage().contains("Model does not exist")); inProgressLatch1.countDown(); - }, exception -> fail(exception.getMessage())); + }); modelDao.delete(modelId, deleteModelDoesNotExistListener); - assertTrue(inProgressLatch1.await(100, TimeUnit.SECONDS)); + assertTrue(inProgressLatch1.await(60, TimeUnit.SECONDS)); final CountDownLatch inProgressLatch2 = new CountDownLatch(1); ActionListener deleteModelTrainingListener = ActionListener.wrap(response -> { assertEquals(modelId, response.getModelID()); - assertEquals("failed", response.getResult()); + assertEquals(FAILED, response.getResult()); String errorMessage = String.format("Cannot delete model \"%s\". Model is still in training", modelId); assertEquals(errorMessage, response.getErrorMessage()); inProgressLatch2.countDown(); @@ -667,15 +670,15 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti assertNotEquals(ModelState.TRAINING.getName(), getModelResponse.getModel().getModelMetadata().getState().toString()); client().execute( - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest(modelId, false), + UpdateModelGraveyardAction.INSTANCE, + new UpdateModelGraveyardRequest(modelId, false), ActionListener.wrap(blockModelIdStep::onResponse, blockModelIdStep::onFailure) ); }, exception -> fail(exception.getMessage())); blockModelIdStep.whenComplete(acknowledgedResponse -> { // Asserting that modelId is in blocked list - assertTrue(modelDao.isModelBlockedForDelete(modelId)); + assertTrue(modelDao.isModelInGraveyard(modelId)); client().execute( UpdateModelMetadataAction.INSTANCE, @@ -709,16 +712,17 @@ public void testDeleteWithStepListeners() throws IOException, InterruptedExcepti }, exception -> fail(exception.getMessage())); clearModelFromCacheStep.whenComplete(removeModelFromCacheResponse -> { + assertFalse(removeModelFromCacheResponse.hasFailures()); + client().execute( - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest(modelId, true), + UpdateModelGraveyardAction.INSTANCE, + new UpdateModelGraveyardRequest(modelId, true), ActionListener.wrap(unblockModelIdStep::onResponse, unblockModelIdStep::onFailure) ); unblockModelIdStep.whenComplete(acknowledgedResponse -> { // Asserting that model is unblocked - assertFalse(modelDao.isModelBlockedForDelete(modelId)); - assertFalse(removeModelFromCacheResponse.hasFailures()); + assertFalse(modelDao.isModelInGraveyard(modelId)); inProgressLatch.countDown(); }, exception -> fail(exception.getMessage())); }, exception -> fail(exception.getMessage())); @@ -757,14 +761,14 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt // Add modelId to blocked list client().execute( - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest(modelId, false), + UpdateModelGraveyardAction.INSTANCE, + new UpdateModelGraveyardRequest(modelId, false), ActionListener.wrap(blockModelIdStep::onResponse, blockModelIdStep::onFailure) ); // Asserting that the modelId is blocked blockModelIdStep.whenComplete(acknowledgedResponse -> { - assertTrue(modelDao.isModelBlockedForDelete(modelId)); + assertTrue(modelDao.isModelInGraveyard(modelId)); // Sending empty string for modelId to fail the clear model metadata request client().execute( @@ -773,14 +777,14 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt ActionListener.wrap(clearModelMetadataStep::onResponse, exp -> { // Asserting that modelId is still blocked and clearModelMetadata throws an exception assertNotNull(exp.getMessage()); - assertTrue(modelDao.isModelBlockedForDelete(modelId)); + assertTrue(modelDao.isModelInGraveyard(modelId)); client().execute( // OnFailure sending request to unblock modelId - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest(modelId, true), + UpdateModelGraveyardAction.INSTANCE, + new UpdateModelGraveyardRequest(modelId, true), ActionListener.wrap(ackResponse -> { // Asserting that model is unblocked - assertFalse(modelDao.isModelBlockedForDelete(modelId)); + assertFalse(modelDao.isModelInGraveyard(modelId)); assertNotNull(exp.getMessage()); }, exception -> fail(exception.getMessage())) ); @@ -799,14 +803,14 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt // Add modelId to blocked list client().execute( - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest(modelId, false), + UpdateModelGraveyardAction.INSTANCE, + new UpdateModelGraveyardRequest(modelId, false), ActionListener.wrap(blockModelIdStep1::onResponse, blockModelIdStep1::onFailure) ); // Asserting that the modelId is blocked blockModelIdStep1.whenComplete(acknowledgedResponse -> { - assertTrue(modelDao.isModelBlockedForDelete(modelId)); + assertTrue(modelDao.isModelInGraveyard(modelId)); // Sending empty string for modelId to fail the clear model metadata request client().execute( @@ -814,15 +818,15 @@ public void testDeleteWithStepListenersOnFailure() throws IOException, Interrupt new UpdateModelMetadataRequest("", true, null), ActionListener.wrap(clearModelMetadataStep1::onResponse, exp -> { assertNotNull(exp.getMessage()); - assertTrue(modelDao.isModelBlockedForDelete(modelId)); + assertTrue(modelDao.isModelInGraveyard(modelId)); // Failing unblock modelId request by sending modelId as an empty string client().execute( - UpdateBlockedModelAction.INSTANCE, - new UpdateBlockedModelRequest("", true), + UpdateModelGraveyardAction.INSTANCE, + new UpdateModelGraveyardRequest("", true), ActionListener.wrap(ackResponse -> {}, unblockingFailedException -> { // Asserting that model is still blocked and returns both exceptions in response - assertTrue(modelDao.isModelBlockedForDelete(modelId)); + assertTrue(modelDao.isModelInGraveyard(modelId)); assertNotNull(exp.getMessage()); assertNotNull(unblockingFailedException.getMessage()); }) diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java index 3cb70ca86..fdccebaac 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java @@ -14,6 +14,7 @@ import org.apache.http.util.EntityUtils; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.client.ResponseException; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.index.SpaceType; @@ -90,20 +91,13 @@ public void testDeleteTrainingModel() throws IOException { } public void testDeleteModelFailsInvalid() throws IOException { + String modelId = "invalid-model-id"; createModelSystemIndex(); - String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, "invalid-model-id"); + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); Request request = new Request("DELETE", restURI); - Response response = client().performRequest(request); - assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); - String responseBody = EntityUtils.toString(response.getEntity()); - assertNotNull(responseBody); - - Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); - - assertEquals("invalid-model-id", responseMap.get(MODEL_ID)); - assertEquals("failed", responseMap.get(DeleteModelResponse.RESULT)); - assertNotNull(responseMap.get(DeleteModelResponse.ERROR_MSG)); + ResponseException ex = expectThrows(ResponseException.class, () -> client().performRequest(request)); + assertTrue(ex.getMessage().contains(modelId)); } } diff --git a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java index cdbb0c76f..2acce6662 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/TrainingModelRequestTests.java @@ -172,6 +172,9 @@ public void testValidation_invalid_modelIdAlreadyExists() { ); when(modelDao.getMetadata(modelId)).thenReturn(modelMetadata); + // ModelId is not added to model graveyard + when(modelDao.isModelInGraveyard(modelId)).thenReturn(false); + // This cluster service will result in no validation exceptions ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension); @@ -187,7 +190,7 @@ public void testValidation_invalid_modelIdAlreadyExists() { } // Check that the validation produces an exception when we are - // training a model with modelId that is in blocked set + // training a model with modelId that is in model graveyard public void testValidation_blocked_modelId() { // Setup the training request @@ -209,9 +212,9 @@ public void testValidation_blocked_modelId() { null ); - // Mock the model dao to return true to recognize that the modelId is blocked + // Mock the model dao to return true to recognize that the modelId is in graveyard ModelDao modelDao = mock(ModelDao.class); - when(modelDao.isModelBlockedForDelete(modelId)).thenReturn(true); + when(modelDao.isModelInGraveyard(modelId)).thenReturn(true); // This cluster service will result in no validation exceptions ClusterService clusterService = getClusterServiceForValidReturns(trainingIndex, trainingField, dimension); @@ -219,12 +222,12 @@ public void testValidation_blocked_modelId() { // Initialize static components with the mocks TrainingModelRequest.initialize(modelDao, clusterService); - // Test that validation produces an error message that modelId is blocked + // Test that validation produces an error message that modelId is being deleted ActionRequestValidationException exception = trainingModelRequest.validate(); assertNotNull(exception); List validationErrors = exception.validationErrors(); assertEquals(1, validationErrors.size()); - assertTrue(validationErrors.get(0).contains("is blocked")); + assertTrue(validationErrors.get(0).contains("is being deleted")); } public void testValidation_invalid_invalidMethodContext() { diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequestTests.java deleted file mode 100644 index cc136234e..000000000 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelRequestTests.java +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.knn.plugin.transport; - -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.knn.KNNTestCase; -import java.io.IOException; - -public class UpdateBlockedModelRequestTests extends KNNTestCase { - - public void testStreams() throws IOException { - String modelId = "test-model-id"; - boolean isRemoveRequest = false; - - UpdateBlockedModelRequest updateBlockedModelRequest = new UpdateBlockedModelRequest(modelId, isRemoveRequest); - - BytesStreamOutput streamOutput = new BytesStreamOutput(); - updateBlockedModelRequest.writeTo(streamOutput); - - UpdateBlockedModelRequest updateBlockedModelRequest1 = new UpdateBlockedModelRequest(streamOutput.bytes().streamInput()); - - assertEquals(updateBlockedModelRequest.getModelId(), updateBlockedModelRequest1.getModelId()); - assertEquals(updateBlockedModelRequest.isRemoveRequest(), updateBlockedModelRequest1.isRemoveRequest()); - } - - public void testValidate() { - String modelId = "test-model-id"; - UpdateBlockedModelRequest updateBlockedModelRequest1 = new UpdateBlockedModelRequest(modelId, false); - assertNull(updateBlockedModelRequest1.validate()); - - UpdateBlockedModelRequest updateBlockedModelRequest2 = new UpdateBlockedModelRequest(modelId, true); - assertNull(updateBlockedModelRequest2.validate()); - - UpdateBlockedModelRequest updateBlockedModelRequest3 = new UpdateBlockedModelRequest("", false); - assertNotNull(updateBlockedModelRequest3.validate()); - - UpdateBlockedModelRequest updateBlockedModelRequest4 = new UpdateBlockedModelRequest("", true); - assertNotNull(updateBlockedModelRequest4.validate()); - } - - public void testGetModelId() { - String modelId = "test-model-id"; - UpdateBlockedModelRequest updateBlockedModelRequest = new UpdateBlockedModelRequest(modelId, false); - - assertEquals(modelId, updateBlockedModelRequest.getModelId()); - } - - public void testIsRemoveRequest() { - String modelId = "test-model-id"; - boolean isRemoveRequest = false; - UpdateBlockedModelRequest updateBlockedModelRequest = new UpdateBlockedModelRequest(modelId, isRemoveRequest); - - assertEquals(isRemoveRequest, updateBlockedModelRequest.isRemoveRequest()); - } -} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequestTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequestTests.java new file mode 100644 index 000000000..7c38adc36 --- /dev/null +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardRequestTests.java @@ -0,0 +1,58 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.plugin.transport; + +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.knn.KNNTestCase; +import java.io.IOException; + +public class UpdateModelGraveyardRequestTests extends KNNTestCase { + + public void testStreams() throws IOException { + String modelId = "test-model-id"; + boolean isRemoveRequest = false; + + UpdateModelGraveyardRequest updateModelGraveyardRequest = new UpdateModelGraveyardRequest(modelId, isRemoveRequest); + + BytesStreamOutput streamOutput = new BytesStreamOutput(); + updateModelGraveyardRequest.writeTo(streamOutput); + + UpdateModelGraveyardRequest updateModelGraveyardRequest1 = new UpdateModelGraveyardRequest(streamOutput.bytes().streamInput()); + + assertEquals(updateModelGraveyardRequest.getModelId(), updateModelGraveyardRequest1.getModelId()); + assertEquals(updateModelGraveyardRequest.isRemoveRequest(), updateModelGraveyardRequest1.isRemoveRequest()); + } + + public void testValidate() { + String modelId = "test-model-id"; + UpdateModelGraveyardRequest updateModelGraveyardRequest1 = new UpdateModelGraveyardRequest(modelId, false); + assertNull(updateModelGraveyardRequest1.validate()); + + UpdateModelGraveyardRequest updateModelGraveyardRequest2 = new UpdateModelGraveyardRequest(modelId, true); + assertNull(updateModelGraveyardRequest2.validate()); + + UpdateModelGraveyardRequest updateModelGraveyardRequest3 = new UpdateModelGraveyardRequest("", false); + assertNotNull(updateModelGraveyardRequest3.validate()); + + UpdateModelGraveyardRequest updateModelGraveyardRequest4 = new UpdateModelGraveyardRequest("", true); + assertNotNull(updateModelGraveyardRequest4.validate()); + } + + public void testGetModelId() { + String modelId = "test-model-id"; + UpdateModelGraveyardRequest updateModelGraveyardRequest = new UpdateModelGraveyardRequest(modelId, false); + + assertEquals(modelId, updateModelGraveyardRequest.getModelId()); + } + + public void testIsRemoveRequest() { + String modelId = "test-model-id"; + boolean isRemoveRequest = false; + UpdateModelGraveyardRequest updateModelGraveyardRequest = new UpdateModelGraveyardRequest(modelId, isRemoveRequest); + + assertEquals(isRemoveRequest, updateModelGraveyardRequest.isRemoveRequest()); + } +} diff --git a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java similarity index 77% rename from src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java rename to src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java index 8139ef294..6216f985d 100644 --- a/src/test/java/org/opensearch/knn/plugin/transport/UpdateBlockedModelTransportActionTests.java +++ b/src/test/java/org/opensearch/knn/plugin/transport/UpdateModelGraveyardTransportActionTests.java @@ -17,21 +17,21 @@ import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -public class UpdateBlockedModelTransportActionTests extends KNNSingleNodeTestCase { +public class UpdateModelGraveyardTransportActionTests extends KNNSingleNodeTestCase { public void testExecutor() { - UpdateBlockedModelTransportAction updateBlockedModelTransportAction = node().injector() - .getInstance(UpdateBlockedModelTransportAction.class); - assertEquals(ThreadPool.Names.SAME, updateBlockedModelTransportAction.executor()); + UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction = node().injector() + .getInstance(UpdateModelGraveyardTransportAction.class); + assertEquals(ThreadPool.Names.SAME, updateModelGraveyardTransportAction.executor()); } public void testRead() throws IOException { - UpdateBlockedModelTransportAction updateBlockedModelTransportAction = node().injector() - .getInstance(UpdateBlockedModelTransportAction.class); + UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction = node().injector() + .getInstance(UpdateModelGraveyardTransportAction.class); AcknowledgedResponse acknowledgedResponse = new AcknowledgedResponse(true); BytesStreamOutput streamOutput = new BytesStreamOutput(); acknowledgedResponse.writeTo(streamOutput); - AcknowledgedResponse acknowledgedResponse1 = updateBlockedModelTransportAction.read(streamOutput.bytes().streamInput()); + AcknowledgedResponse acknowledgedResponse1 = updateModelGraveyardTransportAction.read(streamOutput.bytes().streamInput()); assertEquals(acknowledgedResponse, acknowledgedResponse1); } @@ -41,18 +41,18 @@ public void testClusterManagerOperation() throws InterruptedException { String modelId = "test-model-id"; // Get update transport action - UpdateBlockedModelTransportAction updateBlockedModelTransportAction = node().injector() - .getInstance(UpdateBlockedModelTransportAction.class); + UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction = node().injector() + .getInstance(UpdateModelGraveyardTransportAction.class); - // Generate update request to add modelId to blocked set (ModelGraveyard) - UpdateBlockedModelRequest addBlockedModelRequest = new UpdateBlockedModelRequest(modelId, false); + // Generate update request to add modelId to model graveyard + UpdateModelGraveyardRequest addModelGraveyardRequest = new UpdateModelGraveyardRequest(modelId, false); // Get cluster state, update metadata, check cluster state - all asynchronously final CountDownLatch inProgressLatch1 = new CountDownLatch(1); client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { ClusterState clusterState1 = stateResponse1.getState(); - updateBlockedModelTransportAction.masterOperation( - addBlockedModelRequest, + updateModelGraveyardTransportAction.masterOperation( + addModelGraveyardRequest, clusterState1, ActionListener.wrap(acknowledgedResponse -> { assertTrue(acknowledgedResponse.isAcknowledged()); @@ -75,14 +75,14 @@ public void testClusterManagerOperation() throws InterruptedException { assertTrue(inProgressLatch1.await(60, TimeUnit.SECONDS)); String modelId1 = "test-model-id-1"; - // Generate update request to add modelId1 to blocked set (ModelGraveyard) - UpdateBlockedModelRequest addBlockedModelRequest1 = new UpdateBlockedModelRequest(modelId1, false); + // Generate update request to add modelId1 to model graveyard + UpdateModelGraveyardRequest addModelGraveyardRequest1 = new UpdateModelGraveyardRequest(modelId1, false); final CountDownLatch inProgressLatch2 = new CountDownLatch(1); client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { ClusterState clusterState1 = stateResponse1.getState(); - updateBlockedModelTransportAction.masterOperation( - addBlockedModelRequest1, + updateModelGraveyardTransportAction.masterOperation( + addModelGraveyardRequest1, clusterState1, ActionListener.wrap(acknowledgedResponse -> { assertTrue(acknowledgedResponse.isAcknowledged()); @@ -117,14 +117,14 @@ public void testClusterManagerOperation() throws InterruptedException { assertTrue(inProgressLatch2.await(60, TimeUnit.SECONDS)); - // Generate remove request to remove the modelId from blocked set (ModelGraveyard) - UpdateBlockedModelRequest removeBlockedModelRequest = new UpdateBlockedModelRequest(modelId, true); + // Generate remove request to remove the modelId from model graveyard + UpdateModelGraveyardRequest removeModelGraveyardRequest = new UpdateModelGraveyardRequest(modelId, true); final CountDownLatch inProgressLatch3 = new CountDownLatch(1); client().admin().cluster().prepareState().execute(ActionListener.wrap(stateResponse1 -> { ClusterState clusterState1 = stateResponse1.getState(); - updateBlockedModelTransportAction.masterOperation( - removeBlockedModelRequest, + updateModelGraveyardTransportAction.masterOperation( + removeModelGraveyardRequest, clusterState1, ActionListener.wrap(acknowledgedResponse -> { assertTrue(acknowledgedResponse.isAcknowledged()); @@ -161,8 +161,8 @@ public void testClusterManagerOperation() throws InterruptedException { } public void testCheckBlock() { - UpdateBlockedModelTransportAction updateBlockedModelTransportAction = node().injector() - .getInstance(UpdateBlockedModelTransportAction.class); - assertNull(updateBlockedModelTransportAction.checkBlock(null, null)); + UpdateModelGraveyardTransportAction updateModelGraveyardTransportAction = node().injector() + .getInstance(UpdateModelGraveyardTransportAction.class); + assertNull(updateModelGraveyardTransportAction.checkBlock(null, null)); } } From 0f178ea5bf0c52bfc65a9343423ecdf092f3a0bd Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 14 Jul 2022 14:12:14 -0500 Subject: [PATCH 9/9] Remove Model from Model Graveyard even if model does not exist Signed-off-by: Naveen Tatikonda --- .../org/opensearch/knn/indices/ModelDao.java | 41 ++++--- .../opensearch/knn/indices/ModelDaoTests.java | 35 ++++++ .../action/RestDeleteModelHandlerIT.java | 101 ++++++++++++++++++ 3 files changed, 162 insertions(+), 15 deletions(-) diff --git a/src/main/java/org/opensearch/knn/indices/ModelDao.java b/src/main/java/org/opensearch/knn/indices/ModelDao.java index 8014e73ad..0d5d75d30 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelDao.java +++ b/src/main/java/org/opensearch/knn/indices/ModelDao.java @@ -60,6 +60,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.concurrent.ExecutionException; import static java.util.Objects.isNull; @@ -478,10 +479,11 @@ public void delete(String modelId, ActionListener listener) get(modelId, ActionListener.wrap(getModelStep::onResponse, exception -> { if (exception instanceof ResourceNotFoundException) { String errorMessage = String.format("Unable to delete model \"%s\". Model does not exist", modelId); - listener.onFailure(new ResourceNotFoundException(errorMessage)); - return; + ResourceNotFoundException resourceNotFoundException = new ResourceNotFoundException(errorMessage); + removeModelIdFromGraveyardOnFailure(modelId, resourceNotFoundException, getModelStep); + } else { + removeModelIdFromGraveyardOnFailure(modelId, exception, getModelStep); } - listener.onFailure(exception); })); getModelStep.whenComplete(getModelResponse -> { @@ -493,7 +495,7 @@ public void delete(String modelId, ActionListener listener) } // Add modelId to model graveyard until delete model request is processed - updateModelGraveyardToDelete(modelId, false, blockModelIdStep, null); + updateModelGraveyardToDelete(modelId, false, blockModelIdStep, Optional.empty()); }, listener::onFailure); // Remove the metadata asynchronously @@ -516,7 +518,7 @@ public void delete(String modelId, ActionListener listener) deleteModelFromIndexStep.whenComplete(deleteResponse -> { // If model is not deleted, remove modelId from model graveyard and return with error message if (deleteResponse.getResult() != DocWriteResponse.Result.DELETED) { - updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, null); + updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, Optional.empty()); String errorMessage = String.format("Model \" %s \" does not exist", modelId); listener.onResponse(new DeleteModelResponse(modelId, deleteResponse.getResult().getLowercase(), errorMessage)); return; @@ -536,7 +538,7 @@ public void delete(String modelId, ActionListener listener) } // Remove modelId from model graveyard - updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, exception); + updateModelGraveyardToDelete(modelId, true, unblockModelIdStep, Optional.ofNullable(exception)); }, e -> listener.onFailure(new OpenSearchException(e))); @@ -579,25 +581,30 @@ private void updateModelGraveyardToDelete( String modelId, boolean isRemoveRequest, StepListener step, - Exception exception + Optional exception ) { client.execute( UpdateModelGraveyardAction.INSTANCE, new UpdateModelGraveyardRequest(modelId, isRemoveRequest), ActionListener.wrap(acknowledgedResponse -> { - if (exception == null) { + if (exception.isEmpty()) { step.onResponse(acknowledgedResponse); return; } - throw exception; + throw exception.get(); }, e -> { - if (exception == null) { + // If it fails to remove the modelId from Model Graveyard, then log the error message + String errorMessage = String.format("Failed to remove \" %s \" from Model Graveyard", modelId); + String failureMessage = String.format("%s%s%s", errorMessage, "\n", e.getMessage()); + logger.error(failureMessage); + + if (exception.isEmpty()) { step.onFailure(e); return; } - step.onFailure(exception); + step.onFailure(exception.get()); }) ); } @@ -620,10 +627,14 @@ private void removeModelIdFromGraveyardOnFailure(String modelId, Exception excep client.execute( UpdateModelGraveyardAction.INSTANCE, new UpdateModelGraveyardRequest(modelId, true), - ActionListener.wrap( - acknowledgedResponse -> { throw exceptionFromPreviousStep; }, - unblockingFailedException -> step.onFailure(exceptionFromPreviousStep) - ) + ActionListener.wrap(acknowledgedResponse -> { throw exceptionFromPreviousStep; }, unblockingFailedException -> { + // If it fails to remove the modelId from Model Graveyard, then log the error message and + // throw the exception that was passed as a parameter from previous step + String errorMessage = String.format("Failed to remove \" %s \" from Model Graveyard", modelId); + String failureMessage = String.format("%s%s%s", errorMessage, "\n", unblockingFailedException.getMessage()); + logger.error(failureMessage); + step.onFailure(exceptionFromPreviousStep); + }) ); } diff --git a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java index 9ea025cf0..73d285886 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelDaoTests.java @@ -516,6 +516,7 @@ public void testDelete() throws IOException, InterruptedException { assertNotNull(exception); assertTrue(exception.getMessage().contains(modelId)); assertTrue(exception.getMessage().contains("Model does not exist")); + assertFalse(modelDao.isModelInGraveyard(modelId)); inProgressLatch1.countDown(); }); @@ -589,6 +590,40 @@ public void testDelete() throws IOException, InterruptedException { assertTrue(inProgressLatch3.await(100, TimeUnit.SECONDS)); } + // Test Delete Model when modelId is in Model Graveyard (previous delete model request which failed to + // remove modelId from model graveyard). But, the model does not exist + public void testDeleteModelWithModelInGraveyardModelDoesNotExist() throws InterruptedException { + ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); + String modelId = "test-model-in-graveyard"; + createIndex(MODEL_INDEX_NAME); + + // Model does not exist + final CountDownLatch inProgressLatch = new CountDownLatch(1); + StepListener blockModelIdStep = new StepListener<>(); + ActionListener deleteModelDoesNotExistListener1 = ActionListener.wrap(Assert::assertNull, exception -> { + assertNotNull(exception); + assertTrue(exception.getMessage().contains(modelId)); + assertTrue(exception.getMessage().contains("Model does not exist")); + // Assert that modelId is removed from graveyard even when the model does not exist + assertFalse(modelDao.isModelInGraveyard(modelId)); + inProgressLatch.countDown(); + }); + + // Adding the modelId to model graveyard + client().execute( + UpdateModelGraveyardAction.INSTANCE, + new UpdateModelGraveyardRequest(modelId, false), + ActionListener.wrap(blockModelIdStep::onResponse, blockModelIdStep::onFailure) + ); + + blockModelIdStep.whenComplete(acknowledgedResponse -> { + // Assert that model is in graveyard + assertTrue(modelDao.isModelInGraveyard(modelId)); + modelDao.delete(modelId, deleteModelDoesNotExistListener1); + }, exception -> fail(exception.getMessage())); + assertTrue(inProgressLatch.await(60, TimeUnit.SECONDS)); + } + public void testDeleteModelInTrainingWithStepListeners() throws IOException, ExecutionException, InterruptedException { String modelId = "test-model-id-training"; ModelDao modelDao = ModelDao.OpenSearchKNNModelDao.getInstance(); diff --git a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java index fdccebaac..f2c31fa72 100644 --- a/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java +++ b/src/test/java/org/opensearch/knn/plugin/action/RestDeleteModelHandlerIT.java @@ -15,6 +15,8 @@ import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.client.ResponseException; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.common.xcontent.XContentType; import org.opensearch.knn.KNNRestTestCase; import org.opensearch.knn.index.SpaceType; @@ -28,9 +30,17 @@ import java.io.IOException; import java.util.Map; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_CODE_SIZE; +import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; +import static org.opensearch.knn.common.KNNConstants.KNN_ENGINE; import static org.opensearch.knn.common.KNNConstants.MODEL_ID; import static org.opensearch.knn.common.KNNConstants.MODELS; import static org.opensearch.knn.common.KNNConstants.MODEL_INDEX_NAME; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_SPACE_TYPE; +import static org.opensearch.knn.common.KNNConstants.METHOD_PARAMETER_NLIST; +import static org.opensearch.knn.common.KNNConstants.METHOD_ENCODER_PARAMETER; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; /** * Integration tests to check the correctness of {@link org.opensearch.knn.plugin.rest.RestDeleteModelHandler} @@ -100,4 +110,95 @@ public void testDeleteModelFailsInvalid() throws IOException { assertTrue(ex.getMessage().contains(modelId)); } + // Test Train Model -> Delete Model -> Train Model with same modelId + public void testTrainingDeletedModel() throws IOException, InterruptedException { + String modelId = "test-model-id1"; + String trainingIndexName1 = "train-index-1"; + String trainingIndexName2 = "train-index-2"; + String trainingFieldName = "train-field"; + int dimension = 8; + + // Train Model + trainModel(modelId, trainingIndexName1, trainingFieldName, dimension); + + // Delete Model + String restURI = String.join("/", KNNPlugin.KNN_BASE_URI, MODELS, modelId); + Request request = new Request("DELETE", restURI); + + Response response = client().performRequest(request); + assertEquals(request.getEndpoint() + ": failed", RestStatus.OK, RestStatus.fromCode(response.getStatusLine().getStatusCode())); + + assertEquals(0, getDocCount(MODEL_INDEX_NAME)); + + // Train Model again with same ModelId + trainModel(modelId, trainingIndexName2, trainingFieldName, dimension); + } + + private void trainModel(String modelId, String trainingIndexName, String trainingFieldName, int dimension) throws IOException, + InterruptedException { + + // Create a training index and randomly ingest data into it + createBasicKnnIndex(trainingIndexName, trainingFieldName, dimension); + int trainingDataCount = 200; + bulkIngestRandomVectors(trainingIndexName, trainingFieldName, trainingDataCount, dimension); + + // Call the train API with this definition: + /* + { + "training_index": "train_index", + "training_field": "train_field", + "dimension": 8, + "description": "this should be allowed to be null", + "method": { + "name":"ivf", + "engine":"faiss", + "space_type": "l2", + "parameters":{ + "nlist":1, + "encoder":{ + "name":"pq", + "parameters":{ + "code_size":2, + "m": 2 + } + } + } + } + } + */ + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, "ivf") + .field(KNN_ENGINE, "faiss") + .field(METHOD_PARAMETER_SPACE_TYPE, "l2") + .startObject(PARAMETERS) + .field(METHOD_PARAMETER_NLIST, 1) + .startObject(METHOD_ENCODER_PARAMETER) + .field(NAME, "pq") + .startObject(PARAMETERS) + .field(ENCODER_PARAMETER_PQ_CODE_SIZE, 2) + .field(ENCODER_PARAMETER_PQ_M, 2) + .endObject() + .endObject() + .endObject() + .endObject(); + Map method = xContentBuilderToMap(builder); + + Response trainResponse = trainModel(modelId, trainingIndexName, trainingFieldName, dimension, method, "dummy description"); + + assertEquals(RestStatus.OK, RestStatus.fromCode(trainResponse.getStatusLine().getStatusCode())); + + // Confirm that the model gets created + Response getResponse = getModel(modelId, null); + String responseBody = EntityUtils.toString(getResponse.getEntity()); + assertNotNull(responseBody); + + Map responseMap = createParser(XContentType.JSON.xContent(), responseBody).map(); + + assertEquals(modelId, responseMap.get(MODEL_ID)); + + // Make sure training succeeds after 30 seconds + assertTrainingSucceeds(modelId, 30, 1000); + } + }