Skip to content

Commit

Permalink
Add timeout and cancellation check to rescore phase (#115048) (#115131)
Browse files Browse the repository at this point in the history
This adds cancellation checks to rescore phase. This cancellation checks
for the parent task being cancelled and for timeout checks.

The assumption is that rescore is always significantly more expensive
than a regular query, so we check for timeout as frequently as the most
frequent check in ExitableDirectoryReader.

For LTR, we check on hit inference. Maybe we should also check for per
feature extraction?

For QueryRescorer, we check in the combine method.

closes: #114955
  • Loading branch information
benwtrent authored Oct 18, 2024
1 parent bf70b46 commit a2a84a1
Show file tree
Hide file tree
Showing 8 changed files with 320 additions and 1 deletion.
5 changes: 5 additions & 0 deletions docs/changelog/115048.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 115048
summary: Add timeout and cancellation check to rescore phase
area: Ranking
type: enhancement
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,45 @@

package org.elasticsearch.search.functionscore;

import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.lucene.tests.util.English;
import org.elasticsearch.TransportVersion;
import org.elasticsearch.action.index.IndexRequestBuilder;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.action.search.SearchType;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.lucene.search.function.CombineFunction;
import org.elasticsearch.common.lucene.search.function.LeafScoreFunction;
import org.elasticsearch.common.lucene.search.function.ScoreFunction;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.settings.Settings.Builder;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.query.Operator;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.SearchExecutionContext;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder;
import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.SearchHit;
import org.elasticsearch.search.SearchHits;
import org.elasticsearch.search.collapse.CollapseBuilder;
import org.elasticsearch.search.rescore.QueryRescoreMode;
import org.elasticsearch.search.rescore.QueryRescorerBuilder;
import org.elasticsearch.search.sort.SortBuilders;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentFactory;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Comparator;
import java.util.List;

Expand Down Expand Up @@ -979,9 +993,119 @@ public void testRescoreAfterCollapseRandom() throws Exception {
});
}

public void testRescoreWithTimeout() throws Exception {
// no dummy docs since merges can change scores while we run queries.
int numDocs = indexRandomNumbers("whitespace", -1, false);

String intToEnglish = English.intToEnglish(between(0, numDocs - 1));
String query = intToEnglish.split(" ")[0];
assertResponse(
prepareSearch().setSearchType(SearchType.QUERY_THEN_FETCH)
.setQuery(QueryBuilders.matchQuery("field1", query).operator(Operator.OR))
.setSize(10)
.addRescorer(new QueryRescorerBuilder(functionScoreQuery(new TestTimedScoreFunctionBuilder())).windowSize(100))
.setTimeout(TimeValue.timeValueMillis(10)),
r -> assertTrue(r.isTimedOut())
);
}

@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(TestTimedQueryPlugin.class);
}

private QueryBuilder fieldValueScoreQuery(String scoreField) {
return functionScoreQuery(termQuery("shouldFilter", false), ScoreFunctionBuilders.fieldValueFactorFunction(scoreField)).boostMode(
CombineFunction.REPLACE
);
}

public static class TestTimedQueryPlugin extends Plugin implements SearchPlugin {
@Override
public List<ScoreFunctionSpec<?>> getScoreFunctions() {
return List.of(
new ScoreFunctionSpec<>(
new ParseField("timed"),
TestTimedScoreFunctionBuilder::new,
p -> new TestTimedScoreFunctionBuilder()
)
);
}
}

static class TestTimedScoreFunctionBuilder extends ScoreFunctionBuilder<TestTimedScoreFunctionBuilder> {
private final long time = 500;

TestTimedScoreFunctionBuilder() {}

TestTimedScoreFunctionBuilder(StreamInput in) throws IOException {
super(in);
}

@Override
protected void doWriteTo(StreamOutput out) {}

@Override
public String getName() {
return "timed";
}

@Override
protected void doXContent(XContentBuilder builder, Params params) {}

@Override
protected boolean doEquals(TestTimedScoreFunctionBuilder functionBuilder) {
return false;
}

@Override
protected int doHashCode() {
return 0;
}

@Override
protected ScoreFunction doToFunction(SearchExecutionContext context) throws IOException {
return new ScoreFunction(REPLACE) {
@Override
public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException {
return new LeafScoreFunction() {
@Override
public double score(int docId, float subQueryScore) {
try {
Thread.sleep(time);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
return time;
}

@Override
public Explanation explainScore(int docId, Explanation subQueryScore) {
return null;
}
};
}

@Override
public boolean needsScores() {
return true;
}

@Override
protected boolean doEquals(ScoreFunction other) {
return false;
}

@Override
protected int doHashCode() {
return 0;
}
};
}

@Override
public TransportVersion getMinimalSupportedVersion() {
return TransportVersion.current();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ public void throwTimeExceededException() {
}
}

private static class TimeExceededException extends RuntimeException {
public static class TimeExceededException extends RuntimeException {
// This exception should never be re-thrown, but we fill in the stacktrace to be able to trace where it does not get properly caught
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

public final class QueryRescorer implements Rescorer {

private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10;
public static final Rescorer INSTANCE = new QueryRescorer();

@Override
Expand All @@ -39,9 +40,14 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r
final QueryRescoreContext rescore = (QueryRescoreContext) rescoreContext;

org.apache.lucene.search.Rescorer rescorer = new org.apache.lucene.search.QueryRescorer(rescore.parsedQuery().query()) {
int count = 0;

@Override
protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) {
if (count % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) {
rescore.checkCancellation();
}
count++;
if (secondPassMatches) {
return rescore.scoreMode.combine(
firstPassScore * rescore.queryWeight(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ public class RescoreContext {
private final int windowSize;
private final Rescorer rescorer;
private Set<Integer> rescoredDocs; // doc Ids for which rescoring was applied
private Runnable isCancelled;

/**
* Build the context.
Expand All @@ -34,6 +35,16 @@ public RescoreContext(int windowSize, Rescorer rescorer) {
this.rescorer = rescorer;
}

public void setCancellationChecker(Runnable isCancelled) {
this.isCancelled = isCancelled;
}

public void checkCancellation() {
if (isCancelled != null) {
isCancelled.run();
}
}

/**
* The rescorer to actually apply.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,18 @@
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.search.SearchShardTask;
import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore;
import org.elasticsearch.common.util.Maps;
import org.elasticsearch.lucene.grouping.TopFieldGroups;
import org.elasticsearch.search.internal.ContextIndexSearcher;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.query.QueryPhase;
import org.elasticsearch.search.query.SearchTimeoutException;

import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

/**
Expand All @@ -44,11 +50,14 @@ public static void execute(SearchContext context) {
topGroups = topFieldGroups;
}
try {
Runnable cancellationCheck = getCancellationChecks(context);
for (RescoreContext ctx : context.rescore()) {
ctx.setCancellationChecker(cancellationCheck);
topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx);
// It is the responsibility of the rescorer to sort the resulted top docs,
// here we only assert that this condition is met.
assert context.sort() == null && topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore";
ctx.setCancellationChecker(null);
}
if (topGroups != null) {
assert context.collapse() != null;
Expand All @@ -63,6 +72,11 @@ public static void execute(SearchContext context) {
.topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats());
} catch (IOException e) {
throw new ElasticsearchException("Rescore Phase Failed", e);
} catch (ContextIndexSearcher.TimeExceededException e) {
if (context.request().allowPartialSearchResults() == false) {
throw new SearchTimeoutException(context.shardTarget(), "Time exceeded");
}
context.queryResult().searchTimedOut(true);
}
}

Expand Down Expand Up @@ -106,4 +120,27 @@ private static boolean topDocsSortedByScore(TopDocs topDocs) {
}
return true;
}

static Runnable getCancellationChecks(SearchContext context) {
List<Runnable> cancellationChecks = new ArrayList<>();
if (context.lowLevelCancellation()) {
cancellationChecks.add(() -> {
final SearchShardTask task = context.getTask();
if (task != null) {
task.ensureNotCancelled();
}
});
}

final Runnable timeoutRunnable = QueryPhase.getTimeoutCheck(context);
if (timeoutRunnable != null) {
cancellationChecks.add(timeoutRunnable);
}

return () -> {
for (var check : cancellationChecks) {
check.run();
}
};
}
}
Loading

0 comments on commit a2a84a1

Please sign in to comment.