diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 4515e3a5c2821..892de9916124f 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -25,6 +25,7 @@ import java.util.concurrent.TimeUnit; import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; import org.slf4j.Logger; @@ -87,7 +88,16 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) /** Number of times we've attempted to retry so far. */ private int retryCount = 0; - private boolean saslTimeoutSeen; + // Number of times SASL timeout has been retried without success. + // If we see maxRetries consecutive failures, the request is failed. + // On the other hand, if sasl succeeds and we are able to send other requests subsequently, + // we reduce the SASL failures from retryCount (since SASL failures were part of + // connection bootstrap - which ended up being successful). + // spark.network.auth.rpcTimeout is much lower than spark.network.timeout and others - + // and so sasl is more susceptible to failures when remote service + // (like external shuffle service) is under load: but once it succeeds, we do not want to + // include it as part of request retries. + private int saslRetryCount = 0; /** * Set of all block ids which have not been transferred successfully or with a non-IO Exception. @@ -123,7 +133,7 @@ public RetryingBlockTransferor( this.currentListener = new RetryingBlockTransferListener(); this.errorHandler = errorHandler; this.enableSaslRetries = conf.enableSaslRetries(); - this.saslTimeoutSeen = false; + this.saslRetryCount = 0; } public RetryingBlockTransferor( @@ -167,7 +177,7 @@ private void transferAllOutstanding() { numRetries > 0 ? "(after " + numRetries + " retries)" : ""), e); if (shouldRetry(e)) { - initiateRetry(); + initiateRetry(e); } else { for (String bid : blockIdsToTransfer) { listener.onBlockTransferFailure(bid, e); @@ -180,7 +190,10 @@ private void transferAllOutstanding() { * Lightweight method which initiates a retry in a different thread. The retry will involve * calling transferAllOutstanding() after a configured wait time. */ - private synchronized void initiateRetry() { + private synchronized void initiateRetry(Throwable e) { + if (enableSaslRetries && e instanceof SaslTimeoutException) { + saslRetryCount += 1; + } retryCount += 1; currentListener = new RetryingBlockTransferListener(); @@ -203,16 +216,17 @@ private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException || e.getCause() instanceof IOException; boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException; - if (!isSaslTimeout && saslTimeoutSeen) { - retryCount = 0; - saslTimeoutSeen = false; + // If this is a non SASL request failure, reduce earlier SASL failures from retryCount + // since some subsequent SASL attempt was successful + if (!isSaslTimeout && saslRetryCount > 0) { + Preconditions.checkState(retryCount >= saslRetryCount, + "retryCount must be greater than or equal to saslRetryCount"); + retryCount -= saslRetryCount; + saslRetryCount = 0; } boolean hasRemainingRetries = retryCount < maxRetries; boolean shouldRetry = (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); - if (shouldRetry && isSaslTimeout) { - this.saslTimeoutSeen = true; - } return shouldRetry; } @@ -236,9 +250,13 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { outstandingBlocksIds.remove(blockId); shouldForwardSuccess = true; - if (saslTimeoutSeen) { - retryCount = 0; - saslTimeoutSeen = false; + // If there were SASL failures earlier, remove them from retryCount, as there was + // a SASL success (and some other request post bootstrap was also successful). + if (saslRetryCount > 0) { + Preconditions.checkState(retryCount >= saslRetryCount, + "retryCount must be greater than or equal to saslRetryCount"); + retryCount -= saslRetryCount; + saslRetryCount = 0; } } } @@ -256,7 +274,7 @@ private void handleBlockTransferFailure(String blockId, Throwable exception) { synchronized (RetryingBlockTransferor.this) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { if (shouldRetry(exception)) { - initiateRetry(); + initiateRetry(exception); } else { if (errorHandler.shouldLogError(exception)) { logger.error( diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 117a9ba08dfa5..31fe61841669c 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -58,10 +58,12 @@ public class RetryingBlockTransferorSuite { private static Map configMap; private static RetryingBlockTransferor _retryingBlockTransferor; + private static final int MAX_RETRIES = 2; + @Before public void initMap() { configMap = new HashMap() {{ - put("spark.shuffle.io.maxRetries", "2"); + put("spark.shuffle.io.maxRetries", Integer.toString(MAX_RETRIES)); put("spark.shuffle.io.retryWait", "0"); }}; } @@ -309,7 +311,7 @@ public void testRepeatedSaslRetryFailures() throws IOException, InterruptedExcep verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslTimeoutException); verify(listener, times(3)).getTransferType(); verifyNoMoreInteractions(listener); - assert(_retryingBlockTransferor.getRetryCount() == 2); + assert(_retryingBlockTransferor.getRetryCount() == MAX_RETRIES); } @Test @@ -341,6 +343,35 @@ public void testBlockTransferFailureAfterSasl() throws IOException, InterruptedE assert(_retryingBlockTransferor.getRetryCount() == 1); } + @Test + public void testIOExceptionFailsConnectionEvenWithSaslException() + throws IOException, InterruptedException { + BlockFetchingListener listener = mock(BlockFetchingListener.class); + + SaslTimeoutException saslExceptionInitial = new SaslTimeoutException("initial", + new TimeoutException()); + SaslTimeoutException saslExceptionFinal = new SaslTimeoutException("final", + new TimeoutException()); + IOException ioException = new IOException(); + List> interactions = Arrays.asList( + ImmutableMap.of("b0", saslExceptionInitial), + ImmutableMap.of("b0", ioException), + ImmutableMap.of("b0", saslExceptionInitial), + ImmutableMap.of("b0", ioException), + ImmutableMap.of("b0", saslExceptionFinal), + // will not get invoked because the connection fails + ImmutableMap.of("b0", ioException), + // will not get invoked + ImmutableMap.of("b0", block0) + ); + configMap.put("spark.shuffle.sasl.enableRetries", "true"); + performInteractions(interactions, listener); + verify(listener, timeout(5000)).onBlockTransferFailure("b0", saslExceptionFinal); + verify(listener, atLeastOnce()).getTransferType(); + verifyNoMoreInteractions(listener); + assert(_retryingBlockTransferor.getRetryCount() == MAX_RETRIES); + } + /** * Performs a set of interactions in response to block requests from a RetryingBlockFetcher. * Each interaction is a Map from BlockId to either ManagedBuffer or Exception. This interaction