Skip to content

Commit

Permalink
Added unit tests for MatrixStatsAggregator
Browse files Browse the repository at this point in the history
  • Loading branch information
martijnvg committed May 23, 2017
1 parent b2ccb6b commit 3409373
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ public double getCorrelation(String fieldX, String fieldY) {
return results.getCorrelation(fieldX, fieldY);
}

RunningStats getStats() {
return stats;
}

MatrixStatsResults getResults() {
return results;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@
/**
* Metric Aggregation for computing the pearson product correlation coefficient between multiple fields
**/
public class MatrixStatsAggregator extends MetricsAggregator {
final class MatrixStatsAggregator extends MetricsAggregator {
/** Multiple ValuesSource with field names */
final NumericMultiValuesSource valuesSources;
private final NumericMultiValuesSource valuesSources;

/** array of descriptive stats, per shard, needed to compute the correlation */
ObjectArray<RunningStats> stats;

public MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
MatrixStatsAggregator(String name, Map<String, ValuesSource.Numeric> valuesSources, SearchContext context,
Aggregator parent, MultiValueMode multiValueMode, List<PipelineAggregator> pipelineAggregators,
Map<String,Object> metaData) throws IOException {
super(name, context, parent, pipelineAggregators, metaData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@
import java.util.List;
import java.util.Map;

public class MatrixStatsAggregatorFactory
final class MatrixStatsAggregatorFactory
extends MultiValuesSourceAggregatorFactory<ValuesSource.Numeric, MatrixStatsAggregatorFactory> {

private final MultiValueMode multiValueMode;

public MatrixStatsAggregatorFactory(String name,
MatrixStatsAggregatorFactory(String name,
Map<String, ValuesSourceConfig<ValuesSource.Numeric>> configs, MultiValueMode multiValueMode,
SearchContext context, AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.elasticsearch.search.aggregations.matrix.stats;

import org.apache.lucene.document.Document;
import org.apache.lucene.document.Field;
import org.apache.lucene.document.SortedNumericDocValuesField;
import org.apache.lucene.document.StringField;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.store.Directory;
import org.apache.lucene.util.NumericUtils;
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.NumberFieldMapper;
import org.elasticsearch.search.aggregations.AggregatorTestCase;

import java.util.Arrays;
import java.util.Collections;

public class MatrixStatsAggregatorTests extends AggregatorTestCase {

public void testNoData() throws Exception {
MappedFieldType ft =
new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
ft.setName("field");

try (Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {
if (randomBoolean()) {
indexWriter.addDocument(Collections.singleton(new StringField("another_field", "value", Field.Store.NO)));
}
try (IndexReader reader = indexWriter.getReader()) {
IndexSearcher searcher = new IndexSearcher(reader);
MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
.fields(Collections.singletonList("field"));
InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ft);
assertNull(stats.getStats());
}
}
}

public void testTwoFields() throws Exception {
String fieldA = "a";
MappedFieldType ftA = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
ftA.setName(fieldA);
String fieldB = "b";
MappedFieldType ftB = new NumberFieldMapper.NumberFieldType(NumberFieldMapper.NumberType.DOUBLE);
ftB.setName(fieldB);

try (Directory directory = newDirectory();
RandomIndexWriter indexWriter = new RandomIndexWriter(random(), directory)) {

int numDocs = scaledRandomIntBetween(8192, 16384);
Double[] fieldAValues = new Double[numDocs];
Double[] fieldBValues = new Double[numDocs];
for (int docId = 0; docId < numDocs; docId++) {
Document document = new Document();
fieldAValues[docId] = randomDouble();
document.add(new SortedNumericDocValuesField(fieldA, NumericUtils.doubleToSortableLong(fieldAValues[docId])));

fieldBValues[docId] = randomDouble();
document.add(new SortedNumericDocValuesField(fieldB, NumericUtils.doubleToSortableLong(fieldBValues[docId])));
indexWriter.addDocument(document);
}

MultiPassStats multiPassStats = new MultiPassStats(fieldA, fieldB);
multiPassStats.computeStats(Arrays.asList(fieldAValues), Arrays.asList(fieldBValues));
try (IndexReader reader = indexWriter.getReader()) {
IndexSearcher searcher = new IndexSearcher(reader);
MatrixStatsAggregationBuilder aggBuilder = new MatrixStatsAggregationBuilder("my_agg")
.fields(Arrays.asList(fieldA, fieldB));
InternalMatrixStats stats = search(searcher, new MatchAllDocsQuery(), aggBuilder, ftA, ftB);
multiPassStats.assertNearlyEqual(new MatrixStatsResults(stats.getStats()));
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ protected AggregatorFactory<?> createAggregatorFactory(AggregationBuilder aggreg

QueryShardContext queryShardContext = queryShardContextMock(mapperService, fieldTypes, circuitBreakerService);
when(searchContext.getQueryShardContext()).thenReturn(queryShardContext);
for (MappedFieldType fieldType : fieldTypes) {
when(searchContext.smartNameFieldType(fieldType.name())).thenReturn(fieldType);
}

return aggregationBuilder.build(searchContext, null);
}
Expand Down

0 comments on commit 3409373

Please sign in to comment.