diff --git a/docs/changelog/115048.yaml b/docs/changelog/115048.yaml new file mode 100644 index 0000000000000..10844b83c6d01 --- /dev/null +++ b/docs/changelog/115048.yaml @@ -0,0 +1,5 @@ +pr: 115048 +summary: Add timeout and cancellation check to rescore phase +area: Ranking +type: enhancement +issues: [] diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java index 025d224923dc0..6043688b7670a 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/functionscore/QueryRescorerIT.java @@ -9,19 +9,30 @@ package org.elasticsearch.search.functionscore; +import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Explanation; import org.apache.lucene.tests.util.English; +import org.elasticsearch.TransportVersion; import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.action.search.SearchRequestBuilder; import org.elasticsearch.action.search.SearchResponse; import org.elasticsearch.action.search.SearchType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.lucene.search.function.CombineFunction; +import org.elasticsearch.common.lucene.search.function.LeafScoreFunction; +import org.elasticsearch.common.lucene.search.function.ScoreFunction; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.settings.Settings.Builder; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.query.Operator; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilder; import org.elasticsearch.index.query.functionscore.ScoreFunctionBuilders; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; import org.elasticsearch.search.SearchHit; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.collapse.CollapseBuilder; @@ -29,11 +40,14 @@ import org.elasticsearch.search.rescore.QueryRescorerBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.Comparator; import java.util.List; @@ -979,9 +993,119 @@ public void testRescoreAfterCollapseRandom() throws Exception { }); } + public void testRescoreWithTimeout() throws Exception { + // no dummy docs since merges can change scores while we run queries. + int numDocs = indexRandomNumbers("whitespace", -1, false); + + String intToEnglish = English.intToEnglish(between(0, numDocs - 1)); + String query = intToEnglish.split(" ")[0]; + assertResponse( + prepareSearch().setSearchType(SearchType.QUERY_THEN_FETCH) + .setQuery(QueryBuilders.matchQuery("field1", query).operator(Operator.OR)) + .setSize(10) + .addRescorer(new QueryRescorerBuilder(functionScoreQuery(new TestTimedScoreFunctionBuilder())).windowSize(100)) + .setTimeout(TimeValue.timeValueMillis(10)), + r -> assertTrue(r.isTimedOut()) + ); + } + + @Override + protected Collection> nodePlugins() { + return List.of(TestTimedQueryPlugin.class); + } + private QueryBuilder fieldValueScoreQuery(String scoreField) { return functionScoreQuery(termQuery("shouldFilter", false), ScoreFunctionBuilders.fieldValueFactorFunction(scoreField)).boostMode( CombineFunction.REPLACE ); } + + public static class TestTimedQueryPlugin extends Plugin implements SearchPlugin { + @Override + public List> getScoreFunctions() { + return List.of( + new ScoreFunctionSpec<>( + new ParseField("timed"), + TestTimedScoreFunctionBuilder::new, + p -> new TestTimedScoreFunctionBuilder() + ) + ); + } + } + + static class TestTimedScoreFunctionBuilder extends ScoreFunctionBuilder { + private final long time = 500; + + TestTimedScoreFunctionBuilder() {} + + TestTimedScoreFunctionBuilder(StreamInput in) throws IOException { + super(in); + } + + @Override + protected void doWriteTo(StreamOutput out) {} + + @Override + public String getName() { + return "timed"; + } + + @Override + protected void doXContent(XContentBuilder builder, Params params) {} + + @Override + protected boolean doEquals(TestTimedScoreFunctionBuilder functionBuilder) { + return false; + } + + @Override + protected int doHashCode() { + return 0; + } + + @Override + protected ScoreFunction doToFunction(SearchExecutionContext context) throws IOException { + return new ScoreFunction(REPLACE) { + @Override + public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOException { + return new LeafScoreFunction() { + @Override + public double score(int docId, float subQueryScore) { + try { + Thread.sleep(time); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return time; + } + + @Override + public Explanation explainScore(int docId, Explanation subQueryScore) { + return null; + } + }; + } + + @Override + public boolean needsScores() { + return true; + } + + @Override + protected boolean doEquals(ScoreFunction other) { + return false; + } + + @Override + protected int doHashCode() { + return 0; + } + }; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersion.current(); + } + } } diff --git a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java index 18de4b81cbf8c..da5d2d093fbd8 100644 --- a/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java +++ b/server/src/main/java/org/elasticsearch/search/internal/ContextIndexSearcher.java @@ -407,7 +407,7 @@ public void throwTimeExceededException() { } } - private static class TimeExceededException extends RuntimeException { + public static class TimeExceededException extends RuntimeException { // This exception should never be re-thrown, but we fill in the stacktrace to be able to trace where it does not get properly caught } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java b/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java index 5cd947a1cc73b..cb9169dbeb5e5 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/QueryRescorer.java @@ -26,6 +26,7 @@ public final class QueryRescorer implements Rescorer { + private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10; public static final Rescorer INSTANCE = new QueryRescorer(); @Override @@ -39,9 +40,14 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r final QueryRescoreContext rescore = (QueryRescoreContext) rescoreContext; org.apache.lucene.search.Rescorer rescorer = new org.apache.lucene.search.QueryRescorer(rescore.parsedQuery().query()) { + int count = 0; @Override protected float combine(float firstPassScore, boolean secondPassMatches, float secondPassScore) { + if (count % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) { + rescore.checkCancellation(); + } + count++; if (secondPassMatches) { return rescore.scoreMode.combine( firstPassScore * rescore.queryWeight(), diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java b/server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java index 297b197a6d0c1..0ae6c326ddcdc 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescoreContext.java @@ -24,6 +24,7 @@ public class RescoreContext { private final int windowSize; private final Rescorer rescorer; private Set rescoredDocs; // doc Ids for which rescoring was applied + private Runnable isCancelled; /** * Build the context. @@ -34,6 +35,16 @@ public RescoreContext(int windowSize, Rescorer rescorer) { this.rescorer = rescorer; } + public void setCancellationChecker(Runnable isCancelled) { + this.isCancelled = isCancelled; + } + + public void checkCancellation() { + if (isCancelled != null) { + isCancelled.run(); + } + } + /** * The rescorer to actually apply. */ diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java index be961b8ef942b..1227db5d8e1db 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorePhase.java @@ -14,12 +14,18 @@ import org.apache.lucene.search.SortField; import org.apache.lucene.search.TopDocs; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.search.SearchShardTask; import org.elasticsearch.common.lucene.search.TopDocsAndMaxScore; import org.elasticsearch.common.util.Maps; import org.elasticsearch.lucene.grouping.TopFieldGroups; +import org.elasticsearch.search.internal.ContextIndexSearcher; import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.search.query.QueryPhase; +import org.elasticsearch.search.query.SearchTimeoutException; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Map; /** @@ -44,11 +50,14 @@ public static void execute(SearchContext context) { topGroups = topFieldGroups; } try { + Runnable cancellationCheck = getCancellationChecks(context); for (RescoreContext ctx : context.rescore()) { + ctx.setCancellationChecker(cancellationCheck); topDocs = ctx.rescorer().rescore(topDocs, context.searcher(), ctx); // It is the responsibility of the rescorer to sort the resulted top docs, // here we only assert that this condition is met. assert context.sort() == null && topDocsSortedByScore(topDocs) : "topdocs should be sorted after rescore"; + ctx.setCancellationChecker(null); } if (topGroups != null) { assert context.collapse() != null; @@ -63,6 +72,11 @@ public static void execute(SearchContext context) { .topDocs(new TopDocsAndMaxScore(topDocs, topDocs.scoreDocs[0].score), context.queryResult().sortValueFormats()); } catch (IOException e) { throw new ElasticsearchException("Rescore Phase Failed", e); + } catch (ContextIndexSearcher.TimeExceededException e) { + if (context.request().allowPartialSearchResults() == false) { + throw new SearchTimeoutException(context.shardTarget(), "Time exceeded"); + } + context.queryResult().searchTimedOut(true); } } @@ -106,4 +120,27 @@ private static boolean topDocsSortedByScore(TopDocs topDocs) { } return true; } + + static Runnable getCancellationChecks(SearchContext context) { + List cancellationChecks = new ArrayList<>(); + if (context.lowLevelCancellation()) { + cancellationChecks.add(() -> { + final SearchShardTask task = context.getTask(); + if (task != null) { + task.ensureNotCancelled(); + } + }); + } + + final Runnable timeoutRunnable = QueryPhase.getTimeoutCheck(context); + if (timeoutRunnable != null) { + cancellationChecks.add(timeoutRunnable); + } + + return () -> { + for (var check : cancellationChecks) { + check.run(); + } + }; + } } diff --git a/server/src/test/java/org/elasticsearch/search/rescore/RescorePhaseTests.java b/server/src/test/java/org/elasticsearch/search/rescore/RescorePhaseTests.java new file mode 100644 index 0000000000000..5a1c4b789b460 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/search/rescore/RescorePhaseTests.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.search.rescore; + +import org.apache.lucene.document.Document; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.NoMergePolicy; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.MatchAllDocsQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryCachingPolicy; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.elasticsearch.action.search.SearchShardTask; +import org.elasticsearch.index.query.ParsedQuery; +import org.elasticsearch.index.shard.IndexShard; +import org.elasticsearch.index.shard.IndexShardTestCase; +import org.elasticsearch.search.fetch.subphase.FetchDocValuesContext; +import org.elasticsearch.search.fetch.subphase.FetchFieldsContext; +import org.elasticsearch.search.internal.ContextIndexSearcher; +import org.elasticsearch.search.internal.FilteredSearchContext; +import org.elasticsearch.search.internal.SearchContext; +import org.elasticsearch.tasks.TaskCancelHelper; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.test.TestSearchContext; + +import java.io.IOException; +import java.util.Collections; + +public class RescorePhaseTests extends IndexShardTestCase { + + public void testRescorePhaseCancellation() throws IOException { + IndexWriterConfig iwc = newIndexWriterConfig().setMergePolicy(NoMergePolicy.INSTANCE); + try (Directory dir = newDirectory()) { + try (RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc)) { + final int numDocs = scaledRandomIntBetween(100, 200); + for (int i = 0; i < numDocs; ++i) { + Document doc = new Document(); + w.addDocument(doc); + } + } + try (IndexReader reader = DirectoryReader.open(dir)) { + ContextIndexSearcher s = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + new QueryCachingPolicy() { + @Override + public void onUse(Query query) {} + + @Override + public boolean shouldCache(Query query) { + return false; + } + }, + true + ); + IndexShard shard = newShard(true); + try (TestSearchContext context = new TestSearchContext(null, shard, s)) { + context.parsedQuery(new ParsedQuery(new MatchAllDocsQuery())); + SearchShardTask task = new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()); + context.setTask(task); + SearchContext wrapped = new FilteredSearchContext(context) { + @Override + public boolean lowLevelCancellation() { + return true; + } + + @Override + public FetchDocValuesContext docValuesContext() { + return context.docValuesContext(); + } + + @Override + public SearchContext docValuesContext(FetchDocValuesContext docValuesContext) { + return context.docValuesContext(docValuesContext); + } + + @Override + public FetchFieldsContext fetchFieldsContext() { + return context.fetchFieldsContext(); + } + + @Override + public SearchContext fetchFieldsContext(FetchFieldsContext fetchFieldsContext) { + return context.fetchFieldsContext(fetchFieldsContext); + } + }; + try (wrapped) { + Runnable cancellationChecks = RescorePhase.getCancellationChecks(wrapped); + assertNotNull(cancellationChecks); + TaskCancelHelper.cancel(task, "test cancellation"); + assertTrue(wrapped.isCancelled()); + expectThrows(TaskCancelledException.class, cancellationChecks::run); + QueryRescorer.QueryRescoreContext rescoreContext = new QueryRescorer.QueryRescoreContext(10); + rescoreContext.setQuery(new ParsedQuery(new MatchAllDocsQuery())); + rescoreContext.setCancellationChecker(cancellationChecks); + expectThrows( + TaskCancelledException.class, + () -> new QueryRescorer().rescore( + new TopDocs( + new TotalHits(10, TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(0, 1.0f) } + ), + context.searcher(), + rescoreContext + ) + ); + } + } + closeShards(shard); + } + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java index 70d0b980bb3bf..54a9fe908fa87 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorer.java @@ -35,6 +35,7 @@ public class LearningToRankRescorer implements Rescorer { + private static final int MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK = 10; public static final LearningToRankRescorer INSTANCE = new LearningToRankRescorer(); private static final Logger logger = LogManager.getLogger(LearningToRankRescorer.class); @@ -78,7 +79,12 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r List featureExtractors = ltrRescoreContext.buildFeatureExtractors(searcher); List> docFeatures = new ArrayList<>(topDocIDs.size()); int featureSize = featureExtractors.stream().mapToInt(fe -> fe.featureNames().size()).sum(); + int count = 0; while (hitUpto < hitsToRescore.length) { + if (count % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) { + rescoreContext.checkCancellation(); + } + count++; final ScoreDoc hit = hitsToRescore[hitUpto]; final int docID = hit.doc; while (docID >= endDoc) { @@ -106,6 +112,9 @@ public TopDocs rescore(TopDocs topDocs, IndexSearcher searcher, RescoreContext r hitUpto++; } for (int i = 0; i < hitsToRescore.length; i++) { + if (i % MAX_CALLS_BEFORE_QUERY_TIMEOUT_CHECK == 0) { + rescoreContext.checkCancellation(); + } Map features = docFeatures.get(i); try { InferenceResults results = definition.inferLtr(features, ltrRescoreContext.learningToRankConfig);