Skip to content
This repository has been archived by the owner on Aug 2, 2022. It is now read-only.

Commit

Permalink
add async stopModel (#93)
Browse files Browse the repository at this point in the history
  • Loading branch information
wnbts authored May 6, 2020
1 parent 3b4d21f commit a852539
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -538,10 +538,50 @@ private <T> void stopModel(Map<String, ModelState<T>> models, String modelId, Fu
Optional
.ofNullable(models.remove(modelId))
.filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now))
.ifPresent(model -> {
checkpointDao.putModelCheckpoint(modelId, toCheckpoint.apply(model.getModel()));
model.setLastCheckpointTime(now);
});
.ifPresent(model -> { checkpointDao.putModelCheckpoint(modelId, toCheckpoint.apply(model.getModel())); });
}

/**
* Stops hosting the model and creates a checkpoint.
*
* @param detectorId ID of the detector
* @param modelId ID of the model to stop hosting
* @param listener onResponse is called with null when the operation is completed
*/
public void stopModel(String detectorId, String modelId, ActionListener<Void> listener) {
logger.info(String.format("Stopping detector %s model %s", detectorId, modelId));
stopModel(
forests,
modelId,
this::toCheckpoint,
ActionListener.wrap(r -> stopModel(thresholds, modelId, this::toCheckpoint, listener), listener::onFailure)
);
}

private <T> void stopModel(
Map<String, ModelState<T>> models,
String modelId,
Function<T, String> toCheckpoint,
ActionListener<Void> listener
) {
Instant now = clock.instant();
Optional<ModelState<T>> modelState = Optional
.ofNullable(models.remove(modelId))
.filter(model -> model.getLastCheckpointTime().plus(checkpointInterval).isBefore(now));
if (modelState.isPresent()) {
modelState
.ifPresent(
model -> checkpointDao
.putModelCheckpoint(
modelId,
toCheckpoint.apply(model.getModel()),
ActionListener.wrap(r -> listener.onResponse(null), listener::onFailure)
)
);
} else {
listener.onResponse(null);
}
;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,68 @@ public void stopModel_saveThresholdCheckpoint() {
verify(checkpointDao).putModelCheckpoint(thresholdModelId, checkpoint);
}

@Test
@SuppressWarnings("unchecked")
public void stopModel_returnExpectedToListener_whenRcfStop() {
RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(rcfSerde.toJson(forest)).thenReturn(checkpoint);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0]);
when(clock.instant()).thenReturn(Instant.EPOCH);
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.stopModel(detectorId, rcfModelId, listener);

verify(listener).onResponse(eq(null));
}

@Test
@SuppressWarnings("unchecked")
public void stopModel_returnExpectedToListener_whenThresholdStop() {
when(checkpointDao.getModelCheckpoint(thresholdModelId)).thenReturn(Optional.of(checkpoint));
PowerMockito.doReturn(hybridThresholdingModel).when(gson).fromJson(checkpoint, thresholdingModelClass);
PowerMockito.doReturn(checkpoint).when(gson).toJson(hybridThresholdingModel);
modelManager.getThresholdingResult(detectorId, thresholdModelId, 0);
when(clock.instant()).thenReturn(Instant.EPOCH);
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(2);
listener.onResponse(null);
return null;
}).when(checkpointDao).putModelCheckpoint(eq(thresholdModelId), eq(checkpoint), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.stopModel(detectorId, thresholdModelId, listener);

verify(listener).onResponse(eq(null));
}

@Test
@SuppressWarnings("unchecked")
public void stopModel_throwToListener_whenCheckpointFail() {
RandomCutForest forest = mock(RandomCutForest.class);
when(checkpointDao.getModelCheckpoint(rcfModelId)).thenReturn(Optional.of(checkpoint));
when(rcfSerde.fromJson(checkpoint)).thenReturn(forest);
when(rcfSerde.toJson(forest)).thenReturn(checkpoint);
modelManager.getRcfResult(detectorId, rcfModelId, new double[0]);
when(clock.instant()).thenReturn(Instant.EPOCH);
doAnswer(invocation -> {
ActionListener<Void> listener = invocation.getArgument(2);
listener.onFailure(new RuntimeException());
return null;
}).when(checkpointDao).putModelCheckpoint(eq(rcfModelId), eq(checkpoint), any(ActionListener.class));

ActionListener<Void> listener = mock(ActionListener.class);
modelManager.stopModel(detectorId, rcfModelId, listener);

verify(listener).onFailure(any(Exception.class));
}

@Test
public void clear_deleteRcfCheckpoint() {

Expand Down

0 comments on commit a852539

Please sign in to comment.