Skip to content

Commit

Permalink
Remove LeafSimScorer abstraction. (#13957)
Browse files Browse the repository at this point in the history
`LeafSimScorer` is a specialization of a `SimScorer` for a given segment. It
doesn't add much value, but benchmarks suggest that it adds measurable overhead
to queries sorted by score.
  • Loading branch information
jpountz committed Oct 26, 2024
1 parent 0e918ce commit 054c60e
Show file tree
Hide file tree
Showing 23 changed files with 203 additions and 205 deletions.
2 changes: 2 additions & 0 deletions lucene/CHANGES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ API Changes

* GITHUB#13950: Make BooleanQuery#getClauses public and add #add(Collection<BooleanClause>) to BQ builder. (Shubham Chaudhary)

* GITHUB#13957: Removed LeafSimScorer class, to save its overhead. Scorers now
compute scores directly from a SimScorer, postings and norms. (Adrien Grand)

New Features
---------------------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafSimScorer;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.QueryVisitor;
import org.apache.lucene.search.ScoreMode;
Expand Down Expand Up @@ -120,7 +119,6 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio

@Override
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
final Weight thisWeight = this;
Terms terms = Terms.getTerms(context.reader(), fieldName);
TermsEnum termsEnum = terms.iterator();
if (termsEnum.seekExact(new BytesRef(featureName)) == false) {
Expand All @@ -135,10 +133,8 @@ public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOExcepti
@Override
public Scorer get(long leadCost) throws IOException {
final SimScorer scorer = function.scorer(boost);
final LeafSimScorer simScorer =
new LeafSimScorer(scorer, context.reader(), fieldName, false);
final ImpactsEnum impacts = termsEnum.impacts(PostingsEnum.FREQS);
return new TermScorer(thisWeight, impacts, simScorer, topLevelScoringClause);
return new TermScorer(impacts, scorer, null, topLevelScoringClause);
}

@Override
Expand Down
72 changes: 0 additions & 72 deletions lucene/core/src/java/org/apache/lucene/search/LeafSimScorer.java

This file was deleted.

21 changes: 17 additions & 4 deletions lucene/core/src/java/org/apache/lucene/search/PhraseScorer.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
package org.apache.lucene.search;

import java.io.IOException;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.similarities.Similarity.SimScorer;

class PhraseScorer extends Scorer {

Expand All @@ -26,16 +28,19 @@ class PhraseScorer extends Scorer {
final MaxScoreCache maxScoreCache;
final PhraseMatcher matcher;
final ScoreMode scoreMode;
private final LeafSimScorer simScorer;
private final SimScorer simScorer;
private final NumericDocValues norms;
final float matchCost;

private float minCompetitiveScore = 0;
private float freq = 0;

PhraseScorer(PhraseMatcher matcher, ScoreMode scoreMode, LeafSimScorer simScorer) {
PhraseScorer(
PhraseMatcher matcher, ScoreMode scoreMode, SimScorer simScorer, NumericDocValues norms) {
this.matcher = matcher;
this.scoreMode = scoreMode;
this.simScorer = simScorer;
this.norms = norms;
this.matchCost = matcher.getMatchCost();
this.approximation = matcher.approximation();
this.impactsApproximation = matcher.impactsApproximation();
Expand All @@ -50,7 +55,11 @@ public boolean matches() throws IOException {
matcher.reset();
if (scoreMode == ScoreMode.TOP_SCORES && minCompetitiveScore > 0) {
float maxFreq = matcher.maxFreq();
if (simScorer.score(docID(), maxFreq) < minCompetitiveScore) {
long norm = 1L;
if (norms != null && norms.advanceExact(docID())) {
norm = norms.longValue();
}
if (simScorer.score(maxFreq, norm) < minCompetitiveScore) {
// The maximum score we could get is less than the min competitive score
return false;
}
Expand Down Expand Up @@ -79,7 +88,11 @@ public float score() throws IOException {
freq += matcher.sloppyWeight();
}
}
return simScorer.score(docID(), freq);
long norm = 1L;
if (norms != null && norms.advanceExact(docID())) {
norm = norms.longValue();
}
return simScorer.score(freq, norm);
}

@Override
Expand Down
15 changes: 9 additions & 6 deletions lucene/core/src/java/org/apache/lucene/search/PhraseWeight.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.IOException;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;

Expand Down Expand Up @@ -63,9 +64,8 @@ protected abstract PhraseMatcher getPhraseMatcher(
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
PhraseMatcher matcher = getPhraseMatcher(context, stats, false);
if (matcher == null) return null;
LeafSimScorer simScorer =
new LeafSimScorer(stats, context.reader(), field, scoreMode.needsScores());
final var scorer = new PhraseScorer(matcher, scoreMode, simScorer);
NumericDocValues norms = scoreMode.needsScores() ? context.reader().getNormValues(field) : null;
final var scorer = new PhraseScorer(matcher, scoreMode, stats, norms);
return new DefaultScorerSupplier(scorer);
}

Expand All @@ -83,10 +83,13 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
while (matcher.nextMatch()) {
freq += matcher.sloppyWeight();
}
LeafSimScorer docScorer =
new LeafSimScorer(stats, context.reader(), field, scoreMode.needsScores());
Explanation freqExplanation = Explanation.match(freq, "phraseFreq=" + freq);
Explanation scoreExplanation = docScorer.explain(doc, freqExplanation);
NumericDocValues norms = scoreMode.needsScores() ? context.reader().getNormValues(field) : null;
long norm = 1L;
if (norms != null && norms.advanceExact(doc)) {
norm = norms.longValue();
}
Explanation scoreExplanation = stats.explain(freqExplanation, norm);
return Explanation.match(
scoreExplanation.getValue(),
"weight("
Expand Down
56 changes: 37 additions & 19 deletions lucene/core/src/java/org/apache/lucene/search/SynonymQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import org.apache.lucene.index.ImpactsEnum;
import org.apache.lucene.index.ImpactsSource;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.index.NumericDocValues;
import org.apache.lucene.index.PostingsEnum;
import org.apache.lucene.index.SlowImpactsEnum;
import org.apache.lucene.index.Term;
Expand All @@ -38,6 +39,7 @@
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.similarities.Similarity;
import org.apache.lucene.search.similarities.Similarity.SimScorer;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.IOSupplier;
import org.apache.lucene.util.PriorityQueue;
Expand Down Expand Up @@ -259,9 +261,13 @@ public Explanation explain(LeafReaderContext context, int doc) throws IOExceptio
assert scorer instanceof TermScorer;
freq = ((TermScorer) scorer).freq();
}
LeafSimScorer docScorer = new LeafSimScorer(simWeight, context.reader(), field, true);
Explanation freqExplanation = Explanation.match(freq, "termFreq=" + freq);
Explanation scoreExplanation = docScorer.explain(doc, freqExplanation);
NumericDocValues norms = context.reader().getNormValues(field);
long norm = 1L;
if (norms != null && norms.advanceExact(doc)) {
norm = norms.longValue();
}
Explanation scoreExplanation = simWeight.explain(freqExplanation, norm);
return Explanation.match(
scoreExplanation.getValue(),
"weight("
Expand Down Expand Up @@ -334,27 +340,27 @@ public Scorer get(long leadCost) throws IOException {
return new ConstantScoreScorer(0f, scoreMode, DocIdSetIterator.empty());
}

LeafSimScorer simScorer = new LeafSimScorer(simWeight, context.reader(), field, true);
NumericDocValues norms = context.reader().getNormValues(field);

// we must optimize this case (term not in segment), disjunctions require >= 2 subs
if (iterators.size() == 1) {
final TermScorer scorer;
if (scoreMode == ScoreMode.TOP_SCORES) {
scorer = new TermScorer(impacts.get(0), simScorer);
scorer = new TermScorer(impacts.get(0), simWeight, norms);
} else {
scorer = new TermScorer(iterators.get(0), simScorer);
scorer = new TermScorer(iterators.get(0), simWeight, norms);
}
float boost = termBoosts.get(0);
return scoreMode == ScoreMode.COMPLETE_NO_SCORES || boost == 1f
? scorer
: new FreqBoostTermScorer(boost, scorer, simScorer);
: new FreqBoostTermScorer(boost, scorer, simWeight, norms);
} else {

// we use termscorers + disjunction as an impl detail
DisiPriorityQueue queue = new DisiPriorityQueue(iterators.size());
for (int i = 0; i < iterators.size(); i++) {
PostingsEnum postings = iterators.get(i);
final TermScorer termScorer = new TermScorer(postings, simScorer);
final TermScorer termScorer = new TermScorer(postings, simWeight, norms);
float boost = termBoosts.get(i);
final DisiWrapperFreq wrapper = new DisiWrapperFreq(termScorer, boost);
queue.add(wrapper);
Expand All @@ -368,8 +374,7 @@ public Scorer get(long leadCost) throws IOException {
boosts[i] = termBoosts.get(i);
}
ImpactsSource impactsSource = mergeImpacts(impacts.toArray(new ImpactsEnum[0]), boosts);
MaxScoreCache maxScoreCache =
new MaxScoreCache(impactsSource, simScorer.getSimScorer());
MaxScoreCache maxScoreCache = new MaxScoreCache(impactsSource, simWeight);
ImpactsDISI impactsDisi = new ImpactsDISI(iterator, maxScoreCache);

if (scoreMode == ScoreMode.TOP_SCORES) {
Expand All @@ -379,7 +384,7 @@ public Scorer get(long leadCost) throws IOException {
iterator = impactsDisi;
}

return new SynonymScorer(queue, iterator, impactsDisi, simScorer);
return new SynonymScorer(queue, iterator, impactsDisi, simWeight, norms);
}
}

Expand Down Expand Up @@ -575,18 +580,21 @@ private static class SynonymScorer extends Scorer {
private final DocIdSetIterator iterator;
private final MaxScoreCache maxScoreCache;
private final ImpactsDISI impactsDisi;
private final LeafSimScorer simScorer;
private final SimScorer scorer;
private final NumericDocValues norms;

SynonymScorer(
DisiPriorityQueue queue,
DocIdSetIterator iterator,
ImpactsDISI impactsDisi,
LeafSimScorer simScorer) {
SimScorer scorer,
NumericDocValues norms) {
this.queue = queue;
this.iterator = iterator;
this.maxScoreCache = impactsDisi.getMaxScoreCache();
this.impactsDisi = impactsDisi;
this.simScorer = simScorer;
this.scorer = scorer;
this.norms = norms;
}

@Override
Expand All @@ -605,7 +613,11 @@ float freq() throws IOException {

@Override
public float score() throws IOException {
return simScorer.score(iterator.docID(), freq());
long norm = 1L;
if (norms != null && norms.advanceExact(iterator.docID())) {
norm = norms.longValue();
}
return scorer.score(freq(), norm);
}

@Override
Expand Down Expand Up @@ -647,17 +659,20 @@ float freq() throws IOException {
private static class FreqBoostTermScorer extends FilterScorer {
final float boost;
final TermScorer in;
final LeafSimScorer docScorer;
final SimScorer scorer;
final NumericDocValues norms;

public FreqBoostTermScorer(float boost, TermScorer in, LeafSimScorer docScorer) {
public FreqBoostTermScorer(
float boost, TermScorer in, SimScorer scorer, NumericDocValues norms) {
super(in);
if (Float.isNaN(boost) || Float.compare(boost, 0f) < 0 || Float.compare(boost, 1f) > 0) {
throw new IllegalArgumentException(
"boost must be a positive float between 0 (exclusive) and 1 (inclusive)");
}
this.boost = boost;
this.in = in;
this.docScorer = docScorer;
this.scorer = scorer;
this.norms = norms;
}

float freq() throws IOException {
Expand All @@ -666,8 +681,11 @@ float freq() throws IOException {

@Override
public float score() throws IOException {
assert docID() != DocIdSetIterator.NO_MORE_DOCS;
return docScorer.score(in.docID(), freq());
long norm = 1L;
if (norms != null && norms.advanceExact(in.docID())) {
norm = norms.longValue();
}
return scorer.score(freq(), norm);
}

@Override
Expand Down
Loading

0 comments on commit 054c60e

Please sign in to comment.