diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java index 5b7d2cf288d8b..df6727317e208 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/InternalMatrixStats.java @@ -139,6 +139,10 @@ public double getCorrelation(String fieldX, String fieldY) { return results.getCorrelation(fieldX, fieldY); } + RunningStats getStats() { + return stats; + } + MatrixStatsResults getResults() { return results; } diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java index 5c193828c5536..578116d7b5eb2 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregator.java @@ -41,14 +41,14 @@ /** * Metric Aggregation for computing the pearson product correlation coefficient between multiple fields **/ -public class MatrixStatsAggregator extends MetricsAggregator { +final class MatrixStatsAggregator extends MetricsAggregator { /** Multiple ValuesSource with field names */ - final NumericMultiValuesSource valuesSources; + private final NumericMultiValuesSource valuesSources; /** array of descriptive stats, per shard, needed to compute the correlation */ ObjectArray stats; - public MatrixStatsAggregator(String name, Map valuesSources, SearchContext context, + MatrixStatsAggregator(String name, Map valuesSources, SearchContext context, Aggregator parent, MultiValueMode multiValueMode, List pipelineAggregators, Map metaData) throws IOException { super(name, context, parent, pipelineAggregators, metaData); diff --git a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java index c991e2c5c8655..2c3ac82a0c1a8 100644 --- a/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java +++ b/modules/aggs-matrix-stats/src/main/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorFactory.java @@ -32,12 +32,12 @@ import java.util.List; import java.util.Map; -public class MatrixStatsAggregatorFactory +final class MatrixStatsAggregatorFactory extends MultiValuesSourceAggregatorFactory { private final MultiValueMode multiValueMode; - public MatrixStatsAggregatorFactory(String name, + MatrixStatsAggregatorFactory(String name, Map> configs, MultiValueMode multiValueMode, SearchContext context, AggregatorFactory parent, AggregatorFactories.Builder subFactoriesBuilder, Map metaData) throws IOException { diff --git a/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorTests.java b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorTests.java new file mode 100644 index 0000000000000..aa778e6f704f9 --- /dev/null +++ b/modules/aggs-matrix-stats/src/test/java/org/elasticsearch/search/aggregations/matrix/stats/MatrixStatsAggregatorTests.java @@ -0,0 +1,96 @@ +/* + * Licensed to Elasticsearch under one or more contributor + * license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright + * ownership. Elasticsearch licenses this file to you under + * the Apache License, Version 2.0 (the "License"); you may + * not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.elasticsearch.search.aggregations.matrix.stats; + +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field; +import org.apache.lucene.document.SortedNumericDocValuesField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.RandomIndexWriter; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.store.Directory; +import org.apache.lucene.util.NumericUtils; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.NumberFieldMapper; +import org.elasticsearch.search.aggregations.AggregatorTestCase; + +import java.util.Arrays; +import java.util.Collections; + +public class MatrixStatsAggregatorTests extends AggregatorTestCase { + + public void testNoData() throws Exception { + MappedFieldType ft = + new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE); + ft.setName("field"); + + try (Directory directory = newDirectory(); + RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + if (randomBoolean()) { + indexWriter.addDocument(Collections.singleton(new StringField("another_field", "value", Field.Store.NO))); + } + try (IndexReader reader = indexWriter.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg") + .fields(Collections.singletonList("field")); + InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ft); + assertNull(stats.getStats()); + } + } + } + + public void testTwoFields() throws Exception { + String fieldA = "a"; + MappedFieldType ftA = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE); + ftA.setName(fieldA); + String fieldB = "b"; + MappedFieldType ftB = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE); + ftB.setName(fieldB); + + try (Directory directory = newDirectory(); + RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) { + + int numDocs = scaledRandomIntBetween(8192, 16384); + Double[] fieldAValues = new Double[numDocs]; + Double[] fieldBValues = new Double[numDocs]; + for (int docId = 0; docId < numDocs; docId++) { + Document document = new Document(); + fieldAValues[docId] = randomDouble(); + document.add(new SortedNumericDocValuesField(fieldA, NumericUtils.doubleToSortableLong(fieldAValues[docId]))); + + fieldBValues[docId] = randomDouble(); + document.add(new SortedNumericDocValuesField(fieldB, NumericUtils.doubleToSortableLong(fieldBValues[docId]))); + indexWriter.addDocument(document); + } + + MultiPassStats multiPassStats = new MultiPassStats(fieldA, fieldB); + multiPassStats.computeStats(Arrays.asList(fieldAValues), Arrays.asList(fieldBValues)); + try (IndexReader reader = indexWriter.getReader()) { + IndexSearcher searcher = new IndexSearcher(reader); + MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg") + .fields(Arrays.asList(fieldA, fieldB)); + InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ftA, ftB); + multiPassStats.assertNearlyEqual(new MatrixStatsResults(stats.getStats())); + } + } + } + +} diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index 897e1513e410a..32513a51c132e 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -110,6 +110,9 @@ protected AggregatorFactory createAggregatorFactory(AggregationBuilder aggreg QueryShardContext queryShardContext = queryShardContextMock(mapperService, fieldTypes, circuitBreakerService); when(searchContext.getQueryShardContext()).thenReturn(queryShardContext); + for (MappedFieldType fieldType : fieldTypes) { + when(searchContext.smartNameFieldType(fieldType.name())).thenReturn(fieldType); + } return aggregationBuilder.build(searchContext, null); }