diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java index dd2e7458d81d2..6f66fc64f6dc7 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/CardinalityAggregator.java @@ -35,7 +35,15 @@ import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedNumericDocValues; import org.apache.lucene.index.SortedSetDocValues; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DisiPriorityQueue; +import org.apache.lucene.search.DisiWrapper; +import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.Weight; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.RamUsageEstimator; @@ -59,6 +67,7 @@ import org.opensearch.search.internal.SearchContext; import java.io.IOException; +import java.util.HashMap; import java.util.Map; import java.util.function.BiConsumer; @@ -137,8 +146,15 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException { // only use ordinals if they don't increase memory usage by more than 25% if (ordinalsMemoryUsage < countsMemoryUsage / 4) { ordinalsCollectorsUsed++; - return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), - context, ctx, fieldContext, source); + // return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), + // context, ctx, fieldContext, source); + return new CompetitiveCollector( + new OrdinalsCollector(counts, ordinalValues, context.bigArrays()), + source, + ctx, + context, + fieldContext + ); } ordinalsCollectorsOverheadTooHigh++; } @@ -217,6 +233,110 @@ abstract static class Collector extends LeafBucketCollector implements Releasabl } + private static class CompetitiveCollector extends Collector { + + private final Collector delegate; + private final DisiPriorityQueue pq; + + CompetitiveCollector( + Collector delegate, + ValuesSource.Bytes.WithOrdinals source, + LeafReaderContext ctx, + SearchContext context, + FieldContext fieldContext + ) throws IOException { + this.delegate = delegate; + + final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); + TermsEnum terms = ordinalValues.termsEnum(); + Map postingMap = new HashMap<>(); + while (terms.next() != null) { + BytesRef term = terms.term(); + + TermQuery termQuery = new TermQuery(new Term(fieldContext.field(), term)); + Weight subWeight = context.searcher().createWeight(termQuery, ScoreMode.COMPLETE_NO_SCORES, 1f); + Scorer scorer = subWeight.scorer(ctx); + + postingMap.put(term, scorer); + } + this.pq = new DisiPriorityQueue(postingMap.size()); + for (Map.Entry entry : postingMap.entrySet()) { + pq.add(new DisiWrapper(entry.getValue())); + } + } + + @Override + public void close() { + delegate.close(); + } + + @Override + public void collect(int doc, long owningBucketOrd) throws IOException { + delegate.collect(doc, owningBucketOrd); + } + + @Override + public DocIdSetIterator competitiveIterator() throws IOException { + return new DisjunctionDISIWithPruning(pq); + } + + @Override + public void postCollect() throws IOException { + delegate.postCollect(); + } + } + + private static class DisjunctionDISIWithPruning extends DocIdSetIterator { + + final DisiPriorityQueue queue; + + public DisjunctionDISIWithPruning(DisiPriorityQueue queue) { + this.queue = queue; + } + + @Override + public int docID() { + return queue.top().doc; + } + + @Override + public int nextDoc() throws IOException { + // don't expect this to be called + throw new UnsupportedOperationException(); + } + + @Override + public int advance(int target) throws IOException { + // more than advance to the next doc >= target + // we also do the pruning of current doc here + + DisiWrapper top = queue.top(); + + // after collecting the doc, before advancing to target + // we can safely remove all the iterators that having this doc + if (top.doc != -1) { + int curTopDoc = top.doc; + do { + top.doc = top.approximation.advance(Integer.MAX_VALUE); + top = queue.updateTop(); + } while (top.doc == curTopDoc); + } + + if (top.doc >= target) return top.doc; + do { + top.doc = top.approximation.advance(target); + top = queue.updateTop(); + } while (top.doc < target); + return top.doc; + } + + @Override + public long cost() { + // don't expect this to be called + throw new UnsupportedOperationException(); + } + } + /** * Empty Collector for the Cardinality agg * diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java index 9b6a0f42e0fa9..93a01aaa1e053 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DisjunctionWithDynamicPruningScorer.java @@ -22,7 +22,6 @@ import java.util.Collection; import java.util.List; - /** * Clone of {@link org.apache.lucene.search} {@code DisjunctionScorer.java} in lucene with following modifications *

@@ -45,8 +44,7 @@ public class DisjunctionWithDynamicPruningScorer extends Scorer { private Integer docID; - public DisjunctionWithDynamicPruningScorer(Weight weight, List subScorers) - throws IOException { + public DisjunctionWithDynamicPruningScorer(Weight weight, List subScorers) throws IOException { super(weight); if (subScorers.size() <= 1) { throw new IllegalArgumentException("There must be at least 2 subScorers"); @@ -90,9 +88,9 @@ public void removeAllDISIsOnCurrentDoc() { @Override public DocIdSetIterator iterator() { - DocIdSetIterator disi = getIterator(); - docID = disi.docID(); - return new SlowDocIdPropagatorDISI(getIterator(), docID); + DocIdSetIterator disi = getIterator(); + docID = disi.docID(); + return new SlowDocIdPropagatorDISI(getIterator(), docID); } private static class SlowDocIdPropagatorDISI extends DocIdSetIterator { @@ -157,13 +155,12 @@ private class TwoPhase extends TwoPhaseIterator { private TwoPhase(DocIdSetIterator approximation, float matchCost) { super(approximation); this.matchCost = matchCost; - unverifiedMatches = - new PriorityQueue(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) { - @Override - protected boolean lessThan(DisiWrapper a, DisiWrapper b) { - return a.matchCost < b.matchCost; - } - }; + unverifiedMatches = new PriorityQueue(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) { + @Override + protected boolean lessThan(DisiWrapper a, DisiWrapper b) { + return a.matchCost < b.matchCost; + } + }; } DisiWrapper getSubMatches() throws IOException { @@ -183,7 +180,7 @@ public boolean matches() throws IOException { verifiedMatches = null; unverifiedMatches.clear(); - for (DisiWrapper w = subScorers.topList(); w != null; ) { + for (DisiWrapper w = subScorers.topList(); w != null;) { DisiWrapper next = w.next; if (w.twoPhaseView == null) { diff --git a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java index f4c3d59a3833f..cb735a3257289 100644 --- a/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java +++ b/server/src/main/java/org/opensearch/search/aggregations/metrics/DynamicPruningCollectorWrapper.java @@ -8,7 +8,6 @@ package org.opensearch.search.aggregations.metrics; -import org.apache.lucene.index.LeafReader; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.index.SortedSetDocValues; import org.apache.lucene.index.Term; @@ -38,9 +37,13 @@ class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector { private final DocIdSetIterator disi; private final CardinalityAggregator.Collector delegateCollector; - DynamicPruningCollectorWrapper(CardinalityAggregator.Collector delegateCollector, - SearchContext context, LeafReaderContext ctx, FieldContext fieldContext, - ValuesSource.Bytes.WithOrdinals source) throws IOException { + DynamicPruningCollectorWrapper( + CardinalityAggregator.Collector delegateCollector, + SearchContext context, + LeafReaderContext ctx, + FieldContext fieldContext, + ValuesSource.Bytes.WithOrdinals source + ) throws IOException { this.ctx = ctx; this.delegateCollector = delegateCollector; final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx); @@ -52,7 +55,7 @@ class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector { // this logic should be pluggable depending on the type of leaf bucket collector by CardinalityAggregator TermsEnum terms = ordinalValues.termsEnum(); Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE_NO_SCORES, 1f); - Map found = new HashMap<>(); + Map found = new HashMap<>(); // ord : found or not List subScorers = new ArrayList<>(); while (terms.next() != null && !found.containsKey(terms.ord())) { // TODO can we get rid of terms previously encountered in other segments? diff --git a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java index a9966c9e70e76..d21e7f6ed8550 100644 --- a/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java +++ b/server/src/test/java/org/opensearch/search/aggregations/metrics/CardinalityAggregatorTests.java @@ -105,44 +105,58 @@ public void testDynamicPruningOrdinalCollector() throws IOException { MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName); final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName); testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> { - iw.addDocument(asList( - new KeywordField(fieldName, "1", Field.Store.NO), - new KeywordField(fieldName, "2", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("1")), - new SortedSetDocValuesField(fieldName, new BytesRef("2")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "2", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("2")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "1", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("1")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "2", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("2")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "3", Field.Store.NO), - new KeywordField(filterFieldName, "foo", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("3")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "4", Field.Store.NO), - new KeywordField(filterFieldName, "bar", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("4")) - )); - iw.addDocument(asList( - new KeywordField(fieldName, "5", Field.Store.NO), - new KeywordField(filterFieldName, "bar", Field.Store.NO), - new SortedSetDocValuesField(fieldName, new BytesRef("5")) - )); - }, card -> { + iw.addDocument( + asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "1", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("1")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "2", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("2")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "3", Field.Store.NO), + new KeywordField(filterFieldName, "foo", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("3")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "4", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("4")) + ) + ); + iw.addDocument( + asList( + new KeywordField(fieldName, "5", Field.Store.NO), + new KeywordField(filterFieldName, "bar", Field.Store.NO), + new SortedSetDocValuesField(fieldName, new BytesRef("5")) + ) + ); + }, card -> { assertEquals(3.0, card.getValue(), 0); assertTrue(AggregationInspectionHelper.hasValue(card)); }, fieldType);