From 7d2c47db0f3bd25c7386fbbf7a6793c527796146 Mon Sep 17 00:00:00 2001 From: onebox-li Date: Wed, 1 Nov 2023 10:23:50 +0800 Subject: [PATCH] [CELEBORN-1059] Fix callback not update if push worker excluded during retry ### What changes were proposed in this pull request? When retry push data and revive succeed in ShuffleClientImpl#submitRetryPushData, if new location is excluded, the callback's `lastest` location has not been updated when wrappedCallback.onFailure is called in ShuffleClientImpl#isPushTargetWorkerExcluded. Therefore there may be problems with subsequent revive. ### Why are the changes needed? Ditto ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manual test. Closes #2005 from onebox-li/improve-push-exclude. Authored-by: onebox-li Signed-off-by: zky.zhoukeyong (cherry picked from commit cd8acf89c968aad47ae3edcd5b63edc1a76721c7) Signed-off-by: zky.zhoukeyong --- .../celeborn/client/ShuffleClientImpl.java | 38 +++++++++---------- .../client/ChangePartitionManager.scala | 4 +- .../tests/spark/RetryReviveTest.scala | 2 +- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java index 85379a53572..c532d4f2af9 100644 --- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java +++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java @@ -272,6 +272,7 @@ private void submitRetryPushData( partitionId, batchId, newLoc); + pushDataRpcResponseCallback.updateLatestPartition(newLoc); try { if (!isPushTargetWorkerExcluded(newLoc, pushDataRpcResponseCallback)) { if (!testRetryRevive || remainReviveTimes < 1) { @@ -281,7 +282,6 @@ private void submitRetryPushData( String shuffleKey = Utils.makeShuffleKey(appUniqueId, shuffleId); PushData newPushData = new PushData(PRIMARY_MODE, shuffleKey, newLoc.getUniqueId(), newBuffer); - pushDataRpcResponseCallback.updateLatestPartition(newLoc); client.pushData(newPushData, pushDataTimeout, pushDataRpcResponseCallback); } else { throw new RuntimeException( @@ -633,18 +633,17 @@ boolean newerPartitionLocationExists( void excludeWorkerByCause(StatusCode cause, PartitionLocation oldLocation) { if (pushExcludeWorkerOnFailureEnabled && oldLocation != null) { - if (cause == StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY) { - pushExcludedWorkers.add(oldLocation.hostAndPushPort()); - } else if (cause == StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_PRIMARY) { - pushExcludedWorkers.add(oldLocation.hostAndPushPort()); - } else if (cause == StatusCode.PUSH_DATA_TIMEOUT_PRIMARY) { - pushExcludedWorkers.add(oldLocation.hostAndPushPort()); - } else if (cause == StatusCode.PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA) { - pushExcludedWorkers.add(oldLocation.getPeer().hostAndPushPort()); - } else if (cause == StatusCode.PUSH_DATA_CONNECTION_EXCEPTION_REPLICA) { - pushExcludedWorkers.add(oldLocation.getPeer().hostAndPushPort()); - } else if (cause == StatusCode.PUSH_DATA_TIMEOUT_REPLICA) { - pushExcludedWorkers.add(oldLocation.getPeer().hostAndPushPort()); + switch (cause) { + case PUSH_DATA_CREATE_CONNECTION_FAIL_PRIMARY: + case PUSH_DATA_CONNECTION_EXCEPTION_PRIMARY: + case PUSH_DATA_TIMEOUT_PRIMARY: + pushExcludedWorkers.add(oldLocation.hostAndPushPort()); + break; + case PUSH_DATA_CREATE_CONNECTION_FAIL_REPLICA: + case PUSH_DATA_CONNECTION_EXCEPTION_REPLICA: + case PUSH_DATA_TIMEOUT_REPLICA: + pushExcludedWorkers.add(oldLocation.getPeer().hostAndPushPort()); + break; } } } @@ -905,10 +904,10 @@ public void onFailure(Throwable e) { PartitionLocation latest = loc; @Override - public void updateLatestPartition(PartitionLocation latest) { - pushState.addBatch(nextBatchId, latest.hostAndPushPort()); + public void updateLatestPartition(PartitionLocation newloc) { + pushState.addBatch(nextBatchId, newloc.hostAndPushPort()); pushState.removeBatch(nextBatchId, this.latest.hostAndPushPort()); - this.latest = latest; + this.latest = newloc; } @Override @@ -1003,12 +1002,10 @@ public void onSuccess(ByteBuffer response) { @Override public void onFailure(Throwable e) { - StatusCode cause = getPushDataFailCause(e.getMessage()); - if (pushState.exception.get() != null) { return; } - + StatusCode cause = getPushDataFailCause(e.getMessage()); if (remainReviveTimes <= 0) { if (e instanceof CelebornIOException) { callback.onFailure(e); @@ -1383,11 +1380,10 @@ public void onSuccess(ByteBuffer response) { @Override public void onFailure(Throwable e) { - StatusCode cause = getPushDataFailCause(e.getMessage()); - if (pushState.exception.get() != null) { return; } + StatusCode cause = getPushDataFailCause(e.getMessage()); if (remainReviveTimes <= 0) { if (e instanceof CelebornIOException) { callback.onFailure(e); diff --git a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala index e7b1a45b2de..4edda41da55 100644 --- a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala +++ b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala @@ -67,6 +67,8 @@ class ChangePartitionManager( private var batchHandleChangePartition: Option[ScheduledFuture[_]] = _ + private val testRetryRevive = conf.testRetryRevive + def start(): Unit = { batchHandleChangePartition = batchHandleChangePartitionSchedulerThread.map { // noinspection ConvertExpressionToSAM @@ -204,7 +206,7 @@ class ChangePartitionManager( logWarning(s"Batch handle change partition for $changes") // Exclude all failed workers - if (changePartitions.exists(_.causes.isDefined)) { + if (changePartitions.exists(_.causes.isDefined) && !testRetryRevive) { changePartitions.filter(_.causes.isDefined).foreach { changePartition => lifecycleManager.workerStatusTracker.excludeWorkerFromPartition( shuffleId, diff --git a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala index 2ddc895d860..4d0b42ded61 100644 --- a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala +++ b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/RetryReviveTest.scala @@ -51,7 +51,7 @@ class RetryReviveTest extends AnyFunSuite .config(updateSparkConf(sparkConf, ShuffleMode.HASH)) .getOrCreate() val result = ss.sparkContext.parallelize(1 to 1000, 2) - .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(16).collect() + .map { i => (i, Range(1, 1000).mkString(",")) }.groupByKey(4).collect() assert(result.size == 1000) ss.stop() }