Skip to content

Commit

Permalink
[TEST] Added unit tests for diversified sampler aggregator.
Browse files Browse the repository at this point in the history
  • Loading branch information
martijnvg committed Mar 12, 2017
1 parent 9d4aff5 commit b01070a
Show file tree
Hide file tree
Showing 14 changed files with 446 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
* specific language governing permissions and limitations
* under the License.
*/

package org.elasticsearch.search.aggregations.bucket;
package org.elasticsearch.search.aggregations.bucket.sampler;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.DocIdSetIterator;
Expand All @@ -34,11 +33,11 @@
import org.elasticsearch.common.util.ObjectArray;
import org.elasticsearch.search.aggregations.BucketCollector;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;

/**
Expand All @@ -48,13 +47,11 @@
* {@link BestDocsDeferringCollector#createTopDocsCollector(int)} is designed to
* be overridden and allows subclasses to choose a custom collector
* implementation for determining the top N matches.
*
*/

public class BestDocsDeferringCollector extends DeferringBucketCollector implements Releasable {
final List<PerSegmentCollects> entries = new ArrayList<>();
BucketCollector deferred;
ObjectArray<PerParentBucketSamples> perBucketSamples;
private final List<PerSegmentCollects> entries = new ArrayList<>();
private BucketCollector deferred;
private ObjectArray<PerParentBucketSamples> perBucketSamples;
private int shardSize;
private PerSegmentCollects perSegCollector;
private final BigArrays bigArrays;
Expand All @@ -65,14 +62,12 @@ public class BestDocsDeferringCollector extends DeferringBucketCollector impleme
* @param shardSize
* The number of top-scoring docs to collect for each bucket
*/
public BestDocsDeferringCollector(int shardSize, BigArrays bigArrays) {
BestDocsDeferringCollector(int shardSize, BigArrays bigArrays) {
this.shardSize = shardSize;
this.bigArrays = bigArrays;
perBucketSamples = bigArrays.newObjectArray(1);
}



@Override
public boolean needsScores() {
return true;
Expand Down Expand Up @@ -126,7 +121,6 @@ public void prepareSelectedBuckets(long... selectedBuckets) throws IOException {
}

private void runDeferredAggs() throws IOException {

List<ScoreDoc> allDocs = new ArrayList<>(shardSize);
for (int i = 0; i < perBucketSamples.size(); i++) {
PerParentBucketSamples perBucketSample = perBucketSamples.get(i);
Expand All @@ -138,15 +132,12 @@ private void runDeferredAggs() throws IOException {

// Sort the top matches by docID for the benefit of deferred collector
ScoreDoc[] docsArr = allDocs.toArray(new ScoreDoc[allDocs.size()]);
Arrays.sort(docsArr, new Comparator<ScoreDoc>() {
@Override
public int compare(ScoreDoc o1, ScoreDoc o2) {
if(o1.doc == o2.doc){
return o1.shardIndex - o2.shardIndex;
}
return o1.doc - o2.doc;
}
});
Arrays.sort(docsArr, (o1, o2) -> {
if(o1.doc == o2.doc){
return o1.shardIndex - o2.shardIndex;
}
return o1.doc - o2.doc;
});
try {
for (PerSegmentCollects perSegDocs : entries) {
perSegDocs.replayRelatedMatches(docsArr);
Expand Down Expand Up @@ -295,7 +286,6 @@ public void collect(int docId, long parentBucket) throws IOException {
}
}


public int getDocCount(long parentBucket) {
PerParentBucketSamples sampler = perBucketSamples.get((int) parentBucket);
if (sampler == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ public class DiversifiedAggregatorFactory extends ValuesSourceAggregatorFactory<
private final int maxDocsPerValue;
private final String executionHint;

public DiversifiedAggregatorFactory(String name, ValuesSourceConfig<ValuesSource> config, int shardSize, int maxDocsPerValue,
DiversifiedAggregatorFactory(String name, ValuesSourceConfig<ValuesSource> config, int shardSize, int maxDocsPerValue,
String executionHint, SearchContext context, AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder,
Map<String, Object> metaData) throws IOException {
super(name, config, context, parent, subFactoriesBuilder, metaData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.bucket.BestDocsDeferringCollector;
import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.ValuesSource;
Expand All @@ -48,7 +47,7 @@ public class DiversifiedBytesHashSamplerAggregator extends SamplerAggregator {
private ValuesSource valuesSource;
private int maxDocsPerValue;

public DiversifiedBytesHashSamplerAggregator(String name, int shardSize, AggregatorFactories factories,
DiversifiedBytesHashSamplerAggregator(String name, int shardSize, AggregatorFactories factories,
SearchContext context, Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData,
ValuesSource valuesSource,
int maxDocsPerValue) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import org.elasticsearch.index.fielddata.SortedBinaryDocValues;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.bucket.BestDocsDeferringCollector;
import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.ValuesSource;
Expand All @@ -47,7 +46,7 @@ public class DiversifiedMapSamplerAggregator extends SamplerAggregator {
private int maxDocsPerValue;
private BytesRefHash bucketOrds;

public DiversifiedMapSamplerAggregator(String name, int shardSize, AggregatorFactories factories,
DiversifiedMapSamplerAggregator(String name, int shardSize, AggregatorFactories factories,
SearchContext context, Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData,
ValuesSource valuesSource, int maxDocsPerValue) throws IOException {
super(name, shardSize, factories, context, parent, pipelineAggregators, metaData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.bucket.BestDocsDeferringCollector;
import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.ValuesSource;
Expand All @@ -43,7 +42,7 @@ public class DiversifiedNumericSamplerAggregator extends SamplerAggregator {
private ValuesSource.Numeric valuesSource;
private int maxDocsPerValue;

public DiversifiedNumericSamplerAggregator(String name, int shardSize, AggregatorFactories factories,
DiversifiedNumericSamplerAggregator(String name, int shardSize, AggregatorFactories factories,
SearchContext context, Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData,
ValuesSource.Numeric valuesSource, int maxDocsPerValue) throws IOException {
super(name, shardSize, factories, context, parent, pipelineAggregators, metaData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import org.apache.lucene.search.TopDocsCollector;
import org.elasticsearch.search.aggregations.Aggregator;
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.bucket.BestDocsDeferringCollector;
import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
import org.elasticsearch.search.aggregations.support.ValuesSource;
Expand All @@ -44,7 +43,7 @@ public class DiversifiedOrdinalsSamplerAggregator extends SamplerAggregator {
private ValuesSource.Bytes.WithOrdinals.FieldData valuesSource;
private int maxDocsPerValue;

public DiversifiedOrdinalsSamplerAggregator(String name, int shardSize, AggregatorFactories factories,
DiversifiedOrdinalsSamplerAggregator(String name, int shardSize, AggregatorFactories factories,
SearchContext context, Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData,
ValuesSource.Bytes.WithOrdinals.FieldData valuesSource, int maxDocsPerValue) throws IOException {
super(name, shardSize, factories, context, parent, pipelineAggregators, metaData);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.search.aggregations.AggregatorFactories;
import org.elasticsearch.search.aggregations.InternalAggregation;
import org.elasticsearch.search.aggregations.LeafBucketCollector;
import org.elasticsearch.search.aggregations.bucket.BestDocsDeferringCollector;
import org.elasticsearch.search.aggregations.bucket.DeferringBucketCollector;
import org.elasticsearch.search.aggregations.bucket.SingleBucketAggregator;
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
Expand All @@ -53,7 +52,6 @@ public class SamplerAggregator extends SingleBucketAggregator {
public static final ParseField MAX_DOCS_PER_VALUE_FIELD = new ParseField("max_docs_per_value");
public static final ParseField EXECUTION_HINT_FIELD = new ParseField("execution_hint");


public enum ExecutionMode {

MAP(new ParseField("map")) {
Expand Down Expand Up @@ -141,7 +139,7 @@ public String toString() {
protected final int shardSize;
protected BestDocsDeferringCollector bdd;

public SamplerAggregator(String name, int shardSize, AggregatorFactories factories, SearchContext context,
SamplerAggregator(String name, int shardSize, AggregatorFactories factories, SearchContext context,
Aggregator parent, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) throws IOException {
super(name, factories, context, parent, pipelineAggregators, metaData);
this.shardSize = shardSize;
Expand All @@ -156,10 +154,8 @@ public boolean needsScores() {
public DeferringBucketCollector getDeferringCollector() {
bdd = new BestDocsDeferringCollector(shardSize, context.bigArrays());
return bdd;

}


@Override
protected boolean shouldDefer(Aggregator aggregator) {
return true;
Expand Down Expand Up @@ -193,4 +189,3 @@ protected void doClose() {
}

}

Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public class SamplerAggregatorFactory extends AggregatorFactory<SamplerAggregato

private final int shardSize;

public SamplerAggregatorFactory(String name, int shardSize, SearchContext context, AggregatorFactory<?> parent,
SamplerAggregatorFactory(String name, int shardSize, SearchContext context, AggregatorFactory<?> parent,
AggregatorFactories.Builder subFactories, Map<String, Object> metaData) throws IOException {
super(name, context, parent, subFactories, metaData);
this.shardSize = shardSize;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
public class UnmappedSampler extends InternalSampler {
public static final String NAME = "unmapped_sampler";

public UnmappedSampler(String name, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
UnmappedSampler(String name, List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
super(name, 0, InternalAggregations.EMPTY, pipelineAggregators, metaData);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.lookup.SearchLookup;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.IndexSettingsModule;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -75,15 +74,31 @@ public abstract class AggregatorTestCase extends ESTestCase {
protected <A extends Aggregator, B extends AggregationBuilder> A createAggregator(B aggregationBuilder,
IndexSearcher indexSearcher,
MappedFieldType... fieldTypes) throws IOException {
IndexSettings indexSettings = new IndexSettings(
IndexMetaData.builder("_index").settings(Settings.builder().put(IndexMetaData.SETTING_VERSION_CREATED, Version.CURRENT))
.numberOfShards(1)
.numberOfReplicas(0)
.creationDate(System.currentTimeMillis())
.build(),
Settings.EMPTY
);
IndexSettings indexSettings = createIndexSettings();
SearchContext searchContext = createSearchContext(indexSearcher, indexSettings);
CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
when(searchContext.bigArrays()).thenReturn(new MockBigArrays(Settings.EMPTY, circuitBreakerService));
// TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
MapperService mapperService = mapperServiceMock();
when(mapperService.hasNested()).thenReturn(false);
when(searchContext.mapperService()).thenReturn(mapperService);
IndexFieldDataService ifds = new IndexFieldDataService(indexSettings,
new IndicesFieldDataCache(Settings.EMPTY, new IndexFieldDataCache.Listener() {
}), circuitBreakerService, mapperService);
when(searchContext.fieldData()).thenReturn(ifds);

SearchLookup searchLookup = new SearchLookup(mapperService, ifds, new String[]{"type"});
when(searchContext.lookup()).thenReturn(searchLookup);

QueryShardContext queryShardContext = queryShardContextMock(fieldTypes, indexSettings, circuitBreakerService);
when(searchContext.getQueryShardContext()).thenReturn(queryShardContext);

@SuppressWarnings("unchecked")
A aggregator = (A) aggregationBuilder.build(searchContext, null).create(null, true);
return aggregator;
}

protected SearchContext createSearchContext(IndexSearcher indexSearcher, IndexSettings indexSettings) {
Engine.Searcher searcher = new Engine.Searcher("aggregator_test", indexSearcher);
QueryCache queryCache = new DisabledQueryCache(indexSettings);
QueryCachingPolicy queryCachingPolicy = new QueryCachingPolicy() {
Expand All @@ -99,38 +114,29 @@ public boolean shouldCache(Query query) throws IOException {
};
ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(searcher, queryCache, queryCachingPolicy);

CircuitBreakerService circuitBreakerService = new NoneCircuitBreakerService();
SearchContext searchContext = mock(SearchContext.class);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.bigArrays()).thenReturn(new MockBigArrays(Settings.EMPTY, circuitBreakerService));
when(searchContext.fetchPhase())
.thenReturn(new FetchPhase(Arrays.asList(new FetchSourceSubPhase(), new DocValueFieldsFetchSubPhase())));
.thenReturn(new FetchPhase(Arrays.asList(new FetchSourceSubPhase(), new DocValueFieldsFetchSubPhase())));
doAnswer(invocation -> {
/* Store the releasables so we can release them at the end of the test case. This is important because aggregations don't
* close their sub-aggregations. This is fairly similar to what the production code does. */
releasables.add((Releasable) invocation.getArguments()[0]);
return null;
}).when(searchContext).addReleasable(anyObject(), anyObject());
return searchContext;
}

// TODO: now just needed for top_hits, this will need to be revised for other agg unit tests:
MapperService mapperService = mapperServiceMock();
when(mapperService.hasNested()).thenReturn(false);
when(searchContext.mapperService()).thenReturn(mapperService);
IndexFieldDataService ifds = new IndexFieldDataService(IndexSettingsModule.newIndexSettings("test", Settings.EMPTY),
new IndicesFieldDataCache(Settings.EMPTY, new IndexFieldDataCache.Listener() {
}), circuitBreakerService, mapperService);
when(searchContext.fieldData()).thenReturn(ifds);

SearchLookup searchLookup = new SearchLookup(mapperService, ifds, new String[]{"type"});
when(searchContext.lookup()).thenReturn(searchLookup);

QueryShardContext queryShardContext = queryShardContextMock(fieldTypes, indexSettings, circuitBreakerService);
when(searchContext.getQueryShardContext()).thenReturn(queryShardContext);

@SuppressWarnings("unchecked")
A aggregator = (A) aggregationBuilder.build(searchContext, null).create(null, true);
return aggregator;
protected IndexSettings createIndexSettings() {
return new IndexSettings(
IndexMetaData.builder("_index").settings(Settings.builder().put(IndexMetaData.SETTING_VERSION_CREATED, Version.CURRENT))
.numberOfShards(1)
.numberOfReplicas(0)
.creationDate(System.currentTimeMillis())
.build(),
Settings.EMPTY
);
}

/**
Expand Down
Loading

0 comments on commit b01070a

Please sign in to comment.