diff --git a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java index 1597b31e89871..0876bf93a557b 100644 --- a/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/opensearch/action/search/AbstractSearchAsyncAction.java @@ -302,7 +302,16 @@ public void onFailure(Exception t) { * It is possible to run into connection exceptions here because we are getting the connection early and might * run into nodes that are not connected. In this case, on shard failure will move us to the next shard copy. */ - fork(() -> onShardFailure(shardIndex, shard, shardIt, e)); + fork(() -> { + // It only happens when onPhaseDone() is called and executePhaseOnShard() fails hard with an exception. + // In this case calling onShardFailure() would overflow the operations counter, so the best we could do + // here is to fail the phase and move on to the next one. + if (totalOps.get() == expectedTotalOps) { + onPhaseFailure(this, "The phase has failed", e); + } else { + onShardFailure(shardIndex, shard, shardIt, e); + } + }); } finally { executeNext(pendingExecutions, thread); } diff --git a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java index b44b59b8a4ad5..ad2657517df9a 100644 --- a/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java +++ b/server/src/test/java/org/opensearch/action/search/AbstractSearchAsyncActionTests.java @@ -477,6 +477,57 @@ public void onFailure(Exception e) { assertThat(searchResponse.getSuccessfulShards(), equalTo(shards.length)); } + public void testExecutePhaseOnShardFailure() throws InterruptedException { + final Index index = new Index("test", UUID.randomUUID().toString()); + + final SearchShardIterator[] shards = IntStream.range(0, 2 + randomInt(3)) + .mapToObj(i -> new SearchShardIterator(null, new ShardId(index, i), List.of("n1", "n2", "n3"), null, null, null)) + .toArray(SearchShardIterator[]::new); + + final AtomicBoolean fail = new AtomicBoolean(true); + final CountDownLatch latch = new CountDownLatch(1); + SearchRequest searchRequest = new SearchRequest().allowPartialSearchResults(true); + searchRequest.setMaxConcurrentShardRequests(5); + + final ArraySearchPhaseResults queryResult = new ArraySearchPhaseResults<>(shards.length); + AbstractSearchAsyncAction action = createAction( + searchRequest, + queryResult, + new ActionListener() { + @Override + public void onResponse(SearchResponse response) {} + + @Override + public void onFailure(Exception e) { + try { + // We end up here only when onPhaseDone() is called (causing NPE) and + // ending up in the onPhaseFailure() callback + if (fail.compareAndExchange(true, false)) { + assertThat(e, instanceOf(SearchPhaseExecutionException.class)); + throw new RuntimeException("Simulated exception"); + } + } finally { + executor.submit(() -> latch.countDown()); + } + } + }, + false, + false, + new AtomicLong(), + shards + ); + action.run(); + assertTrue(latch.await(1, TimeUnit.SECONDS)); + + InternalSearchResponse internalSearchResponse = InternalSearchResponse.empty(); + SearchResponse searchResponse = action.buildSearchResponse(internalSearchResponse, action.buildShardFailures(), null, null); + assertSame(searchResponse.getAggregations(), internalSearchResponse.aggregations()); + assertSame(searchResponse.getSuggest(), internalSearchResponse.suggest()); + assertSame(searchResponse.getProfileResults(), internalSearchResponse.profile()); + assertSame(searchResponse.getHits(), internalSearchResponse.hits()); + assertThat(searchResponse.getSuccessfulShards(), equalTo(shards.length)); + } + private static final class PhaseResult extends SearchPhaseResult { PhaseResult(ShardSearchContextId contextId) { this.contextId = contextId;