Skip to content

Commit

Permalink
utilize competitive iterator api to perform pruning
Browse files Browse the repository at this point in the history
Signed-off-by: bowenlan-amzn <bowenlan23@gmail.com>
  • Loading branch information
bowenlan-amzn committed May 24, 2024
1 parent a18b597 commit 85133c4
Show file tree
Hide file tree
Showing 4 changed files with 193 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,15 @@
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedNumericDocValues;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.DisiPriorityQueue;
import org.apache.lucene.search.DisiWrapper;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.Weight;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.FixedBitSet;
import org.apache.lucene.util.RamUsageEstimator;
Expand All @@ -59,6 +67,7 @@
import org.opensearch.search.internal.SearchContext;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import java.util.function.BiConsumer;

Expand Down Expand Up @@ -137,8 +146,15 @@ private Collector pickCollector(LeafReaderContext ctx) throws IOException {
// only use ordinals if they don't increase memory usage by more than 25%
if (ordinalsMemoryUsage < countsMemoryUsage / 4) {
ordinalsCollectorsUsed++;
return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
context, ctx, fieldContext, source);
// return new DynamicPruningCollectorWrapper(new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
// context, ctx, fieldContext, source);
return new CompetitiveCollector(
new OrdinalsCollector(counts, ordinalValues, context.bigArrays()),
source,
ctx,
context,
fieldContext
);
}
ordinalsCollectorsOverheadTooHigh++;
}
Expand Down Expand Up @@ -217,6 +233,110 @@ abstract static class Collector extends LeafBucketCollector implements Releasabl

}

private static class CompetitiveCollector extends Collector {

private final Collector delegate;
private final DisiPriorityQueue pq;

CompetitiveCollector(
Collector delegate,
ValuesSource.Bytes.WithOrdinals source,
LeafReaderContext ctx,
SearchContext context,
FieldContext fieldContext
) throws IOException {
this.delegate = delegate;

final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx);
TermsEnum terms = ordinalValues.termsEnum();
Map<BytesRef, Scorer> postingMap = new HashMap<>();
while (terms.next() != null) {
BytesRef term = terms.term();

TermQuery termQuery = new TermQuery(new Term(fieldContext.field(), term));
Weight subWeight = context.searcher().createWeight(termQuery, ScoreMode.COMPLETE_NO_SCORES, 1f);
Scorer scorer = subWeight.scorer(ctx);

postingMap.put(term, scorer);
}
this.pq = new DisiPriorityQueue(postingMap.size());
for (Map.Entry<BytesRef, Scorer> entry : postingMap.entrySet()) {
pq.add(new DisiWrapper(entry.getValue()));
}
}

@Override
public void close() {
delegate.close();
}

@Override
public void collect(int doc, long owningBucketOrd) throws IOException {
delegate.collect(doc, owningBucketOrd);
}

@Override
public DocIdSetIterator competitiveIterator() throws IOException {
return new DisjunctionDISIWithPruning(pq);
}

@Override
public void postCollect() throws IOException {
delegate.postCollect();
}
}

private static class DisjunctionDISIWithPruning extends DocIdSetIterator {

final DisiPriorityQueue queue;

public DisjunctionDISIWithPruning(DisiPriorityQueue queue) {
this.queue = queue;
}

@Override
public int docID() {
return queue.top().doc;
}

@Override
public int nextDoc() throws IOException {
// don't expect this to be called
throw new UnsupportedOperationException();
}

@Override
public int advance(int target) throws IOException {
// more than advance to the next doc >= target
// we also do the pruning of current doc here

DisiWrapper top = queue.top();

// after collecting the doc, before advancing to target
// we can safely remove all the iterators that having this doc
if (top.doc != -1) {
int curTopDoc = top.doc;
do {
top.doc = top.approximation.advance(Integer.MAX_VALUE);
top = queue.updateTop();
} while (top.doc == curTopDoc);
}

if (top.doc >= target) return top.doc;
do {
top.doc = top.approximation.advance(target);
top = queue.updateTop();
} while (top.doc < target);
return top.doc;
}

@Override
public long cost() {
// don't expect this to be called
throw new UnsupportedOperationException();
}
}

/**
* Empty Collector for the Cardinality agg
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.util.Collection;
import java.util.List;


/**
* Clone of {@link org.apache.lucene.search} {@code DisjunctionScorer.java} in lucene with following modifications
* <p>
Expand All @@ -45,8 +44,7 @@ public class DisjunctionWithDynamicPruningScorer extends Scorer {

private Integer docID;

public DisjunctionWithDynamicPruningScorer(Weight weight, List<Scorer> subScorers)
throws IOException {
public DisjunctionWithDynamicPruningScorer(Weight weight, List<Scorer> subScorers) throws IOException {
super(weight);
if (subScorers.size() <= 1) {
throw new IllegalArgumentException("There must be at least 2 subScorers");
Expand Down Expand Up @@ -90,9 +88,9 @@ public void removeAllDISIsOnCurrentDoc() {

@Override
public DocIdSetIterator iterator() {
DocIdSetIterator disi = getIterator();
docID = disi.docID();
return new SlowDocIdPropagatorDISI(getIterator(), docID);
DocIdSetIterator disi = getIterator();
docID = disi.docID();
return new SlowDocIdPropagatorDISI(getIterator(), docID);
}

private static class SlowDocIdPropagatorDISI extends DocIdSetIterator {
Expand Down Expand Up @@ -157,13 +155,12 @@ private class TwoPhase extends TwoPhaseIterator {
private TwoPhase(DocIdSetIterator approximation, float matchCost) {
super(approximation);
this.matchCost = matchCost;
unverifiedMatches =
new PriorityQueue<DisiWrapper>(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
unverifiedMatches = new PriorityQueue<DisiWrapper>(DisjunctionWithDynamicPruningScorer.this.subScorers.size()) {
@Override
protected boolean lessThan(DisiWrapper a, DisiWrapper b) {
return a.matchCost < b.matchCost;
}
};
}

DisiWrapper getSubMatches() throws IOException {
Expand All @@ -183,7 +180,7 @@ public boolean matches() throws IOException {
verifiedMatches = null;
unverifiedMatches.clear();

for (DisiWrapper w = subScorers.topList(); w != null; ) {
for (DisiWrapper w = subScorers.topList(); w != null;) {
DisiWrapper next = w.next;

if (w.twoPhaseView == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

package org.opensearch.search.aggregations.metrics;

import org.apache.lucene.index.LeafReader;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.SortedSetDocValues;
import org.apache.lucene.index.Term;
Expand Down Expand Up @@ -38,9 +37,13 @@ class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector {
private final DocIdSetIterator disi;
private final CardinalityAggregator.Collector delegateCollector;

DynamicPruningCollectorWrapper(CardinalityAggregator.Collector delegateCollector,
SearchContext context, LeafReaderContext ctx, FieldContext fieldContext,
ValuesSource.Bytes.WithOrdinals source) throws IOException {
DynamicPruningCollectorWrapper(
CardinalityAggregator.Collector delegateCollector,
SearchContext context,
LeafReaderContext ctx,
FieldContext fieldContext,
ValuesSource.Bytes.WithOrdinals source
) throws IOException {
this.ctx = ctx;
this.delegateCollector = delegateCollector;
final SortedSetDocValues ordinalValues = source.ordinalsValues(ctx);
Expand All @@ -52,7 +55,7 @@ class DynamicPruningCollectorWrapper extends CardinalityAggregator.Collector {
// this logic should be pluggable depending on the type of leaf bucket collector by CardinalityAggregator
TermsEnum terms = ordinalValues.termsEnum();
Weight weight = context.searcher().createWeight(context.searcher().rewrite(context.query()), ScoreMode.COMPLETE_NO_SCORES, 1f);
Map<Long, Boolean> found = new HashMap<>();
Map<Long, Boolean> found = new HashMap<>(); // ord : found or not
List<Scorer> subScorers = new ArrayList<>();
while (terms.next() != null && !found.containsKey(terms.ord())) {
// TODO can we get rid of terms previously encountered in other segments?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,44 +105,58 @@ public void testDynamicPruningOrdinalCollector() throws IOException {
MappedFieldType fieldType = new KeywordFieldMapper.KeywordFieldType(fieldName);
final CardinalityAggregationBuilder aggregationBuilder = new CardinalityAggregationBuilder("_name").field(fieldName);
testAggregation(aggregationBuilder, new TermQuery(new Term(filterFieldName, "foo")), iw -> {
iw.addDocument(asList(
new KeywordField(fieldName, "1", Field.Store.NO),
new KeywordField(fieldName, "2", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("1")),
new SortedSetDocValuesField(fieldName, new BytesRef("2"))
));
iw.addDocument(asList(
new KeywordField(fieldName, "2", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("2"))
));
iw.addDocument(asList(
new KeywordField(fieldName, "1", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("1"))
));
iw.addDocument(asList(
new KeywordField(fieldName, "2", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("2"))
));
iw.addDocument(asList(
new KeywordField(fieldName, "3", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("3"))
));
iw.addDocument(asList(
new KeywordField(fieldName, "4", Field.Store.NO),
new KeywordField(filterFieldName, "bar", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("4"))
));
iw.addDocument(asList(
new KeywordField(fieldName, "5", Field.Store.NO),
new KeywordField(filterFieldName, "bar", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("5"))
));
}, card -> {
iw.addDocument(
asList(
new KeywordField(fieldName, "1", Field.Store.NO),
new KeywordField(fieldName, "2", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("1")),
new SortedSetDocValuesField(fieldName, new BytesRef("2"))
)
);
iw.addDocument(
asList(
new KeywordField(fieldName, "2", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("2"))
)
);
iw.addDocument(
asList(
new KeywordField(fieldName, "1", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("1"))
)
);
iw.addDocument(
asList(
new KeywordField(fieldName, "2", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("2"))
)
);
iw.addDocument(
asList(
new KeywordField(fieldName, "3", Field.Store.NO),
new KeywordField(filterFieldName, "foo", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("3"))
)
);
iw.addDocument(
asList(
new KeywordField(fieldName, "4", Field.Store.NO),
new KeywordField(filterFieldName, "bar", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("4"))
)
);
iw.addDocument(
asList(
new KeywordField(fieldName, "5", Field.Store.NO),
new KeywordField(filterFieldName, "bar", Field.Store.NO),
new SortedSetDocValuesField(fieldName, new BytesRef("5"))
)
);
}, card -> {
assertEquals(3.0, card.getValue(), 0);
assertTrue(AggregationInspectionHelper.hasValue(card));
}, fieldType);
Expand Down

0 comments on commit 85133c4

Please sign in to comment.