Skip to content

Commit

Permalink
Fix inference logic and standardize config index mapping (#1284)
Browse files Browse the repository at this point in the history
This commit addresses several issues and improvements in the inference logic and config index mapping:

1. Fixes in RealTimeInferencer:

* Previously, we checked if the last update time of the model state was within the current interval and skipped inference if it was. However, this led to excessive skipping of inference because the last update time was updated when retrieving the model state from the cache.
* Introduced lastSeenExecutionEndTime in the model state, which specifically tracks the last time a sample was processed during inference (not training). This ensures more accurate control over when inference should be skipped.

2. Consistent Naming in Config Index Mapping:

* To maintain consistency across the codebase, changed defaultFill to default_fill in the Config index mapping, following the underscore naming convention used elsewhere.

3. Additional Null Checks:

* Added more null checks for the defaultFill field in the Config constructor to improve robustness.

Testing:
* Added a smoke test to allow the job scheduler to trigger anomaly detection inferencing, successfully reproducing and verifying the fix for item #1.* added unit tests for item #3.

Signed-off-by: Kaituo Li <kaituo@amazon.com>
  • Loading branch information
kaituo committed Aug 23, 2024
1 parent dc85dc4 commit 2922bbd
Show file tree
Hide file tree
Showing 17 changed files with 314 additions and 67 deletions.
60 changes: 60 additions & 0 deletions .github/workflows/long_running.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
name: Run long running tests
on:
push:
branches:
- "*"
pull_request:
branches:
- "*"

env:
ACTIONS_ALLOW_USE_UNSECURE_NODE_VERSION: true

jobs:
Get-CI-Image-Tag:
uses: opensearch-project/opensearch-build/.github/workflows/get-ci-image-tag.yml@main
with:
product: opensearch

Run-Tests:
needs: Get-CI-Image-Tag
runs-on: ubuntu-latest
strategy:
matrix:
# each test scenario (rule, hc, single_stream) is treated as a separate job.
test: [smoke]
fail-fast: false
concurrency:
# The concurrency setting is used to limit the concurrency of each test scenario group to ensure they do not run concurrently on the same machine.
group: ${{ github.workflow }}-${{ matrix.test }}
name: Run long running tests

container:
# using the same image which is used by opensearch-build team to build the OpenSearch Distribution
# this image tag is subject to change as more dependencies and updates will arrive over time
image: ${{ needs.Get-CI-Image-Tag.outputs.ci-image-version-linux }}
# need to switch to root so that github actions can install runner binary on container without permission issues.
options: --user root

steps:
- name: Setup Java
uses: actions/setup-java@v3
with:
distribution: 'temurin'
java-version: 21

- name: Checkout AD
uses: actions/checkout@v3

- name: Build and Run Tests
run: |
chown -R 1000:1000 `pwd`
case ${{ matrix.test }} in
smoke)
su `id -un 1000` -c "./gradlew integTest --tests 'org.opensearch.ad.e2e.SingleStreamSmokeIT' \
-Dtests.seed=B4BA12CCF1D9E825 -Dtests.security.manager=false \
-Dtests.jvm.argline='-XX:TieredStopAtLevel=1 -XX:ReservedCodeCacheSize=64m' \
-Dtests.locale=ar-JO -Dtests.timezone=Asia/Samarkand -Dlong-running=true \
-Dtests.timeoutSuite=3600000! -Dtest.logs=true"
;;
esac
6 changes: 6 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,12 @@ integTest {
}
}

if (System.getProperty("long-running") == null || System.getProperty("long-running") == "false") {
filter {
excludeTestsMatching "org.opensearch.ad.e2e.SingleStreamSmokeIT"
}
}

// The 'doFirst' delays till execution time.
doFirst {
// Tell the test JVM if the cluster JVM is running under a debugger so that tests can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
public class ImputationOption implements Writeable, ToXContent {
// field name in toXContent
public static final String METHOD_FIELD = "method";
public static final String DEFAULT_FILL_FIELD = "defaultFill";
public static final String DEFAULT_FILL_FIELD = "default_fill";

private final ImputationMethod method;
private final Map<String, Double> defaultFill;
Expand Down Expand Up @@ -152,7 +152,7 @@ public int hashCode() {

@Override
public String toString() {
return new ToStringBuilder(this).append("method", method).append("defaultFill", defaultFill).toString();
return new ToStringBuilder(this).append("method", method).append("default_fill", defaultFill).toString();
}

public ImputationMethod getMethod() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ public <RCFDescriptor extends AnomalyDescriptor> IntermediateResultType score(
throw e;
} finally {
modelState.setLastUsedTime(clock.instant());
modelState.setLastSeenExecutionEndTime(clock.instant());
}
return createEmptyResult();
}
Expand Down
10 changes: 10 additions & 0 deletions src/main/java/org/opensearch/timeseries/ml/ModelState.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class ModelState<T> implements org.opensearch.timeseries.ExpiringState {
// time when the ML model was used last time
protected Instant lastUsedTime;
protected Instant lastCheckpointTime;
protected Instant lastSeenExecutionEndTime;
protected Clock clock;
protected float priority;
protected Deque<Sample> samples;
Expand Down Expand Up @@ -74,6 +75,7 @@ public ModelState(
this.priority = priority;
this.entity = entity;
this.samples = samples;
this.lastSeenExecutionEndTime = Instant.MIN;
}

/**
Expand Down Expand Up @@ -249,4 +251,12 @@ public Map<String, Object> getModelStateAsMap() {
}
};
}

public Instant getLastSeenExecutionEndTime() {
return lastSeenExecutionEndTime;
}

public void setLastSeenExecutionEndTime(Instant lastSeenExecutionEndTime) {
this.lastSeenExecutionEndTime = lastSeenExecutionEndTime;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.timeseries.ml;

import java.time.Instant;
import java.util.Collections;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -134,11 +135,14 @@ private boolean processWithTimeout(
}

private boolean tryProcess(Sample sample, ModelState<RCFModelType> modelState, Config config, String taskId, long curExecutionEnd) {
// execution end time (when job starts execution in this interval) >= last used time => the model state is updated in
// execution end time (when job starts execution in this interval) >= last seen execution end time => the model state is updated in
// previous intervals
// This can happen while scheduled to waiting some other threads have already scored the same interval (e.g., during tests
// when everything happens fast)
if (curExecutionEnd < modelState.getLastUsedTime().toEpochMilli()) {
// This branch being true can happen while scheduled to waiting some other threads have already scored the same interval
// (e.g., during tests when everything happens fast)
// We cannot use last used time as it will be updated whenever we update its priority in CacheBuffer.update when there is a
// PriorityCache.get.
if (modelState.getLastSeenExecutionEndTime() != Instant.MIN
&& curExecutionEnd < modelState.getLastSeenExecutionEndTime().toEpochMilli()) {
return false;
}
String modelId = modelState.getModelId();
Expand Down
25 changes: 16 additions & 9 deletions src/main/java/org/opensearch/timeseries/model/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,29 +225,36 @@ protected Config(
: features.stream().filter(Feature::getEnabled).collect(Collectors.toList());

Map<String, Double> defaultFill = imputationOption.getDefaultFill();
if (defaultFill.isEmpty() && enabledFeatures.size() > 0) {

// Case 1: enabledFeatures == null && defaultFill != null
if (enabledFeatures == null && defaultFill != null && !defaultFill.isEmpty()) {
issueType = ValidationIssueType.IMPUTATION;
errorMessage = "Enabled features list is null, but default fill values are provided.";
return;
}

// Case 2: enabledFeatures != null && defaultFill == null
if (enabledFeatures != null && (defaultFill == null || defaultFill.isEmpty())) {
issueType = ValidationIssueType.IMPUTATION;
errorMessage = "No given values for fixed value imputation";
errorMessage = "Enabled features are present, but no default fill values are provided.";
return;
}

// Check if the length of the defaultFill array matches the number of expected features
if (enabledFeatures == null || defaultFill.size() != enabledFeatures.size()) {
// Case 3: enabledFeatures.size() != defaultFill.size()
if (enabledFeatures != null && defaultFill != null && defaultFill.size() != enabledFeatures.size()) {
issueType = ValidationIssueType.IMPUTATION;
errorMessage = String
.format(
Locale.ROOT,
"Incorrect number of values to fill. Got: %d. Expected: %d.",
"Mismatch between the number of enabled features and default fill values. Number of default fill values: %d. Number of enabled features: %d.",
defaultFill.size(),
enabledFeatures == null ? 0 : enabledFeatures.size()
enabledFeatures.size()
);
return;
}

Map<String, Double> defaultFills = imputationOption.getDefaultFill();

for (int i = 0; i < enabledFeatures.size(); i++) {
if (!defaultFills.containsKey(enabledFeatures.get(i).getName())) {
if (!defaultFill.containsKey(enabledFeatures.get(i).getName())) {
issueType = ValidationIssueType.IMPUTATION;
errorMessage = String.format(Locale.ROOT, "Missing feature name: %s.", enabledFeatures.get(i).getName());
return;
Expand Down
2 changes: 1 addition & 1 deletion src/main/resources/mappings/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
"method": {
"type": "keyword"
},
"defaultFill": {
"default_fill": {
"type": "nested",
"properties": {
"feature_name": {
Expand Down
104 changes: 83 additions & 21 deletions src/test/java/org/opensearch/ad/AbstractADSyntheticDataTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.nio.charset.Charset;
import java.time.Duration;
import java.time.Instant;
import java.time.format.DateTimeFormatter;
import java.time.temporal.ChronoUnit;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -163,13 +164,59 @@ protected Map<String, Object> previewWithFailure(String detector, Instant begin,
return entityAsMap(response);
}

protected List<JsonObject> getAnomalyResultByDataTime(
String detectorId,
Instant end,
int entitySize,
RestClient client,
boolean approximateEndTime,
long rangeDurationMillis
) throws InterruptedException {
return getAnomalyResult(
detectorId,
end,
entitySize,
client,
approximateEndTime,
rangeDurationMillis,
"data_end_time",
(h, eSize) -> h.size() == eSize,
entitySize
);
}

protected List<JsonObject> getAnomalyResultByExecutionTime(
String detectorId,
Instant end,
int entitySize,
RestClient client,
boolean approximateEndTime,
long rangeDurationMillis,
int expectedResultSize
) throws InterruptedException {
return getAnomalyResult(
detectorId,
end,
entitySize,
client,
approximateEndTime,
rangeDurationMillis,
"execution_end_time",
(h, eSize) -> h.size() >= eSize,
expectedResultSize
);
}

protected List<JsonObject> getAnomalyResult(
String detectorId,
Instant end,
int entitySize,
RestClient client,
boolean approximateDataEndTime,
long intervalMillis
boolean approximateEndTime,
long rangeDurationMillis,
String endTimeField,
ConditionChecker checker,
int expectedResultSize
) throws InterruptedException {
Request request = new Request("POST", "/_plugins/_anomaly_detection/detectors/results/_search");

Expand All @@ -191,12 +238,12 @@ protected List<JsonObject> getAnomalyResult(
+ " },\n"
+ " {\n"
+ " \"range\": {\n"
+ " \"data_end_time\": {\n";
+ " \"%s\": {\n";

StringBuilder jsonTemplate = new StringBuilder();
jsonTemplate.append(jsonTemplatePrefix);

if (approximateDataEndTime) {
if (approximateEndTime) {
// we may get two interval results if using gte
jsonTemplate.append(" \"gt\": %d,\n \"lte\": %d\n");
} else {
Expand All @@ -217,10 +264,11 @@ protected List<JsonObject> getAnomalyResult(
long dateEndTime = end.toEpochMilli();
String formattedJson = null;

if (approximateDataEndTime) {
formattedJson = String.format(Locale.ROOT, jsonTemplate.toString(), detectorId, dateEndTime - intervalMillis, dateEndTime);
if (approximateEndTime) {
formattedJson = String
.format(Locale.ROOT, jsonTemplate.toString(), detectorId, endTimeField, dateEndTime - rangeDurationMillis, dateEndTime);
} else {
formattedJson = String.format(Locale.ROOT, jsonTemplate.toString(), detectorId, dateEndTime, dateEndTime);
formattedJson = String.format(Locale.ROOT, jsonTemplate.toString(), detectorId, endTimeField, dateEndTime, dateEndTime);
}

request.setJsonEntity(formattedJson);
Expand All @@ -231,25 +279,16 @@ protected List<JsonObject> getAnomalyResult(
do {
try {
JsonArray hits = getHits(client, request);
if (hits != null && hits.size() == entitySize) {
assertTrue("empty response", hits != null);
assertTrue("returned more than " + hits.size() + " results.", hits.size() == entitySize);
if (hits != null && checker.checkCondition(hits, entitySize)) {
List<JsonObject> res = new ArrayList<>();
for (int i = 0; i < entitySize; i++) {
for (int i = 0; i < hits.size(); i++) {
JsonObject source = hits.get(i).getAsJsonObject().get("_source").getAsJsonObject();
res.add(source);
}

return res;
} else {
LOG
.info(
"wait for result, previous result: {}, size: {}, eval result {}, expected {}",
hits,
hits.size(),
hits != null && hits.size() == entitySize,
entitySize
);
LOG.info("wait for result, previous result: {}, size: {}", hits, hits.size());
client.performRequest(new Request("POST", String.format(Locale.ROOT, "/%s/_refresh", ".opendistro-anomaly-results*")));
}
Thread.sleep(2_000 * entitySize);
Expand All @@ -275,7 +314,7 @@ protected List<JsonObject> getAnomalyResult(

protected List<JsonObject> getRealTimeAnomalyResult(String detectorId, Instant end, int entitySize, RestClient client)
throws InterruptedException {
return getAnomalyResult(detectorId, end, entitySize, client, false, 0);
return getAnomalyResultByDataTime(detectorId, end, entitySize, client, false, 0);
}

public double getAnomalyGrade(JsonObject source) {
Expand Down Expand Up @@ -462,7 +501,7 @@ protected List<JsonObject> waitForHistoricalDetector(

Thread.sleep(1_000);

List<JsonObject> sourceList = getAnomalyResult(detectorId, end, entitySize, client, true, intervalMillis);
List<JsonObject> sourceList = getAnomalyResultByDataTime(detectorId, end, entitySize, client, true, intervalMillis);
if (sourceList.size() > 0 && getAnomalyGrade(sourceList.get(0)) >= 0) {
return sourceList;
}
Expand Down Expand Up @@ -624,7 +663,30 @@ protected List<JsonObject> startHistoricalDetector(
);
}

protected long getWindowDelayMinutes(List<JsonObject> data, int trainTestSplit, String timestamp) {
// e.g., "2019-11-02T00:59:00Z"
String trainTimeStr = data.get(trainTestSplit - 1).get("timestamp").getAsString();
Instant trainTime = Instant.from(DateTimeFormatter.ISO_INSTANT.parse(trainTimeStr));
/*
* The {@code CompositeRetriever.PageIterator.hasNext()} method checks if a request is expired
* relative to the current system time. This method is designed to ensure that the execution time
* is set to either the current time or a future time to prevent premature expirations in our tests.
*
* Also, AD accepts windowDelay in the unit of minutes. Thus, we need to convert the delay in minutes. This will
* make it easier to search for results based on data end time. Otherwise, real data time and the converted
* data time from request time.
* Assume x = real data time. y= real window delay. y'= window delay in minutes. If y and y' are different,
* x + y - y' != x.
*/
return Duration.between(trainTime, Instant.now()).toMinutes();
}

public static boolean areDoublesEqual(double d1, double d2) {
return Math.abs(d1 - d2) < EPSILON;
}

@FunctionalInterface
public interface ConditionChecker {
boolean checkCondition(JsonArray hits, int expectedSize);
}
}
Loading

0 comments on commit 2922bbd

Please sign in to comment.