From a8525396a26dc0369c2b1ff2aba5d60cad246536 Mon Sep 17 00:00:00 2001 From: Lai <57818076+wnbts@users.noreply.github.com> Date: Wed, 6 May 2020 12:47:48 -0700 Subject: [PATCH] add async stopModel (#93) --- .../ad/ml/ModelManager.java | 48 ++++++++++++-- .../ad/ml/ModelManagerTests.java | 62 +++++++++++++++++++ 2 files changed, 106 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java index e4aa6f38..5698455f 100644 --- a/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java +++ b/src/main/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManager.java @@ -538,10 +538,50 @@ private void stopModel(Map> models, String modelId, Fu Optional .ofNullable(models.remove(modelId)) .filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)) - .ifPresent(model -> { - checkpointDao.putModelCheckpoint(modelId, toCheckpoint.apply(model.getModel())); - model.setLastCheckpointTime(now); - }); + .ifPresent(model -> { checkpointDao.putModelCheckpoint(modelId, toCheckpoint.apply(model.getModel())); }); + } + + /** + * Stops hosting the model and creates a checkpoint. + * + * @param detectorId ID of the detector + * @param modelId ID of the model to stop hosting + * @param listener onResponse is called with null when the operation is completed + */ + public void stopModel(String detectorId, String modelId, ActionListener listener) { + logger.info(String.format("Stopping detector %s model %s", detectorId, modelId)); + stopModel( + forests, + modelId, + this::toCheckpoint, + ActionListener.wrap(r -> stopModel(thresholds, modelId, this::toCheckpoint, listener), listener::onFailure) + ); + } + + private void stopModel( + Map> models, + String modelId, + Function toCheckpoint, + ActionListener listener + ) { + Instant now = clock.instant(); + Optional> modelState = Optional + .ofNullable(models.remove(modelId)) + .filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now)); + if (modelState.isPresent()) { + modelState + .ifPresent( + model -> checkpointDao + .putModelCheckpoint( + modelId, + toCheckpoint.apply(model.getModel()), + ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure) + ) + ); + } else { + listener.onResponse(null); + } + ; } /** diff --git a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java index c92c05f4..1a9f309e 100644 --- a/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java +++ b/src/test/java/com/amazon/opendistroforelasticsearch/ad/ml/ModelManagerTests.java @@ -555,6 +555,68 @@ public void stopModel_saveThresholdCheckpoint() { verify(checkpointDao).putModelCheckpoint(thresholdModelId, checkpoint); } + @Test + @SuppressWarnings("unchecked") + public void stopModel_returnExpectedToListener_whenRcfStop() { + RandomCutForest forest = mock(RandomCutForest.class); + when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint)); + when(rcfSerde.fromJson(checkpoint)).thenReturn(forest); + when(rcfSerde.toJson(forest)).thenReturn(checkpoint); + modelManager.getRcfResult(detectorId, rcfModelId, new double[0]); + when(clock.instant()).thenReturn(Instant.EPOCH); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.stopModel(detectorId, rcfModelId, listener); + + verify(listener).onResponse(eq(null)); + } + + @Test + @SuppressWarnings("unchecked") + public void stopModel_returnExpectedToListener_whenThresholdStop() { + when(checkpointDao.getModelCheckpoint(thresholdModelId)).thenReturn(Optional.of(checkpoint)); + PowerMockito.doReturn(hybridThresholdingModel).when(gson).fromJson(checkpoint, thresholdingModelClass); + PowerMockito.doReturn(checkpoint).when(gson).toJson(hybridThresholdingModel); + modelManager.getThresholdingResult(detectorId, thresholdModelId, 0); + when(clock.instant()).thenReturn(Instant.EPOCH); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onResponse(null); + return null; + }).when(checkpointDao).putModelCheckpoint(eq(thresholdModelId), eq(checkpoint), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.stopModel(detectorId, thresholdModelId, listener); + + verify(listener).onResponse(eq(null)); + } + + @Test + @SuppressWarnings("unchecked") + public void stopModel_throwToListener_whenCheckpointFail() { + RandomCutForest forest = mock(RandomCutForest.class); + when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint)); + when(rcfSerde.fromJson(checkpoint)).thenReturn(forest); + when(rcfSerde.toJson(forest)).thenReturn(checkpoint); + modelManager.getRcfResult(detectorId, rcfModelId, new double[0]); + when(clock.instant()).thenReturn(Instant.EPOCH); + doAnswer(invocation -> { + ActionListener listener = invocation.getArgument(2); + listener.onFailure(new RuntimeException()); + return null; + }).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class)); + + ActionListener listener = mock(ActionListener.class); + modelManager.stopModel(detectorId, rcfModelId, listener); + + verify(listener).onFailure(any(Exception.class)); + } + @Test public void clear_deleteRcfCheckpoint() {