Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

add async stopModel #93

Merged
merged 1 commit into from
May 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -537,10 +537,50 @@ private <T> void stopModel(Map<String, ModelState<T>> models, String modelId, Fu
Optional
.ofNullable(models.remove(modelId))
.filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now))
.ifPresent(model -> {
checkpointDao.putModelCheckpoint(modelId, toCheckpoint.apply(model.getModel()));
model.setLastCheckpointTime(now);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is setting last checkpoint time not necessary any more?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because the model has been removed from the managed models at line 538 and has become garbage at the end of the method.

});
.ifPresent(model -> { checkpointDao.putModelCheckpoint(modelId, toCheckpoint.apply(model.getModel())); });
}

/**
* Stops hosting the model and creates a checkpoint.
*
* @param detectorId ID of the detector
* @param modelId ID of the model to stop hosting
* @param listener onResponse is called with null when the operation is completed
*/
public void stopModel(String detectorId, String modelId, ActionListener<Void> listener) {
logger.info(String.format("Stopping detector %s model %s", detectorId, modelId));
stopModel(
forests,
modelId,
this::toCheckpoint,
ActionListener.wrap(r -> stopModel(thresholds, modelId, this::toCheckpoint, listener), listener::onFailure)
);
}

private <T> void stopModel(
Map<String, ModelState<T>> models,
String modelId,
Function<T, String> toCheckpoint,
ActionListener<Void> listener
) {
Instant now = clock.instant();
Optional<ModelState<T>> modelState = Optional
.ofNullable(models.remove(modelId))
.filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now));
if (modelState.isPresent()) {
modelState
.ifPresent(
model -> checkpointDao
.putModelCheckpoint(
modelId,
toCheckpoint.apply(model.getModel()),
ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure)
)
);
} else {
listener.onResponse(null);
}
;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,68 @@ public void stopModel_saveThresholdCheckpoint() {
verify(checkpointDao).putModelCheckpoint(thresholdModelId, checkpoint);
}

@Test
@SuppressWarnings("unchecked")
public void stopModel_returnExpectedToListener_whenRcfStop() {
RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(rcfSerde.toJson(forest)).thenReturn(checkpoint);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0]);
when(clock.instant()).thenReturn(Instant.EPOCH);
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.stopModel(detectorId, rcfModelId, listener);

verify(listener).onResponse(eq(null));
}

@Test
@SuppressWarnings("unchecked")
public void stopModel_returnExpectedToListener_whenThresholdStop() {
when(checkpointDao.getModelCheckpoint(thresholdModelId)).thenReturn(Optional.of(checkpoint));
PowerMockito.doReturn(hybridThresholdingModel).when(gson).fromJson(checkpoint, thresholdingModelClass);
PowerMockito.doReturn(checkpoint).when(gson).toJson(hybridThresholdingModel);
modelManager.getThresholdingResult(detectorId, thresholdModelId, 0);
when(clock.instant()).thenReturn(Instant.EPOCH);
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(thresholdModelId), eq(checkpoint), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.stopModel(detectorId, thresholdModelId, listener);

verify(listener).onResponse(eq(null));
}

@Test
@SuppressWarnings("unchecked")
public void stopModel_throwToListener_whenCheckpointFail() {
RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(rcfSerde.toJson(forest)).thenReturn(checkpoint);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0]);
when(clock.instant()).thenReturn(Instant.EPOCH);
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException());
return null;
}).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.stopModel(detectorId, rcfModelId, listener);

verify(listener).onFailure(any(Exception.class));
}

@Test
public void clear_deleteRcfCheckpoint() {
String checkpoint = "checkpoint";
Expand Down