From a989229bcd70888b3cb9e98ab58fb10ef47f62bc Mon Sep 17 00:00:00 2001 From: Martijn van Groningen Date: Mon, 26 May 2014 17:56:07 +0200 Subject: [PATCH] Added pagination support to `top_hits` aggregation by adding `from` option. Closes #6299 --- .../bucket/tophits-aggregation.asciidoc | 1 + .../bucket/tophits/InternalTopHits.java | 8 +- .../aggregations/bucket/tophits/TopHits.java | 4 + .../bucket/tophits/TopHitsAggregator.java | 6 +- .../bucket/tophits/TopHitsBuilder.java | 11 +- .../bucket/tophits/TopHitsContext.java | 6 +- .../bucket/tophits/TopHitsParser.java | 3 + .../aggregations/bucket/TopHitsTests.java | 154 +++++++++++++++++- 8 files changed, 184 insertions(+), 9 deletions(-) diff --git a/docs/reference/search/aggregations/bucket/tophits-aggregation.asciidoc b/docs/reference/search/aggregations/bucket/tophits-aggregation.asciidoc index 51e4686fc19d2..24230cccdf0d8 100644 --- a/docs/reference/search/aggregations/bucket/tophits-aggregation.asciidoc +++ b/docs/reference/search/aggregations/bucket/tophits-aggregation.asciidoc @@ -13,6 +13,7 @@ This aggregator can't hold any sub-aggregators and therefor can only be used as ==== Options +* `from` - The index from which to include matching hits. * `size` - The maximum number of top matching hits to return per bucket. By default the top three matching hits are returned. * `sort` - How the top matching hits should be sorted. By default the hits are sorted by the score of the main query. diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/InternalTopHits.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/InternalTopHits.java index 3f74391529bff..76b27fa65b9d5 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/InternalTopHits.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/InternalTopHits.java @@ -54,6 +54,7 @@ public static void registerStreams() { AggregationStreams.registerStream(STREAM, TYPE.stream()); } + private int from; private int size; private Sort sort; private TopDocs topDocs; @@ -62,8 +63,9 @@ public static void registerStreams() { InternalTopHits() { } - public InternalTopHits(String name, int size, Sort sort, TopDocs topDocs, InternalSearchHits searchHits) { + public InternalTopHits(String name, int from, int size, Sort sort, TopDocs topDocs, InternalSearchHits searchHits) { this.name = name; + this.from = from; this.size = size; this.sort = sort; this.topDocs = topDocs; @@ -104,7 +106,7 @@ public InternalAggregation reduce(ReduceContext reduceContext) { try { int[] tracker = new int[shardHits.length]; - TopDocs reducedTopDocs = TopDocs.merge(sort, size, shardDocs); + TopDocs reducedTopDocs = TopDocs.merge(sort, from, size, shardDocs); InternalSearchHit[] hits = new InternalSearchHit[reducedTopDocs.scoreDocs.length]; for (int i = 0; i < reducedTopDocs.scoreDocs.length; i++) { ScoreDoc scoreDoc = reducedTopDocs.scoreDocs[i]; @@ -119,6 +121,7 @@ public InternalAggregation reduce(ReduceContext reduceContext) { @Override public void readFrom(StreamInput in) throws IOException { name = in.readString(); + from = in.readVInt(); size = in.readVInt(); topDocs = Lucene.readTopDocs(in); if (topDocs instanceof TopFieldDocs) { @@ -130,6 +133,7 @@ public void readFrom(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(name); + out.writeVInt(from); out.writeVInt(size); Lucene.writeTopDocs(out, topDocs, 0); searchHits.writeTo(out); diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHits.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHits.java index 853a8a1dad6c1..4c20e430b9af1 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHits.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHits.java @@ -22,9 +22,13 @@ import org.elasticsearch.search.aggregations.Aggregation; /** + * Accumulation of the most relevant hits for a bucket this aggregation falls into. */ public interface TopHits extends Aggregation { + /** + * @return The top matching hits for the bucket + */ SearchHits getHits(); } diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsAggregator.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsAggregator.java index 165f9e5198d63..0e09360a33409 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsAggregator.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsAggregator.java @@ -88,7 +88,7 @@ public InternalAggregation buildAggregation(long owningBucketOrdinal) { searchHitFields.sortValues(fieldDoc.fields); } } - return new InternalTopHits(name, topHitsContext.size(), topHitsContext.sort(), topDocs, fetchResult.hits()); + return new InternalTopHits(name, topHitsContext.from(), topHitsContext.size(), topHitsContext.sort(), topDocs, fetchResult.hits()); } } @@ -102,10 +102,10 @@ public void collect(int docId, long bucketOrdinal) throws IOException { TopDocsCollector topDocsCollector = topDocsCollectors.get(bucketOrdinal); if (topDocsCollector == null) { Sort sort = topHitsContext.sort(); - int size = topHitsContext.size(); + int topN = topHitsContext.from() + topHitsContext.size(); topDocsCollectors.put( bucketOrdinal, - topDocsCollector = sort != null ? TopFieldCollector.create(sort, size, true, topHitsContext.trackScores(), true, false) : TopScoreDocCollector.create(size, false) + topDocsCollector = sort != null ? TopFieldCollector.create(sort, topN, true, topHitsContext.trackScores(), true, false) : TopScoreDocCollector.create(topN, false) ); topDocsCollector.setNextReader(currentContext); topDocsCollector.setScorer(currentScorer); diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsBuilder.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsBuilder.java index 785fadff2d703..2a49b5afd2c91 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsBuilder.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsBuilder.java @@ -41,8 +41,17 @@ public TopHitsBuilder(String name) { } /** - * The number of search hits to return. Defaults to 10. + * The index to start to return hits from. Defaults to 0. */ + public TopHitsBuilder setFrom(int from) { + sourceBuilder().from(from); + return this; + } + + + /** + * The number of search hits to return. Defaults to 10. + */ public TopHitsBuilder setSize(int size) { sourceBuilder().size(size); return this; diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsContext.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsContext.java index 848c6ca414dab..2023b4c96689f 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsContext.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsContext.java @@ -71,6 +71,7 @@ public class TopHitsContext extends SearchContext { // the to hits are returned per bucket. private final static int DEFAULT_SIZE = 3; + private int from; private int size = DEFAULT_SIZE; private Sort sort; @@ -440,12 +441,13 @@ public SearchContext updateRewriteQuery(Query rewriteQuery) { @Override public int from() { - return context.from(); + return from; } @Override public SearchContext from(int from) { - throw new UnsupportedOperationException("Not supported"); + this.from = from; + return this; } @Override diff --git a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsParser.java b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsParser.java index eac833f567251..cfd93bb4f2da1 100644 --- a/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsParser.java +++ b/src/main/java/org/elasticsearch/search/aggregations/bucket/tophits/TopHitsParser.java @@ -72,6 +72,9 @@ public AggregatorFactory parse(String aggregationName, XContentParser parser, Se currentFieldName = parser.currentName(); } else if (token.isValue()) { switch (currentFieldName) { + case "from": + topHitsContext.from(parser.intValue()); + break; case "size": topHitsContext.size(parser.intValue()); break; diff --git a/src/test/java/org/elasticsearch/search/aggregations/bucket/TopHitsTests.java b/src/test/java/org/elasticsearch/search/aggregations/bucket/TopHitsTests.java index c990d745c5d64..c67cf999d5c84 100644 --- a/src/test/java/org/elasticsearch/search/aggregations/bucket/TopHitsTests.java +++ b/src/test/java/org/elasticsearch/search/aggregations/bucket/TopHitsTests.java @@ -75,7 +75,6 @@ public void setupSuiteScopeCluster() throws Exception { .endObject())); } - // Use routing to make sure all docs are in the same shard for consistent scoring builders.add(client().prepareIndex("idx", "field-collapsing", "1").setSource(jsonBuilder() .startObject() .field("group", "a") @@ -168,6 +167,159 @@ public void testBasics() throws Exception { } } + @Test + public void testPagination() throws Exception { + SearchResponse response = client().prepareSearch("idx").setTypes("type") + .addAggregation(terms("terms") + .executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD) + .subAggregation( + topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC)).setSize(2) + ) + ) + .get(); + + assertSearchResponse(response); + + Terms terms = response.getAggregations().get("terms"); + assertThat(terms, notNullValue()); + assertThat(terms.getName(), equalTo("terms")); + assertThat(terms.getBuckets().size(), equalTo(5)); + + Terms.Bucket bucket = terms.getBucketByKey("val0"); + assertThat(bucket, notNullValue()); + assertThat(bucket.getDocCount(), equalTo(10l)); + TopHits topHits = bucket.getAggregations().get("hits"); + SearchHits hits = topHits.getHits(); + assertThat(hits.totalHits(), equalTo(10l)); + assertThat(hits.getHits().length, equalTo(2)); + assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(10l)); + assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(9l)); + + response = client().prepareSearch("idx").setTypes("type") + .addAggregation(terms("terms") + .executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD) + .subAggregation( + topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC)) + .setSize(2) + .setFrom(2) + ) + ) + .get(); + + assertSearchResponse(response); + + terms = response.getAggregations().get("terms"); + bucket = terms.getBucketByKey("val0"); + assertThat(bucket, notNullValue()); + assertThat(bucket.getDocCount(), equalTo(10l)); + topHits = bucket.getAggregations().get("hits"); + hits = topHits.getHits(); + assertThat(hits.totalHits(), equalTo(10l)); + assertThat(hits.getHits().length, equalTo(2)); + assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(8l)); + assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(7l)); + + response = client().prepareSearch("idx").setTypes("type") + .addAggregation(terms("terms") + .executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD) + .subAggregation( + topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC)) + .setSize(2) + .setFrom(4) + ) + ) + .get(); + + assertSearchResponse(response); + + terms = response.getAggregations().get("terms"); + bucket = terms.getBucketByKey("val0"); + assertThat(bucket, notNullValue()); + assertThat(bucket.getDocCount(), equalTo(10l)); + topHits = bucket.getAggregations().get("hits"); + hits = topHits.getHits(); + assertThat(hits.totalHits(), equalTo(10l)); + assertThat(hits.getHits().length, equalTo(2)); + assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(6l)); + assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(5l)); + + response = client().prepareSearch("idx").setTypes("type") + .addAggregation(terms("terms") + .executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD) + .subAggregation( + topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC)) + .setSize(2) + .setFrom(6) + ) + ) + .get(); + + assertSearchResponse(response); + + terms = response.getAggregations().get("terms"); + bucket = terms.getBucketByKey("val0"); + assertThat(bucket, notNullValue()); + assertThat(bucket.getDocCount(), equalTo(10l)); + topHits = bucket.getAggregations().get("hits"); + hits = topHits.getHits(); + assertThat(hits.totalHits(), equalTo(10l)); + assertThat(hits.getHits().length, equalTo(2)); + assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(4l)); + assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(3l)); + + response = client().prepareSearch("idx").setTypes("type") + .addAggregation(terms("terms") + .executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD) + .subAggregation( + topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC)) + .setSize(2) + .setFrom(8) + ) + ) + .get(); + + assertSearchResponse(response); + + terms = response.getAggregations().get("terms"); + bucket = terms.getBucketByKey("val0"); + assertThat(bucket, notNullValue()); + assertThat(bucket.getDocCount(), equalTo(10l)); + topHits = bucket.getAggregations().get("hits"); + hits = topHits.getHits(); + assertThat(hits.totalHits(), equalTo(10l)); + assertThat(hits.getHits().length, equalTo(2)); + assertThat((Long) hits.getAt(0).sortValues()[0], equalTo(2l)); + assertThat((Long) hits.getAt(1).sortValues()[0], equalTo(1l)); + + response = client().prepareSearch("idx").setTypes("type") + .addAggregation(terms("terms") + .executionHint(randomExecutionHint()) + .field(TERMS_AGGS_FIELD) + .subAggregation( + topHits("hits").addSort(SortBuilders.fieldSort(SORT_FIELD).order(SortOrder.DESC)) + .setSize(2) + .setFrom(10) + ) + ) + .get(); + + assertSearchResponse(response); + + terms = response.getAggregations().get("terms"); + bucket = terms.getBucketByKey("val0"); + assertThat(bucket, notNullValue()); + assertThat(bucket.getDocCount(), equalTo(10l)); + topHits = bucket.getAggregations().get("hits"); + hits = topHits.getHits(); + assertThat(hits.totalHits(), equalTo(10l)); + assertThat(hits.getHits().length, equalTo(0)); + } + @Test public void testSortByBucket() throws Exception { SearchResponse response = client().prepareSearch("idx").setTypes("type")