From 235919b2d9e9a2fd4fa8a0af245c21202283b753 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Wed, 3 Jun 2015 15:49:07 -0700 Subject: [PATCH] [SPARK-6980] Resolved conflicts after master merge --- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 8 +- .../spark/storage/BlockManagerMaster.scala | 31 ++++--- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 80 +++++++++++++++++++ 3 files changed, 104 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 34ea6103e4abb..8bba16874ca2c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -28,9 +28,11 @@ import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Add import akka.event.Logging.Error import akka.pattern.{ask => akkaAsk} import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} +import com.google.common.util.concurrent.MoreExecutors + import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} /** * A RpcEnv implementation based on Akka. @@ -293,8 +295,8 @@ private[akka] class AkkaRpcEndpointRef( } override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - import scala.concurrent.ExecutionContext.Implicits.global actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { + // The function will run in the calling thread, so it should be short and never block. case msg @ AkkaMessage(message, reply) => if (reply) { logError(s"Receive $msg but the sender cannot reply") @@ -304,7 +306,7 @@ private[akka] class AkkaRpcEndpointRef( } case AkkaFailure(e) => Future.failed(e) - }.mapTo[T] + }(ThreadUtils.sameThread).mapTo[T] } override def toString: String = s"${getClass.getSimpleName}($actorRef)" diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 12be1beccde1b..214c6e8f96bf6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,13 +17,14 @@ package org.apache.spark.storage -import scala.concurrent.Future -import scala.concurrent.ExecutionContext.Implicits.global +import scala.collection.Iterable +import scala.collection.generic.CanBuildFrom +import scala.concurrent.{Await, Future} import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.RpcUtils +import org.apache.spark.util.{ThreadUtils, RpcUtils} private[spark] class BlockManagerMaster( @@ -102,8 +103,8 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Int]]](RemoveRdd(rddId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}") - } + logWarning(s"Failed to remove RDD $rddId - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) } @@ -114,8 +115,8 @@ class BlockManagerMaster( val future = driverEndpoint.askWithRetry[Future[Seq[Boolean]]](RemoveShuffle(shuffleId)) future.onFailure { case e: Exception => - logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}") - } + logWarning(s"Failed to remove shuffle $shuffleId - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) } @@ -128,8 +129,8 @@ class BlockManagerMaster( future.onFailure { case e: Exception => logWarning(s"Failed to remove broadcast $broadcastId" + - s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}") - } + s" with removeFromMaster = $removeFromMaster - ${e.getMessage}}", e) + }(ThreadUtils.sameThread) if (blocking) { timeout.awaitResult(future) } @@ -169,11 +170,17 @@ class BlockManagerMaster( val response = driverEndpoint. askWithRetry[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) val (blockManagerIds, futures) = response.unzip - val result = timeout.awaitResult(Future.sequence(futures)) - if (result == null) { + implicit val sameThread = ThreadUtils.sameThread + val cbf = + implicitly[ + CanBuildFrom[Iterable[Future[Option[BlockStatus]]], + Option[BlockStatus], + Iterable[Option[BlockStatus]]]] + val blockStatus = timeout.awaitResult( + Future.sequence[Option[BlockStatus], Iterable](futures)(cbf, ThreadUtils.sameThread)) + if (blockStatus == null) { throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) } - val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => status.map { s => (blockManagerId, s) } }.toMap diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index a33a83db7bc9e..caa069a3f37ea 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -17,9 +17,21 @@ package org.apache.spark.rpc.akka +import java.util.concurrent.TimeoutException + +import scala.concurrent.Await +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.duration._ +import scala.util.{Success, Failure} +import scala.language.postfixOps + +import akka.actor.{ActorSystem, Actor, ActorRef, Props, Address} +import akka.pattern.ask + import org.apache.spark.rpc._ import org.apache.spark.{SecurityManager, SparkConf} + class AkkaRpcEnvSuite extends RpcEnvSuite { override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { @@ -47,4 +59,72 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { } } + test("Future failure with RpcTimeout") { + + class EchoActor extends Actor { + def receive: Receive = { + case msg => + Thread.sleep(500) + sender() ! msg + } + } + + val system = ActorSystem("EchoSystem") + val echoActor = system.actorOf(Props(new EchoActor), name = "echoA") + + val timeout = new RpcTimeout(50 millis, "spark.rpc.short.timeout") + + val fut = echoActor.ask("hello")(1000 millis).mapTo[String].recover { + case te: TimeoutException => throw timeout.amend(te) + } + + fut.onFailure { + case te: TimeoutException => println("failed with timeout exception") + } + + fut.onComplete { + case Success(str) => println("future success") + case Failure(ex) => println("future failure") + } + + println("sleeping") + Thread.sleep(50) + println("Future complete: " + fut.isCompleted.toString() + ", " + fut.value.toString()) + + println("Caught TimeoutException: " + + intercept[TimeoutException] { + //timeout.awaitResult(fut) // prints RpcTimeout description twice + Await.result(fut, 10 millis) + }.getMessage() + ) + + /* + val ref = env.setupEndpoint("test_future", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case _ => + } + }) + val conf = new SparkConf() + val newRpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + try { + val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_future") + val akkaActorRef = newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef + + val timeout = new RpcTimeout(1 millis, "spark.rpc.short.timeout") + val fut = akkaActorRef.ask("hello")(timeout.duration).mapTo[String] + + Thread.sleep(500) + println("Future complete: " + fut.isCompleted.toString() + ", " + fut.value.toString()) + + } finally { + newRpcEnv.shutdown() + } + */ + + + } + }