From d39848f8bdc04dd57fcf116e1b7f37f9c27058f8 Mon Sep 17 00:00:00 2001 From: lai <57818076+wnbts@users.noreply.github.com> Date: Fri, 7 Jan 2022 13:22:34 -0800 Subject: [PATCH] anomaly localization integration step 3 Signed-off-by: lai <57818076+wnbts@users.noreply.github.com> --- ...ut.java => AnomalyLocalizationOutput.java} | 55 +++++++++----- .../anomalylocalization/AnomalyLocalizer.java | 2 +- .../AnomalyLocalizerImpl.java | 76 +++++++++---------- ...va => AnomalyLocalizationOutputTests.java} | 16 ++-- .../AnomalyLocalizerImplTests.java | 40 +++++----- 5 files changed, 101 insertions(+), 88 deletions(-) rename ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/{Output.java => AnomalyLocalizationOutput.java} (83%) rename ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/{OutputTests.java => AnomalyLocalizationOutputTests.java} (73%) diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/Output.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationOutput.java similarity index 83% rename from ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/Output.java rename to ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationOutput.java index 2db2cd79a4..1c8b3f59db 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/Output.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationOutput.java @@ -25,6 +25,7 @@ import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentParser; +import org.opensearch.ml.common.parameter.Output; import lombok.Data; import lombok.EqualsAndHashCode; @@ -39,7 +40,7 @@ */ @Data @NoArgsConstructor -public class Output implements org.opensearch.ml.common.parameter.Output { +public class AnomalyLocalizationOutput implements Output { public static final String FIELD_RESULTS = "results"; public static final String FIELD_NAME = "name"; @@ -47,7 +48,7 @@ public class Output implements org.opensearch.ml.common.parameter.Output { private Map results = new HashMap<>(); // aggregation name to result. - public Output(StreamInput in) throws IOException { + public AnomalyLocalizationOutput(StreamInput in) throws IOException { this.results = in.readMap(StreamInput::readString, Result::new); } @@ -197,26 +198,29 @@ public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params par return builder; } - public static Output parse(XContentParser parser) throws IOException { - Output output = new Output(); + public static AnomalyLocalizationOutput parse(XContentParser parser) throws IOException { + AnomalyLocalizationOutput output = new AnomalyLocalizationOutput(); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { switch (parser.currentName()) { case FIELD_RESULTS: parseResultMapEntry(parser, output); break; + default: + parser.skipChildren(); + break; } } ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.currentToken(), parser); return output; } - private static void parseResultMapEntry(XContentParser parser, Output output) throws IOException { + private static void parseResultMapEntry(XContentParser parser, AnomalyLocalizationOutput output) throws IOException { ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); String key = null; - Output.Result result = new Output.Result(); + AnomalyLocalizationOutput.Result result = new AnomalyLocalizationOutput.Result(); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { switch (parser.currentName()) { case FIELD_NAME: @@ -226,6 +230,9 @@ private static void parseResultMapEntry(XContentParser parser, Output output) th case FIELD_RESULT: parseResult(parser, result); break; + default: + parser.skipChildren(); + break; } } ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.currentToken(), parser); @@ -234,14 +241,14 @@ private static void parseResultMapEntry(XContentParser parser, Output output) th ensureExpectedToken(XContentParser.Token.END_ARRAY, parser.currentToken(), parser); } - private static void parseResult(XContentParser parser, Output.Result result) throws IOException { + private static void parseResult(XContentParser parser, AnomalyLocalizationOutput.Result result) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { switch (parser.currentName()) { - case Output.Result.FIELD_BUCKETS: + case AnomalyLocalizationOutput.Result.FIELD_BUCKETS: ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - Output.Bucket bucket = new Output.Bucket(); + AnomalyLocalizationOutput.Bucket bucket = new AnomalyLocalizationOutput.Bucket(); parseBucket(parser, bucket); result.getBuckets().add(bucket); } @@ -252,39 +259,42 @@ private static void parseResult(XContentParser parser, Output.Result result) thr ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.currentToken(), parser); } - private static void parseBucket(XContentParser parser, Output.Bucket bucket) throws IOException { + private static void parseBucket(XContentParser parser, AnomalyLocalizationOutput.Bucket bucket) throws IOException { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { switch (parser.currentName()) { - case Output.Bucket.FIELD_START_TIME: + case AnomalyLocalizationOutput.Bucket.FIELD_START_TIME: parser.nextToken(); bucket.setStartTime(parser.longValue()); break; - case Output.Bucket.FIELD_END_TIME: + case AnomalyLocalizationOutput.Bucket.FIELD_END_TIME: parser.nextToken(); bucket.setEndTime(parser.longValue()); break; - case Output.Bucket.FIELD_OVERALL_VALUE: + case AnomalyLocalizationOutput.Bucket.FIELD_OVERALL_VALUE: parser.nextToken(); bucket.setOverallAggValue(parser.doubleValue()); break; - case Output.Bucket.FIELD_ENTITIES: + case AnomalyLocalizationOutput.Bucket.FIELD_ENTITIES: parseEntities(parser, bucket); break; + default: + parser.skipChildren(); + break; } } ensureExpectedToken(XContentParser.Token.END_OBJECT, parser.currentToken(), parser); } - private static void parseEntities(XContentParser parser, Output.Bucket bucket) throws IOException { - List entities = new ArrayList<>(); + private static void parseEntities(XContentParser parser, AnomalyLocalizationOutput.Bucket bucket) throws IOException { + List entities = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - Output.Entity entity = new Output.Entity(); + AnomalyLocalizationOutput.Entity entity = new AnomalyLocalizationOutput.Entity(); while (parser.nextToken() != XContentParser.Token.END_OBJECT) { switch (parser.currentName()) { - case Output.Entity.FIELD_KEY: + case AnomalyLocalizationOutput.Entity.FIELD_KEY: List key = new ArrayList<>(); ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.nextToken(), parser); while (parser.nextToken() != XContentParser.Token.END_ARRAY) { @@ -292,18 +302,21 @@ private static void parseEntities(XContentParser parser, Output.Bucket bucket) t } entity.setKey(key); break; - case Output.Entity.FIELD_CONTRIBUTION_VALUE: + case AnomalyLocalizationOutput.Entity.FIELD_CONTRIBUTION_VALUE: parser.nextToken(); entity.setContributionValue(parser.doubleValue()); break; - case Output.Entity.FIELD_BASE_VALUE: + case AnomalyLocalizationOutput.Entity.FIELD_BASE_VALUE: parser.nextToken(); entity.setBaseValue(parser.doubleValue()); break; - case Output.Entity.FIELD_NEW_VALUE: + case AnomalyLocalizationOutput.Entity.FIELD_NEW_VALUE: parser.nextToken(); entity.setNewValue(parser.doubleValue()); break; + default: + parser.skipChildren(); + break; } } entities.add(entity); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizer.java index 7e88cdf681..32b41fbe6d 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizer.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizer.java @@ -25,5 +25,5 @@ public interface AnomalyLocalizer { * @param input Information about aggregation and metadata. * @param listener Listener to localized details or exception. */ - void getLocalizationResults(Input input, ActionListener listener); + void getLocalizationResults(Input input, ActionListener listener); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java index c48c901279..29ffb51558 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImpl.java @@ -93,34 +93,34 @@ public AnomalyLocalizerImpl( */ @Override @SneakyThrows - public void getLocalizationResults(Input input, ActionListener listener) { - Output output = new Output(); + public void getLocalizationResults(Input input, ActionListener listener) { + AnomalyLocalizationOutput output = new AnomalyLocalizationOutput(); input.getAggregations().stream().forEach(agg -> localizeByBuckets(input, agg, output, notifyOnce(listener))); } /** * Bucketizes data by time and get overall aggregates. */ - private void localizeByBuckets(Input input, AggregationBuilder agg, Output output, ActionListener listener) { + private void localizeByBuckets(Input input, AggregationBuilder agg, AnomalyLocalizationOutput output, ActionListener listener) { LocalizationTimeBuckets timeBuckets = getTimeBuckets(input); getOverallAggregates(input, timeBuckets, agg, output, listener); } - private void getOverallAggregates(Input input, LocalizationTimeBuckets timeBuckets, AggregationBuilder agg, Output output, - ActionListener listener) { + private void getOverallAggregates(Input input, LocalizationTimeBuckets timeBuckets, AggregationBuilder agg, AnomalyLocalizationOutput output, + ActionListener listener) { MultiSearchRequest searchRequest = newSearchRequestForOverallAggregates(input, agg, timeBuckets); client.multiSearch(searchRequest, wrap(r -> onOverallAggregatesResponse(r, input, agg, output, timeBuckets, listener), listener::onFailure)); } - private void onOverallAggregatesResponse(MultiSearchResponse response, Input input, AggregationBuilder agg, Output output, - LocalizationTimeBuckets timeBuckets, ActionListener listener) { - Output.Result result = new Output.Result(); + private void onOverallAggregatesResponse(MultiSearchResponse response, Input input, AggregationBuilder agg, AnomalyLocalizationOutput output, + LocalizationTimeBuckets timeBuckets, ActionListener listener) { + AnomalyLocalizationOutput.Result result = new AnomalyLocalizationOutput.Result(); List> intervals = timeBuckets.getAllIntervals(); for (int i = 0; i < intervals.size(); i++) { double value = getDoubleValue((SingleValue) response.getResponses()[i].getResponse().getAggregations().get(agg.getName())); - Output.Bucket bucket = new Output.Bucket(); + AnomalyLocalizationOutput.Bucket bucket = new AnomalyLocalizationOutput.Bucket(); bucket.setStartTime(intervals.get(i).getKey()); bucket.setEndTime(intervals.get(i).getValue()); bucket.setOverallAggValue(value); @@ -134,8 +134,8 @@ private void onOverallAggregatesResponse(MultiSearchResponse response, Input inp /** * Identifies buckets of data that need localization and localizes entities in the bucket. */ - private void getLocalizedEntities(Input input, AggregationBuilder agg, Output.Result result, Output output, - ActionListener listener) { + private void getLocalizedEntities(Input input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput output, + ActionListener listener) { if (setBase(result, input)) { Counter counter = new HybridCounter(); result.getBuckets().stream().filter(e -> e.getBase().isPresent() && e.getBase().get().equals(e)) @@ -144,19 +144,19 @@ private void getLocalizedEntities(Input input, AggregationBuilder agg, Output.Re outputIfResultsAreComplete(output, listener); } - private void outputIfResultsAreComplete(Output output, ActionListener listener) { + private void outputIfResultsAreComplete(AnomalyLocalizationOutput output, ActionListener listener) { if (output.getResults().values().stream().allMatch(this::isResultComplete)) { listener.onResponse(output); } } - private boolean isResultComplete(Output.Result result) { + private boolean isResultComplete(AnomalyLocalizationOutput.Result result) { // When completed is null, the bucket does not localization, base bucket for example. return result.getBuckets().stream().allMatch(e -> e.getCompleted() == null || e.getCompleted().get() == true); } - private void processBaseEntry(Input input, AggregationBuilder agg, Output.Result result, Output.Bucket bucket, Counter counter, - Optional> afterKey, Output output, ActionListener listener) { + private void processBaseEntry(Input input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput.Bucket bucket, Counter counter, + Optional> afterKey, AnomalyLocalizationOutput output, ActionListener listener) { SearchRequest request = newSearchRequestForEntry(input, agg, bucket, afterKey); client.search(request, wrap(r -> onBaseEntryResponse(r, input, agg, result, bucket, counter, output, listener), listener::onFailure)); @@ -165,8 +165,8 @@ private void processBaseEntry(Input input, AggregationBuilder agg, Output.Result /** * Keeps info from entities in the base bucket to compare entities from new buckets against. */ - private void onBaseEntryResponse(SearchResponse response, Input input, AggregationBuilder agg, Output.Result result, - Output.Bucket bucket, Counter counter, Output output, ActionListener listener) { + private void onBaseEntryResponse(SearchResponse response, Input input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, + AnomalyLocalizationOutput.Bucket bucket, Counter counter, AnomalyLocalizationOutput output, ActionListener listener) { Optional respAgg = Optional.ofNullable(response.getAggregations()).map(aggs -> (CompositeAggregation) aggs.get(agg.getName())); respAgg.map(a -> a.getBuckets()).orElse(Collections.emptyList()).stream().forEach(b -> { @@ -179,12 +179,12 @@ private void onBaseEntryResponse(SearchResponse response, Input input, Aggregati bucket.setCounter(Optional.of(counter)); result.getBuckets().stream().filter(e -> e.getCompleted() != null && e.getCompleted().get() == false) .forEach(e -> { - PriorityQueue queue; + PriorityQueue queue; if (e.getOverallAggValue() > 0) { - queue = new PriorityQueue(input.getNumOutputs(), + queue = new PriorityQueue(input.getNumOutputs(), (a, b) -> (int) Math.signum(a.getContributionValue() - b.getContributionValue())); } else { - queue = new PriorityQueue(input.getNumOutputs(), + queue = new PriorityQueue(input.getNumOutputs(), (a, b) -> (int) Math.signum(b.getContributionValue() - a.getContributionValue())); } ; @@ -193,8 +193,8 @@ private void onBaseEntryResponse(SearchResponse response, Input input, Aggregati } } - private void processNewEntry(Input input, AggregationBuilder agg, Output.Result result, Output.Bucket bucket, Optional> afterKey, PriorityQueue queue, Output output, ActionListener listener) { + private void processNewEntry(Input input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, AnomalyLocalizationOutput.Bucket bucket, Optional> afterKey, PriorityQueue queue, AnomalyLocalizationOutput output, ActionListener listener) { SearchRequest request = newSearchRequestForEntry(input, agg, bucket, afterKey); client.search(request, wrap(r -> onNewEntryResponse(r, input, agg, result, bucket, queue, output, listener), listener::onFailure)); } @@ -202,14 +202,14 @@ private void processNewEntry(Input input, AggregationBuilder agg, Output.Result /** * Chooses entities from the new bucket that contribute the most to the overall change. */ - private void onNewEntryResponse(SearchResponse response, Input input, AggregationBuilder agg, Output.Result result, - Output.Bucket outputBucket, PriorityQueue queue, Output output, - ActionListener listener) { + private void onNewEntryResponse(SearchResponse response, Input input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, + AnomalyLocalizationOutput.Bucket outputBucket, PriorityQueue queue, AnomalyLocalizationOutput output, + ActionListener listener) { Optional respAgg = Optional.ofNullable(response.getAggregations()).map(aggs -> (CompositeAggregation) aggs.get(agg.getName())); for (CompositeAggregation.Bucket bucket : respAgg.map(a -> a.getBuckets()).orElse(Collections.emptyList())) { List key = toStringKey(bucket.getKey(), input); - Output.Entity entity = new Output.Entity(); + AnomalyLocalizationOutput.Entity entity = new AnomalyLocalizationOutput.Entity(); entity.setKey(key); entity.setNewValue(getDoubleValue((SingleValue) bucket.getAggregations().get(agg.getName()))); entity.setBaseValue(outputBucket.getBase().get().getCounter().get().estimate(key)); @@ -225,7 +225,7 @@ private void onNewEntryResponse(SearchResponse response, Input input, Aggregatio if (afterKey.isPresent()) { processNewEntry(input, agg, result, outputBucket, afterKey, queue, output, listener); } else { - List> keys = queue.stream().map(Output.Entity::getKey).collect(Collectors.toList()); + List> keys = queue.stream().map(AnomalyLocalizationOutput.Entity::getKey).collect(Collectors.toList()); SearchRequest request = newSearchRequestForEntityKeys(input, agg, outputBucket, keys); client.search(request, wrap(r -> onEntityKeysResponse(r, input, agg, result, outputBucket, queue, output, listener), listener::onFailure)); @@ -235,16 +235,16 @@ private void onNewEntryResponse(SearchResponse response, Input input, Aggregatio /** * Updates to date entity contribution values in final output. */ - private void onEntityKeysResponse(SearchResponse response, Input input, AggregationBuilder agg, Output.Result result, - Output.Bucket bucket, PriorityQueue queue, Output output, - ActionListener listener) { - List entities = new ArrayList(queue); + private void onEntityKeysResponse(SearchResponse response, Input input, AggregationBuilder agg, AnomalyLocalizationOutput.Result result, + AnomalyLocalizationOutput.Bucket bucket, PriorityQueue queue, AnomalyLocalizationOutput output, + ActionListener listener) { + List entities = new ArrayList(queue); Optional respAgg = Optional.ofNullable(response.getAggregations()).map(aggs -> (Filters) aggs.get(agg.getName())); for (Filters.Bucket respBucket : respAgg.map(a -> a.getBuckets()).orElse(Collections.emptyList())) { int entityIndex = Integer.parseInt(respBucket.getKeyAsString()); double aggValue = getDoubleValue((SingleValue) respBucket.getAggregations().get(agg.getName())); - Output.Entity entity = entities.get(entityIndex); + AnomalyLocalizationOutput.Entity entity = entities.get(entityIndex); entity.setBaseValue(aggValue); entity.setContributionValue(entity.getNewValue() - entity.getBaseValue()); } @@ -257,7 +257,7 @@ private void onEntityKeysResponse(SearchResponse response, Input input, Aggregat outputIfResultsAreComplete(output, listener); } - private SearchRequest newSearchRequestForEntityKeys(Input input, AggregationBuilder agg, Output.Bucket bucket, + private SearchRequest newSearchRequestForEntityKeys(Input input, AggregationBuilder agg, AnomalyLocalizationOutput.Bucket bucket, List> keys) { RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName()) .from(bucket.getBase().get().getStartTime(), true) @@ -283,7 +283,7 @@ private List toStringKey(Map key, Input input) { return input.getAttributeFieldNames().stream().map(name -> key.get(name).toString()).collect(Collectors.toList()); } - private SearchRequest newSearchRequestForEntry(Input input, AggregationBuilder agg, Output.Bucket bucket, Optional> afterKey) { RangeQueryBuilder timeRangeFilter = new RangeQueryBuilder(input.getTimeFieldName()) .from(bucket.getStartTime(), true) @@ -301,14 +301,14 @@ private SearchRequest newSearchRequestForEntry(Input input, AggregationBuilder a return searchRequest; } - private boolean setBase(Output.Result result, Input input) { + private boolean setBase(AnomalyLocalizationOutput.Result result, Input input) { boolean newEntry = false; - List entries = result.getBuckets(); + List entries = result.getBuckets(); int baseEntryIndex = 0; - Output.Bucket baseEntry = entries.get(baseEntryIndex); + AnomalyLocalizationOutput.Bucket baseEntry = entries.get(baseEntryIndex); baseEntry.setBase(Optional.of(baseEntry)); for (int i = 1; i < entries.size(); i++) { - Output.Bucket currentEntry = entries.get(i); + AnomalyLocalizationOutput.Bucket currentEntry = entries.get(i); if (input.getAnomalyStartTime().isPresent()) { if (currentEntry.getEndTime() > input.getAnomalyStartTime().get()) { currentEntry.setBase(Optional.of(baseEntry)); diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/OutputTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationOutputTests.java similarity index 73% rename from ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/OutputTests.java rename to ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationOutputTests.java index cb832cf92c..94a209dc16 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/OutputTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizationOutputTests.java @@ -26,19 +26,19 @@ import static org.junit.Assert.assertEquals; -public class OutputTests { +public class AnomalyLocalizationOutputTests { - private Output output; + private AnomalyLocalizationOutput output; @Before public void setup() { - Output.Entity entity = new Output.Entity(); + AnomalyLocalizationOutput.Entity entity = new AnomalyLocalizationOutput.Entity(); entity.setKey(Arrays.asList("key1")); - Output.Bucket bucket = new Output.Bucket(); + AnomalyLocalizationOutput.Bucket bucket = new AnomalyLocalizationOutput.Bucket(); bucket.setEntities(Arrays.asList(entity)); - Output.Result result = new Output.Result(); + AnomalyLocalizationOutput.Result result = new AnomalyLocalizationOutput.Result(); result.setBuckets(Arrays.asList(bucket)); - output = new Output(); + output = new AnomalyLocalizationOutput(); output.getResults().put("agg", result); } @@ -47,7 +47,7 @@ public void testWriteable() throws Exception { BytesStreamOutput out = new BytesStreamOutput(); output.writeTo(out); - Output newOutput = new Output(out.bytes().streamInput()); + AnomalyLocalizationOutput newOutput = new AnomalyLocalizationOutput(out.bytes().streamInput()); assertEquals(output, newOutput); } @@ -60,7 +60,7 @@ public void testXContent() throws Exception { String json = Strings.toString(builder); XContentParser parser = XContentType.JSON.xContent().createParser(NamedXContentRegistry.EMPTY, null, json); - Output newOutput = Output.parse(parser); + AnomalyLocalizationOutput newOutput = AnomalyLocalizationOutput.parse(parser); assertEquals(output, newOutput); } diff --git a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java index 8685367b09..3ad42676f8 100644 --- a/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java +++ b/ml-algorithms/src/test/java/org/opensearch/ml/engine/algorithms/anomalylocalization/AnomalyLocalizerImplTests.java @@ -52,7 +52,7 @@ public class AnomalyLocalizerImplTests { private Client client; @Mock - private ActionListener outputListener; + private ActionListener outputListener; private Settings settings; @@ -69,16 +69,16 @@ public class AnomalyLocalizerImplTests { private int numOutput = 1; private Input input; - private Output expectedOutput; + private AnomalyLocalizationOutput expectedOutput; @Mock private SingleValue valueOne; @Mock private SingleValue valueTwo; @Mock private SingleValue valueThree; - private Output.Bucket expectedBucketOne; - private Output.Bucket expectedBucketTwo; - private Output.Entity entity; + private AnomalyLocalizationOutput.Bucket expectedBucketOne; + private AnomalyLocalizationOutput.Bucket expectedBucketTwo; + private AnomalyLocalizationOutput.Entity entity; @Before @SuppressWarnings({"unchecked", "rawtypes"}) @@ -198,17 +198,17 @@ public Object answer(InvocationOnMock invocation) { ). when(client).search(any(), any()); - expectedOutput = new Output(); - Output.Result result = new Output.Result(); - expectedBucketOne = new Output.Bucket(); + expectedOutput = new AnomalyLocalizationOutput(); + AnomalyLocalizationOutput.Result result = new AnomalyLocalizationOutput.Result(); + expectedBucketOne = new AnomalyLocalizationOutput.Bucket(); expectedBucketOne.setStartTime(0); expectedBucketOne.setEndTime(1); expectedBucketOne.setOverallAggValue(0); - expectedBucketTwo = new Output.Bucket(); + expectedBucketTwo = new AnomalyLocalizationOutput.Bucket(); expectedBucketTwo.setStartTime(1); expectedBucketTwo.setEndTime(2); expectedBucketTwo.setOverallAggValue(10); - entity = new Output.Entity(); + entity = new AnomalyLocalizationOutput.Entity(); entity.setKey(Arrays.asList(bucketOneKeyValue)); entity.setNewValue(valueTwo.value()); entity.setBaseValue(valueOne.value()); @@ -223,9 +223,9 @@ public Object answer(InvocationOnMock invocation) { public void testGetLocalizedResultsGivenNoAnomaly() { anomalyLocalizer.getLocalizationResults(input, outputListener); - ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(Output.class); + ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(AnomalyLocalizationOutput.class); verify(outputListener).onResponse(outputCaptor.capture()); - Output actualOutput = outputCaptor.getValue(); + AnomalyLocalizationOutput actualOutput = outputCaptor.getValue(); assertEquals(expectedOutput, actualOutput); } @@ -237,9 +237,9 @@ public void testGetLocalizedResultsGivenAnomaly() { anomalyLocalizer.getLocalizationResults(input, outputListener); - ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(Output.class); + ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(AnomalyLocalizationOutput.class); verify(outputListener).onResponse(outputCaptor.capture()); - Output actualOutput = outputCaptor.getValue(); + AnomalyLocalizationOutput actualOutput = outputCaptor.getValue(); assertEquals(expectedOutput, actualOutput); } @@ -278,9 +278,9 @@ public void testGetLocalizedResultsOverallDecrease() { anomalyLocalizer.getLocalizationResults(input, outputListener); - ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(Output.class); + ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(AnomalyLocalizationOutput.class); verify(outputListener).onResponse(outputCaptor.capture()); - Output actualOutput = outputCaptor.getValue(); + AnomalyLocalizationOutput actualOutput = outputCaptor.getValue(); expectedBucketOne.setOverallAggValue(valueOne.value()); expectedBucketTwo.setOverallAggValue(valueTwo.value()); entity.setNewValue(valueTwo.value()); @@ -296,9 +296,9 @@ public void testGetLocalizedResultsOverallUnchange() { anomalyLocalizer.getLocalizationResults(input, outputListener); - ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(Output.class); + ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(AnomalyLocalizationOutput.class); verify(outputListener).onResponse(outputCaptor.capture()); - Output actualOutput = outputCaptor.getValue(); + AnomalyLocalizationOutput actualOutput = outputCaptor.getValue(); expectedBucketOne.setOverallAggValue(valueOne.value()); expectedBucketOne.setEntities(null); expectedBucketTwo.setOverallAggValue(valueTwo.value()); @@ -313,9 +313,9 @@ public void testGetLocalizedResultsFilterEntity() { anomalyLocalizer.getLocalizationResults(input, outputListener); - ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(Output.class); + ArgumentCaptor outputCaptor = ArgumentCaptor.forClass(AnomalyLocalizationOutput.class); verify(outputListener).onResponse(outputCaptor.capture()); - Output actualOutput = outputCaptor.getValue(); + AnomalyLocalizationOutput actualOutput = outputCaptor.getValue(); assertEquals(expectedOutput, actualOutput); } }