diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java index 9462cbac82b3f..dd2fdb08ee5bf 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -294,15 +294,6 @@ public void onFailure(Throwable e) { } } - /** - * Exception thrown when sasl request times out. - */ - public static class SaslTimeoutException extends RuntimeException { - public SaslTimeoutException(Throwable cause) { - super((cause)); - } - } - /** * Sends an opaque message to the RpcHandler on the server-side. No reply is expected for the * message, and no delivery guarantees are made. diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java index 334a849e83b91..69baaca8a2614 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslClientBootstrap.java @@ -17,7 +17,6 @@ package org.apache.spark.network.sasl; -import com.google.common.base.Throwables; import java.io.IOException; import java.nio.ByteBuffer; import java.util.concurrent.TimeoutException; @@ -74,7 +73,7 @@ public void doBootstrap(TransportClient client, Channel channel) { } catch (RuntimeException ex) { // We know it is a Sasl timeout here if it is a TimeoutException. if (ex.getCause() instanceof TimeoutException) { - throw Throwables.propagate(new TransportClient.SaslTimeoutException(ex.getCause())); + throw new SaslTimeoutException(ex.getCause()); } else { throw ex; } diff --git a/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java new file mode 100644 index 0000000000000..ecdd764d41af9 --- /dev/null +++ b/common/network-common/src/main/java/org/apache/spark/network/sasl/SaslTimeoutException.java @@ -0,0 +1,15 @@ +package org.apache.spark.network.sasl; + +public class SaslTimeoutException extends RuntimeException { + public SaslTimeoutException(Throwable cause) { + super(cause); + } + + public SaslTimeoutException(String message) { + super(message); + } + + public SaslTimeoutException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java index 7e23d9aa8be98..d41b7ef1da6b3 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RetryingBlockTransferor.java @@ -19,14 +19,16 @@ import java.io.IOException; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedHashSet; +import java.util.Map; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import com.google.common.collect.Sets; import com.google.common.util.concurrent.Uninterruptibles; -import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SaslTimeoutException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -85,6 +87,12 @@ void createAndStart(String[] blockIds, BlockTransferListener listener) // while inside a synchronized block. /** Number of times we've attempted to retry so far. */ private int retryCount = 0; + /** + * Map to track blockId to exception that the block is being retried for. + * This is mainly used in the case of SASL retries, because we need to set + * `retryCount` back to 0 in those cases. + */ + private Map blockIdToException; /** * Set of all block ids which have not been transferred successfully or with a non-IO Exception. @@ -120,6 +128,7 @@ public RetryingBlockTransferor( this.currentListener = new RetryingBlockTransferListener(); this.errorHandler = errorHandler; this.enableSaslRetries = conf.enableSaslRetries(); + this.blockIdToException = new HashMap(); } public RetryingBlockTransferor( @@ -197,9 +206,7 @@ private synchronized void initiateRetry() { private synchronized boolean shouldRetry(Throwable e) { boolean isIOException = e instanceof IOException || e.getCause() instanceof IOException; - boolean isSaslTimeout = enableSaslRetries && - (e instanceof TransportClient.SaslTimeoutException || - (e.getCause() != null && e.getCause() instanceof TransportClient.SaslTimeoutException)); + boolean isSaslTimeout = enableSaslRetries && e instanceof SaslTimeoutException; boolean hasRemainingRetries = retryCount < maxRetries; return (isSaslTimeout || isIOException) && hasRemainingRetries && errorHandler.shouldRetryError(e); @@ -220,6 +227,10 @@ private void handleBlockTransferSuccess(String blockId, ManagedBuffer data) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { outstandingBlocksIds.remove(blockId); shouldForwardSuccess = true; + if (blockIdToException.containsKey(blockId) && + blockIdToException.get(blockId) instanceof SaslTimeoutException) { + retryCount = 0; + } } } @@ -236,6 +247,7 @@ private void handleBlockTransferFailure(String blockId, Throwable exception) { synchronized (RetryingBlockTransferor.this) { if (this == currentListener && outstandingBlocksIds.contains(blockId)) { if (shouldRetry(exception)) { + blockIdToException.putIfAbsent(blockId, exception); initiateRetry(); } else { if (errorHandler.shouldLogError(exception)) { diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java index 815d06ccad188..0fea1aef1b8a2 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockTransferorSuite.java @@ -29,7 +29,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Sets; import java.util.concurrent.TimeoutException; -import org.apache.spark.network.client.TransportClient; +import org.apache.spark.network.sasl.SaslTimeoutException; import org.junit.Before; import org.junit.Test; import org.mockito.stubbing.Answer; @@ -247,8 +247,8 @@ public void testRetryAndUnrecoverable() throws IOException, InterruptedException public void testSaslTimeoutFailure() throws IOException, InterruptedException { BlockFetchingListener listener = mock(BlockFetchingListener.class); TimeoutException timeoutException = new TimeoutException(); - TransportClient.SaslTimeoutException saslTimeoutException = - new TransportClient.SaslTimeoutException(timeoutException); + SaslTimeoutException saslTimeoutException = + new SaslTimeoutException(timeoutException); List> interactions = Arrays.asList( ImmutableMap.builder() .put("b0", saslTimeoutException) @@ -272,7 +272,7 @@ public void testRetryOnSaslTimeout() throws IOException, InterruptedException { List> interactions = Arrays.asList( // SaslTimeout will cause a retry. Since b0 fails, we will retry both. ImmutableMap.builder() - .put("b0", new TransportClient.SaslTimeoutException(new TimeoutException())) + .put("b0", new SaslTimeoutException(new TimeoutException())) .build(), ImmutableMap.builder() .put("b0", block0)