diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index deabf6f5c8c5f..b71b7fa517fd2 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -21,27 +21,41 @@ import java.lang.ref.{ReferenceQueue, WeakReference} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -/** Listener class used for testing when any item has been cleaned by the Cleaner class */ -private[spark] trait CleanerListener { - def rddCleaned(rddId: Int) - def shuffleCleaned(shuffleId: Int) -} +/** + * Classes that represent cleaning tasks. + */ +private sealed trait CleanupTask +private case class CleanRDD(rddId: Int) extends CleanupTask +private case class CleanShuffle(shuffleId: Int) extends CleanupTask +private case class CleanBroadcast(broadcastId: Long) extends CleanupTask /** - * Cleans RDDs and shuffle data. + * A WeakReference associated with a CleanupTask. + * + * When the referent object becomes only weakly reachable, the corresponding + * CleanupTaskWeakReference is automatically added to the given reference queue. + */ +private class CleanupTaskWeakReference( + val task: CleanupTask, + referent: AnyRef, + referenceQueue: ReferenceQueue[AnyRef]) + extends WeakReference(referent, referenceQueue) + +/** + * An asynchronous cleaner for RDD, shuffle, and broadcast state. + * + * This maintains a weak reference for each RDD, ShuffleDependency, and Broadcast of interest, + * to be processed when the associated object goes out of scope of the application. Actual + * cleanup is performed in a separate daemon thread. */ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { - /** Classes to represent cleaning tasks */ - private sealed trait CleanupTask - private case class CleanRDD(rddId: Int) extends CleanupTask - private case class CleanShuffle(shuffleId: Int) extends CleanupTask - // TODO: add CleanBroadcast + private val referenceBuffer = new ArrayBuffer[CleanupTaskWeakReference] + with SynchronizedBuffer[CleanupTaskWeakReference] - private val referenceBuffer = new ArrayBuffer[WeakReferenceWithCleanupTask] - with SynchronizedBuffer[WeakReferenceWithCleanupTask] private val referenceQueue = new ReferenceQueue[AnyRef] private val listeners = new ArrayBuffer[CleanerListener] @@ -49,77 +63,64 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private val cleaningThread = new Thread() { override def run() { keepCleaning() }} - private val REF_QUEUE_POLL_TIMEOUT = 100 - @volatile private var stopped = false - private class WeakReferenceWithCleanupTask(referent: AnyRef, val task: CleanupTask) - extends WeakReference(referent, referenceQueue) + /** Attach a listener object to get information of when objects are cleaned. */ + def attachListener(listener: CleanerListener) { + listeners += listener + } - /** Start the cleaner */ + /** Start the cleaner. */ def start() { cleaningThread.setDaemon(true) cleaningThread.setName("ContextCleaner") cleaningThread.start() } - /** Stop the cleaner */ + /** Stop the cleaner. */ def stop() { stopped = true cleaningThread.interrupt() } - /** - * Register a RDD for cleanup when it is garbage collected. - */ + /** Register a RDD for cleanup when it is garbage collected. */ def registerRDDForCleanup(rdd: RDD[_]) { registerForCleanup(rdd, CleanRDD(rdd.id)) } - /** - * Register a shuffle dependency for cleanup when it is garbage collected. - */ + /** Register a ShuffleDependency for cleanup when it is garbage collected. */ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) { registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId)) } - /** Cleanup RDD. */ - def cleanupRDD(rdd: RDD[_]) { - doCleanupRDD(rdd.id) - } - - /** Cleanup shuffle. */ - def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { - doCleanupShuffle(shuffleDependency.shuffleId) - } - - /** Attach a listener object to get information of when objects are cleaned. */ - def attachListener(listener: CleanerListener) { - listeners += listener + /** Register a Broadcast for cleanup when it is garbage collected. */ + def registerBroadcastForCleanup[T](broadcast: Broadcast[T]) { + registerForCleanup(broadcast, CleanBroadcast(broadcast.id)) } /** Register an object for cleanup. */ private def registerForCleanup(objectForCleanup: AnyRef, task: CleanupTask) { - referenceBuffer += new WeakReferenceWithCleanupTask(objectForCleanup, task) + referenceBuffer += new CleanupTaskWeakReference(task, objectForCleanup, referenceQueue) } - /** Keep cleaning RDDs and shuffle data */ + /** Keep cleaning RDD, shuffle, and broadcast state. */ private def keepCleaning() { - while (!isStopped) { + while (!stopped) { try { - val reference = Option(referenceQueue.remove(REF_QUEUE_POLL_TIMEOUT)) - .map(_.asInstanceOf[WeakReferenceWithCleanupTask]) + val reference = Option(referenceQueue.remove(ContextCleaner.REF_QUEUE_POLL_TIMEOUT)) + .map(_.asInstanceOf[CleanupTaskWeakReference]) reference.map(_.task).foreach { task => logDebug("Got cleaning task " + task) referenceBuffer -= reference.get task match { case CleanRDD(rddId) => doCleanupRDD(rddId) case CleanShuffle(shuffleId) => doCleanupShuffle(shuffleId) + case CleanBroadcast(broadcastId) => doCleanupBroadcast(broadcastId) } } } catch { case ie: InterruptedException => - if (!isStopped) logWarning("Cleaning thread interrupted") + if (!stopped) logWarning("Cleaning thread interrupted") case t: Throwable => logError("Error in cleaning thread", t) } } @@ -129,7 +130,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { private def doCleanupRDD(rddId: Int) { try { logDebug("Cleaning RDD " + rddId) - sc.unpersistRDD(rddId, false) + sc.unpersistRDD(rddId, blocking = false) listeners.foreach(_.rddCleaned(rddId)) logInfo("Cleaned RDD " + rddId) } catch { @@ -150,10 +151,46 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging { } } - private def mapOutputTrackerMaster = - sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + /** Perform broadcast cleanup. */ + private def doCleanupBroadcast(broadcastId: Long) { + try { + logDebug("Cleaning broadcast " + broadcastId) + broadcastManager.unbroadcast(broadcastId, removeFromDriver = true) + listeners.foreach(_.broadcastCleaned(broadcastId)) + logInfo("Cleaned broadcast " + broadcastId) + } catch { + case t: Throwable => logError("Error cleaning broadcast " + broadcastId, t) + } + } private def blockManagerMaster = sc.env.blockManager.master + private def broadcastManager = sc.env.broadcastManager + private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + // Used for testing + + def cleanupRDD(rdd: RDD[_]) { + doCleanupRDD(rdd.id) + } + + def cleanupShuffle(shuffleDependency: ShuffleDependency[_, _]) { + doCleanupShuffle(shuffleDependency.shuffleId) + } - private def isStopped = stopped + def cleanupBroadcast[T](broadcast: Broadcast[T]) { + doCleanupBroadcast(broadcast.id) + } +} + +private object ContextCleaner { + private val REF_QUEUE_POLL_TIMEOUT = 100 +} + +/** + * Listener class used for testing when any item has been cleaned by the Cleaner class. + */ +private[spark] trait CleanerListener { + def rddCleaned(rddId: Int) + def shuffleCleaned(shuffleId: Int) + def broadcastCleaned(broadcastId: Long) } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index e1a273593cce5..c45c5c90048f3 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -112,8 +112,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * Called from executors to get the server URIs and - * output sizes of the map outputs of a given shuffle + * Called from executors to get the server URIs and output sizes of the map outputs of + * a given shuffle. */ def getServerStatuses(shuffleId: Int, reduceId: Int): Array[(BlockManagerId, Long)] = { val statuses = mapStatuses.get(shuffleId).orNull @@ -218,10 +218,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) private var cacheEpoch = epoch /** - * Timestamp based HashMap for storing mapStatuses and cached serialized statuses - * in the master, so that statuses are dropped only by explicit deregistering or - * by TTL-based cleaning (if set). Other than these two - * scenarios, nothing should be dropped from this HashMap. + * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master, + * so that statuses are dropped only by explicit deregistering or by TTL-based cleaning (if set). + * Other than these two scenarios, nothing should be dropped from this HashMap. */ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 5fb7ff3c2c92d..13fba1e0dfe5d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -35,7 +35,6 @@ import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHad import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary -import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ @@ -230,6 +229,7 @@ class SparkContext( private[spark] val cleaner = new ContextCleaner(this) cleaner.start() + postEnvironmentUpdate() /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */ @@ -643,7 +643,11 @@ class SparkContext( * [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions. * The variable will be sent to each cluster only once. */ - def broadcast[T](value: T): Broadcast[T] = env.broadcastManager.newBroadcast[T](value, isLocal) + def broadcast[T](value: T) = { + val bc = env.broadcastManager.newBroadcast[T](value, isLocal) + cleaner.registerBroadcastForCleanup(bc) + bc + } /** * Add a file to be downloaded with this Spark job on every node. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 012b096e8801a..9ea123f174b95 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -185,6 +185,7 @@ object SparkEnv extends Logging { } else { new MapOutputTrackerWorker(conf) } + // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself mapOutputTracker.trackerActor = registerOrLookup( diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index e3c3a12d16f2a..b28e15a6840d9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -18,9 +18,8 @@ package org.apache.spark.broadcast import java.io.Serializable -import java.util.concurrent.atomic.AtomicLong -import org.apache.spark._ +import org.apache.spark.SparkException /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable @@ -51,49 +50,37 @@ import org.apache.spark._ * @tparam T Type of the data contained in the broadcast variable. */ abstract class Broadcast[T](val id: Long) extends Serializable { - def value: T - - // We cannot have an abstract readObject here due to some weird issues with - // readObject having to be 'private' in sub-classes. - - override def toString = "Broadcast(" + id + ")" -} - -private[spark] -class BroadcastManager(val _isDriver: Boolean, conf: SparkConf, securityManager: SecurityManager) - extends Logging with Serializable { - - private var initialized = false - private var broadcastFactory: BroadcastFactory = null - initialize() + protected var _isValid: Boolean = true - // Called by SparkContext or Executor before using Broadcast - private def initialize() { - synchronized { - if (!initialized) { - val broadcastFactoryClass = conf.get( - "spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + /** + * Whether this Broadcast is actually usable. This should be false once persisted state is + * removed from the driver. + */ + def isValid: Boolean = _isValid - broadcastFactory = - Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject - broadcastFactory.initialize(isDriver, conf, securityManager) + def value: T - initialized = true - } + /** + * Remove all persisted state associated with this broadcast on the executors. The next use + * of this broadcast on the executors will trigger a remote fetch. + */ + def unpersist() + + /** + * Remove all persisted state associated with this broadcast on both the executors and the + * driver. Overriding implementations should set isValid to false. + */ + private[spark] def destroy() + + /** + * If this broadcast is no longer valid, throw an exception. + */ + protected def assertValid() { + if (!_isValid) { + throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) } } - def stop() { - broadcastFactory.stop() - } - - private val nextBroadcastId = new AtomicLong(0) - - def newBroadcast[T](value_ : T, isLocal: Boolean) = - broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) - - def isDriver = _isDriver + override def toString = "Broadcast(" + id + ")" } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 6beecaeced5be..9ff1675e76a5e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -27,7 +27,8 @@ import org.apache.spark.SparkConf * entire Spark job. */ trait BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T] - def stop(): Unit + def unbroadcast(id: Long, removeFromDriver: Boolean) + def stop() } diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala new file mode 100644 index 0000000000000..c3ea16ff9eb5e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark._ + +private[spark] class BroadcastManager( + val isDriver: Boolean, + conf: SparkConf, + securityManager: SecurityManager) + extends Logging { + + private var initialized = false + private var broadcastFactory: BroadcastFactory = null + + initialize() + + // Called by SparkContext or Executor before using Broadcast + private def initialize() { + synchronized { + if (!initialized) { + val broadcastFactoryClass = + conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") + + broadcastFactory = + Class.forName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] + + // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory.initialize(isDriver, conf, securityManager) + + initialized = true + } + } + } + + def stop() { + broadcastFactory.stop() + } + + private val nextBroadcastId = new AtomicLong(0) + + def newBroadcast[T](value_ : T, isLocal: Boolean) = { + broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement()) + } + + def unbroadcast(id: Long, removeFromDriver: Boolean) { + broadcastFactory.unbroadcast(id, removeFromDriver) + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index e8eb04bb10469..ec5acf5f23f5f 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -17,14 +17,13 @@ package org.apache.spark.broadcast -import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} -import java.net.{URL, URLConnection, URI} +import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} +import java.net.{URI, URL, URLConnection} import java.util.concurrent.TimeUnit -import it.unimi.dsi.fastutil.io.FastBufferedInputStream -import it.unimi.dsi.fastutil.io.FastBufferedOutputStream +import it.unimi.dsi.fastutil.io.{FastBufferedInputStream, FastBufferedOutputStream} -import org.apache.spark.{SparkConf, HttpServer, Logging, SecurityManager, SparkEnv} +import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.io.CompressionCodec import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} @@ -32,19 +31,45 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + def value: T = { + assertValid() + value_ + } - def blockId = BroadcastBlockId(id) + val blockId = BroadcastBlockId(id) HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } if (!isLocal) { HttpBroadcast.write(id, value_) } - // Called by JVM when deserializing an object + /** + * Remove all persisted state associated with this HTTP broadcast on the executors. + */ + def unpersist() { + HttpBroadcast.unpersist(id, removeFromDriver = false) + } + + /** + * Remove all persisted state associated with this HTTP Broadcast on both the executors + * and the driver. + */ + private[spark] def destroy() { + _isValid = false + HttpBroadcast.unpersist(id, removeFromDriver = true) + } + + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assertValid() + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() HttpBroadcast.synchronized { @@ -54,7 +79,8 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea logInfo("Started reading broadcast variable " + id) val start = System.nanoTime value_ = HttpBroadcast.read[T](id) - SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) val time = (System.nanoTime - start) / 1e9 logInfo("Reading broadcast variable " + id + " took " + time + " s") } @@ -63,21 +89,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea } } -/** - * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. - */ -class HttpBroadcastFactory extends BroadcastFactory { - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new HttpBroadcast[T](value_, isLocal, id) - - def stop() { HttpBroadcast.stop() } -} - -private object HttpBroadcast extends Logging { +private[spark] object HttpBroadcast extends Logging { private var initialized = false private var broadcastDir: File = null @@ -136,8 +148,10 @@ private object HttpBroadcast extends Logging { logInfo("Broadcast server started at " + serverUri) } + def getFile(id: Long) = new File(broadcastDir, BroadcastBlockId(id).name) + def write(id: Long, value: Any) { - val file = new File(broadcastDir, BroadcastBlockId(id).name) + val file = getFile(id) val out: OutputStream = { if (compress) { compressionCodec.compressedOutputStream(new FileOutputStream(file)) @@ -160,7 +174,7 @@ private object HttpBroadcast extends Logging { if (securityManager.isAuthenticationEnabled()) { logDebug("broadcast security enabled") val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL().openConnection() + uc = newuri.toURL.openConnection() uc.setAllowUserInteraction(false) } else { logDebug("broadcast not using security") @@ -169,7 +183,7 @@ private object HttpBroadcast extends Logging { val in = { uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream(); + val inputStream = uc.getInputStream if (compress) { compressionCodec.compressedInputStream(inputStream) } else { @@ -183,20 +197,48 @@ private object HttpBroadcast extends Logging { obj } - def cleanup(cleanupTime: Long) { + /** + * Remove all persisted blocks associated with this HTTP broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver + * and delete the associated broadcast file. + */ + def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + if (removeFromDriver) { + val file = getFile(id) + files.remove(file.toString) + deleteBroadcastFile(file) + } + } + + /** + * Periodically clean up old broadcasts by removing the associated map entries and + * deleting the associated files. + */ + private def cleanup(cleanupTime: Long) { val iterator = files.internalMap.entrySet().iterator() while(iterator.hasNext) { val entry = iterator.next() val (file, time) = (entry.getKey, entry.getValue) if (time < cleanupTime) { - try { - iterator.remove() - new File(file.toString).delete() - logInfo("Deleted broadcast file '" + file + "'") - } catch { - case e: Exception => logWarning("Could not delete broadcast file '" + file + "'", e) + iterator.remove() + deleteBroadcastFile(new File(file.toString)) + } + } + } + + private def deleteBroadcastFile(file: File) { + try { + if (file.exists) { + if (file.delete()) { + logInfo("Deleted broadcast file: %s".format(file)) + } else { + logWarning("Could not delete broadcast file: %s".format(file)) } } + } catch { + case e: Exception => + logError("Exception while deleting broadcast file: %s".format(file), e) } } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala new file mode 100644 index 0000000000000..4affa922156c9 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[BroadcastFactory]] implementation that uses a HTTP server as the broadcast medium. + */ +class HttpBroadcastFactory extends BroadcastFactory { + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + HttpBroadcast.initialize(isDriver, conf, securityMgr) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new HttpBroadcast[T](value_, isLocal, id) + + def stop() { HttpBroadcast.stop() } + + /** + * Remove all persisted state associated with the HTTP broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unbroadcast(id: Long, removeFromDriver: Boolean) { + HttpBroadcast.unpersist(id, removeFromDriver) + } +} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 2595c15104e87..590caa9699dd3 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -17,24 +17,28 @@ package org.apache.spark.broadcast -import java.io._ +import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream} import scala.math import scala.util.Random -import org.apache.spark._ -import org.apache.spark.storage.{BroadcastBlockId, BroadcastHelperBlockId, StorageLevel} +import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException} +import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} import org.apache.spark.util.Utils private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) -extends Broadcast[T](id) with Logging with Serializable { + extends Broadcast[T](id) with Logging with Serializable { - def value = value_ + def value = { + assertValid() + value_ + } - def broadcastId = BroadcastBlockId(id) + val broadcastId = BroadcastBlockId(id) TorrentBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle(broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + SparkEnv.get.blockManager.putSingle( + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) } @transient var arrayOfBlocks: Array[TorrentBlock] = null @@ -46,32 +50,53 @@ extends Broadcast[T](id) with Logging with Serializable { sendBroadcast() } - def sendBroadcast() { - var tInfo = TorrentBroadcast.blockifyObject(value_) + /** + * Remove all persisted state associated with this Torrent broadcast on the executors. + */ + def unpersist() { + TorrentBroadcast.unpersist(id, removeFromDriver = false) + } + /** + * Remove all persisted state associated with this Torrent broadcast on both the executors + * and the driver. + */ + private[spark] def destroy() { + _isValid = false + TorrentBroadcast.unpersist(id, removeFromDriver = true) + } + + private def sendBroadcast() { + val tInfo = TorrentBroadcast.blockifyObject(value_) totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes hasBlocks = tInfo.totalBlocks // Store meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") val metaInfo = TorrentInfo(null, totalBlocks, totalBytes) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, true) + metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true) } // Store individual pieces for (i <- 0 until totalBlocks) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + i) + val pieceId = BroadcastBlockId(id, "piece" + i) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.putSingle( - pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, true) + pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true) } } } - // Called by JVM when deserializing an object + // Used by the JVM when serializing this object + private def writeObject(out: ObjectOutputStream) { + assertValid() + out.defaultWriteObject() + } + + // Used by the JVM when deserializing this object private def readObject(in: ObjectInputStream) { in.defaultReadObject() TorrentBroadcast.synchronized { @@ -86,18 +111,18 @@ extends Broadcast[T](id) with Logging with Serializable { // Initialize @transient variables that will receive garbage values from the master. resetWorkerVariables() - if (receiveBroadcast(id)) { + if (receiveBroadcast()) { value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - // Store the merged copy in cache so that the next worker doesn't need to rebuild it. - // This creates a tradeoff between memory usage and latency. - // Storing copy doubles the memory footprint; not storing doubles deserialization cost. + /* Store the merged copy in cache so that the next worker doesn't need to rebuild it. + * This creates a trade-off between memory usage and latency. Storing copy doubles + * the memory footprint; not storing doubles deserialization cost. */ SparkEnv.get.blockManager.putSingle( - broadcastId, value_, StorageLevel.MEMORY_AND_DISK, false) + broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) // Remove arrayOfBlocks from memory once value_ is on local cache resetWorkerVariables() - } else { + } else { logError("Reading broadcast variable " + id + " failed") } @@ -114,9 +139,9 @@ extends Broadcast[T](id) with Logging with Serializable { hasBlocks = 0 } - def receiveBroadcast(variableID: Long): Boolean = { + def receiveBroadcast(): Boolean = { // Receive meta-info - val metaId = BroadcastHelperBlockId(broadcastId, "meta") + val metaId = BroadcastBlockId(id, "meta") var attemptId = 10 while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { @@ -141,14 +166,14 @@ extends Broadcast[T](id) with Logging with Serializable { // Receive actual blocks val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList) for (pid <- recvOrder) { - val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) + val pieceId = BroadcastBlockId(id, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { case Some(x) => arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] hasBlocks += 1 SparkEnv.get.blockManager.putSingle( - pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) + pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true) case None => throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) @@ -156,16 +181,16 @@ extends Broadcast[T](id) with Logging with Serializable { } } - (hasBlocks == totalBlocks) + hasBlocks == totalBlocks } } -private object TorrentBroadcast -extends Logging { - +private[spark] object TorrentBroadcast extends Logging { + private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 private var initialized = false private var conf: SparkConf = null + def initialize(_isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests synchronized { @@ -179,39 +204,37 @@ extends Logging { initialized = false } - lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024 - def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) val bais = new ByteArrayInputStream(byteArray) - var blockNum = (byteArray.length / BLOCK_SIZE) + var blockNum = byteArray.length / BLOCK_SIZE if (byteArray.length % BLOCK_SIZE != 0) { blockNum += 1 } - var retVal = new Array[TorrentBlock](blockNum) - var blockID = 0 + val blocks = new Array[TorrentBlock](blockNum) + var blockId = 0 for (i <- 0 until (byteArray.length, BLOCK_SIZE)) { val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i) - var tempByteArray = new Array[Byte](thisBlockSize) - val hasRead = bais.read(tempByteArray, 0, thisBlockSize) + val tempByteArray = new Array[Byte](thisBlockSize) + bais.read(tempByteArray, 0, thisBlockSize) - retVal(blockID) = new TorrentBlock(blockID, tempByteArray) - blockID += 1 + blocks(blockId) = new TorrentBlock(blockId, tempByteArray) + blockId += 1 } bais.close() - val tInfo = TorrentInfo(retVal, blockNum, byteArray.length) - tInfo.hasBlocks = blockNum - - tInfo + val info = TorrentInfo(blocks, blockNum, byteArray.length) + info.hasBlocks = blockNum + info } - def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, - totalBlocks: Int): T = { + def unBlockifyObject[T]( + arrayOfBlocks: Array[TorrentBlock], + totalBytes: Int, + totalBlocks: Int): T = { val retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray, @@ -220,6 +243,13 @@ extends Logging { Utils.deserialize[T](retByteArray, Thread.currentThread.getContextClassLoader) } + /** + * Remove all persisted blocks associated with this torrent broadcast on the executors. + * If removeFromDriver is true, also remove these persisted blocks on the driver. + */ + def unpersist(id: Long, removeFromDriver: Boolean) = synchronized { + SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver) + } } private[spark] case class TorrentBlock( @@ -228,25 +258,10 @@ private[spark] case class TorrentBlock( extends Serializable private[spark] case class TorrentInfo( - @transient arrayOfBlocks : Array[TorrentBlock], + @transient arrayOfBlocks: Array[TorrentBlock], totalBlocks: Int, totalBytes: Int) extends Serializable { @transient var hasBlocks = 0 } - -/** - * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. - */ -class TorrentBroadcastFactory extends BroadcastFactory { - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - TorrentBroadcast.initialize(isDriver, conf) - } - - def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = - new TorrentBroadcast[T](value_, isLocal, id) - - def stop() { TorrentBroadcast.stop() } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala new file mode 100644 index 0000000000000..eabe792b550bb --- /dev/null +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.broadcast + +import org.apache.spark.{SecurityManager, SparkConf} + +/** + * A [[BroadcastFactory]] that creates a torrent-based implementation of broadcast. + */ +class TorrentBroadcastFactory extends BroadcastFactory { + + def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { + TorrentBroadcast.initialize(isDriver, conf) + } + + def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = + new TorrentBroadcast[T](value_, isLocal, id) + + def stop() { TorrentBroadcast.stop() } + + /** + * Remove all persisted state associated with the torrent broadcast with the given ID. + * @param removeFromDriver Whether to remove state from the driver. + */ + def unbroadcast(id: Long, removeFromDriver: Boolean) { + TorrentBroadcast.unpersist(id, removeFromDriver) + } +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 6b0a972f0bbe0..bdf586351ac14 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -17,7 +17,6 @@ package org.apache.spark.network -import java.net._ import java.nio._ import java.nio.channels._ import java.nio.channels.spi._ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e5638d0132e88..ea22ad29bc885 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -158,7 +158,7 @@ abstract class RDD[T: ClassTag]( */ def unpersist(blocking: Boolean = true): RDD[T] = { logInfo("Removing RDD " + id + " from persistence list") - sc.unpersistRDD(this.id, blocking) + sc.unpersistRDD(id, blocking) storageLevel = StorageLevel.NONE this } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d74f04a9f8625..88b06e7a2f84e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1090,7 +1090,6 @@ class DAGScheduler( eventProcessActor ! StopDAGScheduler } taskScheduler.stop() - listenerBus.stop() } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 301d784b350a3..cffea28fbf794 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -34,7 +34,7 @@ private[spark] sealed abstract class BlockId { def asRDDId = if (isRDD) Some(asInstanceOf[RDDBlockId]) else None def isRDD = isInstanceOf[RDDBlockId] def isShuffle = isInstanceOf[ShuffleBlockId] - def isBroadcast = isInstanceOf[BroadcastBlockId] || isInstanceOf[BroadcastHelperBlockId] + def isBroadcast = isInstanceOf[BroadcastBlockId] override def toString = name override def hashCode = name.hashCode @@ -48,18 +48,13 @@ private[spark] case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockI def name = "rdd_" + rddId + "_" + splitIndex } -private[spark] -case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId { +private[spark] case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) + extends BlockId { def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId } -private[spark] case class BroadcastBlockId(broadcastId: Long) extends BlockId { - def name = "broadcast_" + broadcastId -} - -private[spark] -case class BroadcastHelperBlockId(broadcastId: BroadcastBlockId, hType: String) extends BlockId { - def name = broadcastId.name + "_" + hType +private[spark] case class BroadcastBlockId(broadcastId: Long, field: String = "") extends BlockId { + def name = "broadcast_" + broadcastId + (if (field == "") "" else "_" + field) } private[spark] case class TaskResultBlockId(taskId: Long) extends BlockId { @@ -83,8 +78,7 @@ private[spark] case class TestBlockId(id: String) extends BlockId { private[spark] object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r - val BROADCAST = "broadcast_([0-9]+)".r - val BROADCAST_HELPER = "broadcast_([0-9]+)_([A-Za-z0-9]+)".r + val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r val TEST = "test_(.*)".r @@ -95,10 +89,8 @@ private[spark] object BlockId { RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => ShuffleBlockId(shuffleId.toInt, mapId.toInt, reduceId.toInt) - case BROADCAST(broadcastId) => - BroadcastBlockId(broadcastId.toLong) - case BROADCAST_HELPER(broadcastId, hType) => - BroadcastHelperBlockId(BroadcastBlockId(broadcastId.toLong), hType) + case BROADCAST(broadcastId, field) => + BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_")) case TASKRESULT(taskId) => TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala index c8f397609a0b4..ef924123a3b11 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfo.scala @@ -79,3 +79,5 @@ private object BlockInfo { private val BLOCK_PENDING: Long = -1L private val BLOCK_FAILED: Long = -2L } + +private[spark] case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ca23513c4dc64..925cee1eb6be7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props} import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} import sun.nio.ch.DirectBuffer -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException, MapOutputTracker} +import org.apache.spark._ import org.apache.spark.io.CompressionCodec import org.apache.spark.network._ import org.apache.spark.serializer.Serializer @@ -49,8 +49,8 @@ private[spark] class BlockManager( maxMemory: Long, val conf: SparkConf, securityManager: SecurityManager, - mapOutputTracker: MapOutputTracker - ) extends Logging { + mapOutputTracker: MapOutputTracker) + extends Logging { val shuffleBlockManager = new ShuffleBlockManager(this) val diskBlockManager = new DiskBlockManager(shuffleBlockManager, @@ -58,7 +58,7 @@ private[spark] class BlockManager( private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] - private[storage] val memoryStore: BlockStore = new MemoryStore(this, maxMemory) + private[storage] val memoryStore = new MemoryStore(this, maxMemory) private[storage] val diskStore = new DiskStore(this, diskBlockManager) // If we use Netty for shuffle, start a new Netty-based shuffle sender service. @@ -209,10 +209,14 @@ private[spark] class BlockManager( } } - /** - * Get storage level of local block. If no info exists for the block, then returns null. - */ - def getLevel(blockId: BlockId): StorageLevel = blockInfo.get(blockId).map(_.level).orNull + /** Get the BlockStatus for the block identified by the given ID, if it exists. */ + def getStatus(blockId: BlockId): Option[BlockStatus] = { + blockInfo.get(blockId).map { info => + val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L + val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L + BlockStatus(info.level, memSize, diskSize) + } + } /** * Tell the master about the current storage status of a block. This will send a block update @@ -496,9 +500,8 @@ private[spark] class BlockManager( /** * A short circuited method to get a block writer that can write data directly to disk. - * The Block will be appended to the File specified by filename. - * This is currently used for writing shuffle files out. Callers should handle error - * cases. + * The Block will be appended to the File specified by filename. This is currently used for + * writing shuffle files out. Callers should handle error cases. */ def getDiskWriter( blockId: BlockId, @@ -816,14 +819,24 @@ private[spark] class BlockManager( * @return The number of blocks removed. */ def removeRdd(rddId: Int): Int = { - // TODO: Instead of doing a linear scan on the blockInfo map, create another map that maps - // from RDD.id to blocks. + // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo("Removing RDD " + rddId) val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) - blocksToRemove.foreach(blockId => removeBlock(blockId, tellMaster = false)) + blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } + /** + * Remove all blocks belonging to the given broadcast. + */ + def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { + logInfo("Removing broadcast " + broadcastId) + val blocksToRemove = blockInfo.keys.collect { + case bid @ BroadcastBlockId(`broadcastId`, _) => bid + } + blocksToRemove.foreach { blockId => removeBlock(blockId, removeFromDriver) } + } + /** * Remove a block from both memory and disk. */ @@ -860,7 +873,7 @@ private[spark] class BlockManager( } private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)) { - val iterator = blockInfo.internalMap.entrySet().iterator() + val iterator = blockInfo.getEntrySet.iterator while (iterator.hasNext) { val entry = iterator.next() val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) @@ -884,7 +897,7 @@ private[spark] class BlockManager( def shouldCompress(blockId: BlockId): Boolean = blockId match { case ShuffleBlockId(_, _, _) => compressShuffle - case BroadcastBlockId(_) => compressBroadcast + case BroadcastBlockId(_, _) => compressBroadcast case RDDBlockId(_, _) => compressRdds case TempBlockId(_) => compressShuffleSpill case _ => false diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index ff3f22b3b092a..4e45bb8452fd8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -82,7 +82,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log /** * Check if block manager master has a block. Note that this can be used to check for only - * those blocks that are expected to be reported to block manager master. + * those blocks that are reported to block manager master. */ def contains(blockId: BlockId) = { !getLocations(blockId).isEmpty @@ -106,9 +106,7 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply(RemoveBlock(blockId)) } - /** - * Remove all blocks belonging to the given RDD. - */ + /** Remove all blocks belonging to the given RDD. */ def removeRdd(rddId: Int, blocking: Boolean) { val future = askDriverWithReply[Future[Seq[Int]]](RemoveRdd(rddId)) future onFailure { @@ -119,13 +117,16 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log } } - /** - * Remove all blocks belonging to the given shuffle. - */ + /** Remove all blocks belonging to the given shuffle. */ def removeShuffle(shuffleId: Int) { askDriverWithReply(RemoveShuffle(shuffleId)) } + /** Remove all blocks belonging to the given broadcast. */ + def removeBroadcast(broadcastId: Long, removeFromMaster: Boolean) { + askDriverWithReply(RemoveBroadcast(broadcastId, removeFromMaster)) + } + /** * Return the memory status for each block manager, in the form of a map from * the block manager's id to two long values. The first value is the maximum @@ -140,6 +141,34 @@ class BlockManagerMaster(var driverActor: ActorRef, conf: SparkConf) extends Log askDriverWithReply[Array[StorageStatus]](GetStorageStatus) } + /** + * Return the block's status on all block managers, if any. + * + * If askSlaves is true, this invokes the master to query each block manager for the most + * updated block statuses. This is useful when the master is not informed of the given block + * by all block managers. + */ + def getBlockStatus( + blockId: BlockId, + askSlaves: Boolean = true): Map[BlockManagerId, BlockStatus] = { + val msg = GetBlockStatus(blockId, askSlaves) + /* + * To avoid potential deadlocks, the use of Futures is necessary, because the master actor + * should not block on waiting for a block manager, which can in turn be waiting for the + * master actor for a response to a prior message. + */ + val response = askDriverWithReply[Map[BlockManagerId, Future[Option[BlockStatus]]]](msg) + val (blockManagerIds, futures) = response.unzip + val result = Await.result(Future.sequence(futures), timeout) + if (result == null) { + throw new SparkException("BlockManager returned null for BlockStatus query: " + blockId) + } + val blockStatus = result.asInstanceOf[Iterable[Option[BlockStatus]]] + blockManagerIds.zip(blockStatus).flatMap { case (blockManagerId, status) => + status.map { s => (blockManagerId, s) } + }.toMap + } + /** Stop the driver actor, called only on the Spark driver node */ def stop() { if (driverActor != null) { diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 646ccb7fa74f6..4159fc733a566 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -93,6 +93,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus case GetStorageStatus => sender ! storageStatus + case GetBlockStatus(blockId, askSlaves) => + sender ! blockStatus(blockId, askSlaves) + case RemoveRdd(rddId) => sender ! removeRdd(rddId) @@ -100,6 +103,10 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus removeShuffle(shuffleId) sender ! true + case RemoveBroadcast(broadcastId, removeFromDriver) => + removeBroadcast(broadcastId, removeFromDriver) + sender ! true + case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) sender ! true @@ -151,9 +158,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def removeShuffle(shuffleId: Int) { // Nothing to do in the BlockManagerMasterActor data structures val removeMsg = RemoveShuffle(shuffleId) - blockManagerInfo.values.foreach { bm => - bm.slaveActor ! removeMsg - } + blockManagerInfo.values.foreach { bm => bm.slaveActor ! removeMsg } + } + + /** + * Delegate RemoveBroadcast messages to each BlockManager because the master may not notified + * of all broadcast blocks. If removeFromDriver is false, broadcast blocks are only removed + * from the executors, but not from the driver. + */ + private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean) { + // TODO: Consolidate usages of + val removeMsg = RemoveBroadcast(broadcastId) + blockManagerInfo.values + .filter { info => removeFromDriver || info.blockManagerId.executorId != "" } + .foreach { bm => bm.slaveActor ! removeMsg } } private def removeBlockManager(blockManagerId: BlockManagerId) { @@ -236,6 +254,34 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus }.toArray } + /** + * Return the block's status for all block managers, if any. + * + * If askSlaves is true, the master queries each block manager for the most updated block + * statuses. This is useful when the master is not informed of the given block by all block + * managers. + */ + private def blockStatus( + blockId: BlockId, + askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { + import context.dispatcher + val getBlockStatus = GetBlockStatus(blockId) + /* + * Rather than blocking on the block status query, master actor should simply return + * Futures to avoid potential deadlocks. This can arise if there exists a block manager + * that is also waiting for this master actor's response to a previous message. + */ + blockManagerInfo.values.map { info => + val blockStatusFuture = + if (askSlaves) { + info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + } else { + Future { info.getStatus(blockId) } + } + (info.blockManagerId, blockStatusFuture) + }.toMap + } + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -321,9 +367,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus } } - -private[spark] case class BlockStatus(storageLevel: StorageLevel, memSize: Long, diskSize: Long) - private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, @@ -340,6 +383,8 @@ private[spark] class BlockManagerInfo( logInfo("Registering block manager %s with %s RAM".format( blockManagerId.hostPort, Utils.bytesToString(maxMem))) + def getStatus(blockId: BlockId) = Option(_blocks.get(blockId)) + def updateLastSeenMs() { _lastSeenMs = System.currentTimeMillis() } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 50ea4e31ce509..afb2c6a12ce67 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -35,7 +35,11 @@ private[storage] object BlockManagerMessages { case class RemoveRdd(rddId: Int) extends ToBlockManagerSlave // Remove all blocks belonging to a specific shuffle. - case class RemoveShuffle(shuffleId: Int) + case class RemoveShuffle(shuffleId: Int) extends ToBlockManagerSlave + + // Remove all blocks belonging to a specific broadcast. + case class RemoveBroadcast(broadcastId: Long, removeFromDriver: Boolean = true) + extends ToBlockManagerSlave ////////////////////////////////////////////////////////////////////////////////// @@ -57,8 +61,7 @@ private[storage] object BlockManagerMessages { var storageLevel: StorageLevel, var memSize: Long, var diskSize: Long) - extends ToBlockManagerMaster - with Externalizable { + extends ToBlockManagerMaster with Externalizable { def this() = this(null, null, null, 0, 0) // For deserialization only @@ -80,7 +83,8 @@ private[storage] object BlockManagerMessages { } object UpdateBlockInfo { - def apply(blockManagerId: BlockManagerId, + def apply( + blockManagerId: BlockManagerId, blockId: BlockId, storageLevel: StorageLevel, memSize: Long, @@ -106,7 +110,10 @@ private[storage] object BlockManagerMessages { case object GetMemoryStatus extends ToBlockManagerMaster - case object ExpireDeadHosts extends ToBlockManagerMaster - case object GetStorageStatus extends ToBlockManagerMaster + + case class GetBlockStatus(blockId: BlockId, askSlaves: Boolean = true) + extends ToBlockManagerMaster + + case object ExpireDeadHosts extends ToBlockManagerMaster } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index a6ff147c1d3e6..016ade428c68f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -29,8 +29,9 @@ import org.apache.spark.storage.BlockManagerMessages._ private[storage] class BlockManagerSlaveActor( blockManager: BlockManager, - mapOutputTracker: MapOutputTracker - ) extends Actor { + mapOutputTracker: MapOutputTracker) + extends Actor { + override def receive = { case RemoveBlock(blockId) => @@ -45,5 +46,11 @@ class BlockManagerSlaveActor( if (mapOutputTracker != null) { mapOutputTracker.unregisterShuffle(shuffleId) } + + case RemoveBroadcast(broadcastId, removeFromDriver) => + blockManager.removeBroadcast(broadcastId, removeFromDriver) + + case GetBlockStatus(blockId, _) => + sender ! blockManager.getStatus(blockId) } } diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index a57e6f710305a..fcad84669c79a 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -90,7 +90,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD def getFile(blockId: BlockId): File = getFile(blockId.name) - /** Check if disk block manager has a block */ + /** Check if disk block manager has a block. */ def containsBlock(blockId: BlockId): Boolean = { getBlockLocation(blockId).file.exists() } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index cf83a60ffb9e8..06233153c56d4 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -169,13 +169,13 @@ class ShuffleBlockManager(blockManager: BlockManager) extends Logging { throw new IllegalStateException("Failed to find shuffle block: " + id) } - /** Remove all the blocks / files and metadata related to a particular shuffle */ + /** Remove all the blocks / files and metadata related to a particular shuffle. */ def removeShuffle(shuffleId: ShuffleId) { removeShuffleBlocks(shuffleId) shuffleStates.remove(shuffleId) } - /** Remove all the blocks / files related to a particular shuffle */ + /** Remove all the blocks / files related to a particular shuffle. */ private def removeShuffleBlocks(shuffleId: ShuffleId) { shuffleStates.get(shuffleId) match { case Some(state) => diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 7b75215846a9a..a107c5182b3be 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -48,7 +48,7 @@ private[spark] object ThreadingTest { val block = (1 to blockSize).map(_ => Random.nextInt()) val level = randomLevel() val startTime = System.currentTimeMillis() - manager.put(blockId, block.iterator, level, true) + manager.put(blockId, block.iterator, level, tellMaster = true) println("Pushed block " + blockId + " in " + (System.currentTimeMillis - startTime) + " ms") queue.add((blockId, block)) } diff --git a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala b/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala deleted file mode 100644 index 888a06b2408c9..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/BoundedHashMap.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.{ArrayBuffer, SynchronizedMap} - -import java.util.{Collections, LinkedHashMap} -import java.util.Map.{Entry => JMapEntry} -import scala.reflect.ClassTag - -/** - * A map that upper bounds the number of key-value pairs present in it. It can be configured to - * drop the least recently user pair or the earliest inserted pair. It exposes a - * scala.collection.mutable.Map interface to allow it to be a drop-in replacement for Scala - * HashMaps. - * - * Internally, a Java LinkedHashMap is used to get insert-order or access-order behavior. - * Note that the LinkedHashMap is not thread-safe and hence, it is wrapped in a - * Collections.synchronizedMap. However, getting the Java HashMap's iterator and - * using it can still lead to ConcurrentModificationExceptions. Hence, the iterator() - * function is overridden to copy the all pairs into an ArrayBuffer and then return the - * iterator to the ArrayBuffer. Also, the class apply the trait SynchronizedMap which - * ensures that all calls to the Scala Map API are synchronized. This together ensures - * that ConcurrentModificationException is never thrown. - * - * @param bound max number of key-value pairs - * @param useLRU true = least recently used/accessed will be dropped when bound is reached, - * false = earliest inserted will be dropped - */ -private[spark] class BoundedHashMap[A, B](bound: Int, useLRU: Boolean) - extends WrappedJavaHashMap[A, B, A, B] with SynchronizedMap[A, B] { - - private[util] val internalJavaMap = Collections.synchronizedMap(new LinkedHashMap[A, B]( - bound / 8, (0.75).toFloat, useLRU) { - override protected def removeEldestEntry(eldest: JMapEntry[A, B]): Boolean = { - size() > bound - } - }) - - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new BoundedHashMap[K1, V1](bound, useLRU) - } - - /** - * Overriding iterator to make sure that the internal Java HashMap's iterator - * is not concurrently modified. This can be a performance issue and this should be overridden - * if it is known that this map will not be used in a multi-threaded environment. - */ - override def iterator: Iterator[(A, B)] = { - (new ArrayBuffer[(A, B)] ++= super.iterator).iterator - } -} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 346f2b7856791..c23b6b3944ba0 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -295,10 +295,8 @@ private[spark] object JsonProtocol { ("Map ID" -> shuffleBlockId.mapId) ~ ("Reduce ID" -> shuffleBlockId.reduceId) case broadcastBlockId: BroadcastBlockId => - "Broadcast ID" -> broadcastBlockId.broadcastId - case broadcastHelperBlockId: BroadcastHelperBlockId => - ("Broadcast Block ID" -> blockIdToJson(broadcastHelperBlockId.broadcastId)) ~ - ("Helper Type" -> broadcastHelperBlockId.hType) + ("Broadcast ID" -> broadcastBlockId.broadcastId) ~ + ("Field" -> broadcastBlockId.field) case taskResultBlockId: TaskResultBlockId => "Task ID" -> taskResultBlockId.taskId case streamBlockId: StreamBlockId => @@ -620,7 +618,6 @@ private[spark] object JsonProtocol { val rddBlockId = Utils.getFormattedClassName(RDDBlockId) val shuffleBlockId = Utils.getFormattedClassName(ShuffleBlockId) val broadcastBlockId = Utils.getFormattedClassName(BroadcastBlockId) - val broadcastHelperBlockId = Utils.getFormattedClassName(BroadcastHelperBlockId) val taskResultBlockId = Utils.getFormattedClassName(TaskResultBlockId) val streamBlockId = Utils.getFormattedClassName(StreamBlockId) val tempBlockId = Utils.getFormattedClassName(TempBlockId) @@ -638,12 +635,8 @@ private[spark] object JsonProtocol { new ShuffleBlockId(shuffleId, mapId, reduceId) case `broadcastBlockId` => val broadcastId = (json \ "Broadcast ID").extract[Long] - new BroadcastBlockId(broadcastId) - case `broadcastHelperBlockId` => - val broadcastBlockId = - blockIdFromJson(json \ "Broadcast Block ID").asInstanceOf[BroadcastBlockId] - val hType = (json \ "Helper Type").extract[String] - new BroadcastHelperBlockId(broadcastBlockId, hType) + val field = (json \ "Field").extract[String] + new BroadcastBlockId(broadcastId, field) case `taskResultBlockId` => val taskId = (json \ "Task ID").extract[Long] new TaskResultBlockId(taskId) diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 2ef853710a554..7ebed5105b9fd 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -78,15 +78,16 @@ private[spark] object MetadataCleaner { conf.getInt("spark.cleaner.ttl", -1) } - def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = - { - conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString) - .toInt + def getDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { + conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt } - def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType, - delay: Int) - { + def setDelaySeconds( + conf: SparkConf, + cleanerType: MetadataCleanerType.MetadataCleanerType, + delay: Int) { conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index c4d770fecdf74..5c239329588d8 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -17,80 +17,136 @@ package org.apache.spark.util +import java.util.Set +import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap +import scala.collection.{JavaConversions, mutable} + import org.apache.spark.Logging -private[util] case class TimeStampedValue[T](timestamp: Long, value: T) +private[spark] case class TimeStampedValue[V](value: V, timestamp: Long) /** - * A map that stores the timestamp of when a key was inserted along with the value. If specified, - * the timestamp of each pair can be updated every time it is accessed. - * Key-value pairs whose timestamps are older than a particular - * threshold time can then be removed using the clearOldValues method. It exposes a - * scala.collection.mutable.Map interface to allow it to be a drop-in replacement for Scala - * HashMaps. - * - * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * This is a custom implementation of scala.collection.mutable.Map which stores the insertion + * timestamp along with each key-value pair. If specified, the timestamp of each pair can be + * updated every time it is accessed. Key-value pairs whose timestamp are older than a particular + * threshold time can then be removed using the clearOldValues method. This is intended to + * be a drop-in replacement of scala.collection.mutable.HashMap. * - * @param updateTimeStampOnGet When enabled, the timestamp of a pair will be - * updated when it is accessed + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed */ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends WrappedJavaHashMap[A, B, A, TimeStampedValue[B]] with Logging { + extends mutable.Map[A, B]() with Logging { - private[util] val internalJavaMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() + private val internalMap = new ConcurrentHashMap[A, TimeStampedValue[B]]() + + def get(key: A): Option[B] = { + val value = internalMap.get(key) + if (value != null && updateTimeStampOnGet) { + internalMap.replace(key, value, TimeStampedValue(value.value, currentTime)) + } + Option(value).map(_.value) + } - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new TimeStampedHashMap[K1, V1]() + def iterator: Iterator[(A, B)] = { + val jIterator = getEntrySet.iterator + JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) } - def internalMap = internalJavaMap + def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet - override def get(key: A): Option[B] = { - val timeStampedValue = internalMap.get(key) - if (updateTimeStampOnGet && timeStampedValue != null) { - internalJavaMap.replace(key, timeStampedValue, - TimeStampedValue(currentTime, timeStampedValue.value)) - } - Option(timeStampedValue).map(_.value) + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { + val newMap = new TimeStampedHashMap[A, B1] + val oldInternalMap = this.internalMap.asInstanceOf[ConcurrentHashMap[A, TimeStampedValue[B1]]] + newMap.internalMap.putAll(oldInternalMap) + kv match { case (a, b) => newMap.internalMap.put(a, TimeStampedValue(b, currentTime)) } + newMap + } + + override def - (key: A): mutable.Map[A, B] = { + val newMap = new TimeStampedHashMap[A, B] + newMap.internalMap.putAll(this.internalMap) + newMap.internalMap.remove(key) + newMap } - @inline override protected def externalValueToInternalValue(v: B): TimeStampedValue[B] = { - new TimeStampedValue(currentTime, v) + + override def += (kv: (A, B)): this.type = { + kv match { case (a, b) => internalMap.put(a, TimeStampedValue(b, currentTime)) } + this + } + + override def -= (key: A): this.type = { + internalMap.remove(key) + this + } + + override def update(key: A, value: B) { + this += ((key, value)) + } + + override def apply(key: A): B = { + get(key).getOrElse { throw new NoSuchElementException() } } - @inline override protected def internalValueToExternalValue(iv: TimeStampedValue[B]): B = { - iv.value + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { + JavaConversions.mapAsScalaConcurrentMap(internalMap) + .map { case (k, TimeStampedValue(v, t)) => (k, v) } + .filter(p) } - /** Atomically put if a key is absent. This exposes the existing API of ConcurrentHashMap. */ + override def empty: mutable.Map[A, B] = new TimeStampedHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) { + val it = getEntrySet.iterator + while(it.hasNext) { + val entry = it.next() + val kv = (entry.getKey, entry.getValue.value) + f(kv) + } + } + + // Should we return previous value directly or as Option? def putIfAbsent(key: A, value: B): Option[B] = { - val prev = internalJavaMap.putIfAbsent(key, TimeStampedValue(currentTime, value)) + val prev = internalMap.putIfAbsent(key, TimeStampedValue(value, currentTime)) Option(prev).map(_.value) } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`, - * calling the supplied function on each such entry before removing. - */ + def putAll(map: Map[A, B]) { + map.foreach { case (k, v) => update(k, v) } + } + + def toMap: Map[A, B] = iterator.toMap + def clearOldValues(threshTime: Long, f: (A, B) => Unit) { - val iterator = internalJavaMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() + val it = getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() if (entry.getValue.timestamp < threshTime) { f(entry.getKey, entry.getValue.value) logDebug("Removing key " + entry.getKey) - iterator.remove() + it.remove() } } } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime` - */ + /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */ def clearOldValues(threshTime: Long) { clearOldValues(threshTime, (_, _) => ()) } - private def currentTime: Long = System.currentTimeMillis() + private def currentTime: Long = System.currentTimeMillis + + // For testing + + def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = { + Option(internalMap.get(key)) + } + + def getTimestamp(key: A): Option[Long] = { + getTimeStampedValue(key).map(_.timestamp) + } + } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index 09a6faf33ec60..b65017d6806c6 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -17,114 +17,154 @@ package org.apache.spark.util -import scala.collection.{JavaConversions, immutable} - -import java.util import java.lang.ref.WeakReference -import java.util.concurrent.ConcurrentHashMap - -import org.apache.spark.Logging import java.util.concurrent.atomic.AtomicInteger -private[util] case class TimeStampedWeakValue[T](timestamp: Long, weakValue: WeakReference[T]) { - def this(timestamp: Long, value: T) = this(timestamp, new WeakReference[T](value)) -} +import scala.collection.mutable + +import org.apache.spark.Logging /** - * A map that stores the timestamp of when a key was inserted along with the value, - * while ensuring that the values are weakly referenced. If the value is garbage collected and - * the weak reference is null, get() operation returns the key be non-existent. However, - * the key is actually not removed in the current implementation. Key-value pairs whose - * timestamps are older than a particular threshold time can then be removed using the - * clearOldValues method. It exposes a scala.collection.mutable.Map interface to allow it to be a - * drop-in replacement for Scala HashMaps. + * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. + * + * If the value is garbage collected and the weak reference is null, get() will return a + * non-existent value. These entries are removed from the map periodically (every N inserts), as + * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are + * older than a particular threshold can be removed using the clearOldValues method. + * + * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it + * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap, + * so all operations on this HashMap are thread-safe. * - * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. */ +private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) + extends mutable.Map[A, B]() with Logging { + + import TimeStampedWeakValueHashMap._ + + private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) + private val insertCount = new AtomicInteger(0) -private[spark] class TimeStampedWeakValueHashMap[A, B]() - extends WrappedJavaHashMap[A, B, A, TimeStampedWeakValue[B]] with Logging { + /** Return a map consisting only of entries whose values are still strongly reachable. */ + private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null } - /** Number of inserts after which keys whose weak ref values are null will be cleaned */ - private val CLEANUP_INTERVAL = 1000 + def get(key: A): Option[B] = internalMap.get(key) - /** Counter for counting the number of inserts */ - private val insertCounts = new AtomicInteger(0) + def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator - private[util] val internalJavaMap: util.Map[A, TimeStampedWeakValue[B]] = { - new ConcurrentHashMap[A, TimeStampedWeakValue[B]]() + override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { + val newMap = new TimeStampedWeakValueHashMap[A, B1] + val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]] + newMap.internalMap.putAll(oldMap.toMap) + newMap.internalMap += kv + newMap } - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new TimeStampedWeakValueHashMap[K1, V1]() + override def - (key: A): mutable.Map[A, B] = { + val newMap = new TimeStampedWeakValueHashMap[A, B] + newMap.internalMap.putAll(nonNullReferenceMap.toMap) + newMap.internalMap -= key + newMap } - override def +=(kv: (A, B)): this.type = { - // Cleanup null value at certain intervals - if (insertCounts.incrementAndGet() % CLEANUP_INTERVAL == 0) { - cleanNullValues() + override def += (kv: (A, B)): this.type = { + internalMap += kv + if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) { + clearNullValues() } - super.+=(kv) + this } - override def get(key: A): Option[B] = { - Option(internalJavaMap.get(key)).flatMap { weakValue => - val value = weakValue.weakValue.get - if (value == null) { - internalJavaMap.remove(key) + override def -= (key: A): this.type = { + internalMap -= key + this + } + + override def update(key: A, value: B) = this += ((key, value)) + + override def apply(key: A): B = internalMap.apply(key) + + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p) + + override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() + + override def size: Int = internalMap.size + + override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) + + def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) + + def toMap: Map[A, B] = iterator.toMap + + /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ + def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + + /** Remove entries with values that are no longer strongly reachable. */ + def clearNullValues() { + val it = internalMap.getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() + if (entry.getValue.value.get == null) { + logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.") + it.remove() } - Option(value) } } - @inline override protected def externalValueToInternalValue(v: B): TimeStampedWeakValue[B] = { - new TimeStampedWeakValue(currentTime, v) + // For testing + + def getTimestamp(key: A): Option[Long] = { + internalMap.getTimeStampedValue(key).map(_.timestamp) } - @inline override protected def internalValueToExternalValue(iv: TimeStampedWeakValue[B]): B = { - iv.weakValue.get + def getReference(key: A): Option[WeakReference[B]] = { + internalMap.getTimeStampedValue(key).map(_.value) } +} + +/** + * Helper methods for converting to and from WeakReferences. + */ +private object TimeStampedWeakValueHashMap { + + // Number of inserts after which entries with null references are removed + val CLEAR_NULL_VALUES_INTERVAL = 100 + + /* Implicit conversion methods to WeakReferences. */ - override def iterator: Iterator[(A, B)] = { - val iterator = internalJavaMap.entrySet().iterator() - JavaConversions.asScalaIterator(iterator).flatMap(kv => { - val (key, value) = (kv.getKey, kv.getValue.weakValue.get) - if (value != null) Seq((key, value)) else Seq.empty - }) + implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) + + implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = { + kv match { case (k, v) => (k, toWeakReference(v)) } } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`, - * calling the supplied function on each such entry before removing. - */ - def clearOldValues(threshTime: Long, f: (A, B) => Unit = null) { - val iterator = internalJavaMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue.timestamp < threshTime) { - val value = entry.getValue.weakValue.get - if (f != null && value != null) { - f(entry.getKey, value) - } - logDebug("Removing key " + entry.getKey) - iterator.remove() - } - } + implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = { + (kv: (K, WeakReference[V])) => p(kv) } - /** - * Removes keys whose weak referenced values have become null. - */ - private def cleanNullValues() { - val iterator = internalJavaMap.entrySet().iterator() - while (iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue.weakValue.get == null) { - logDebug("Removing key " + entry.getKey) - iterator.remove() - } + /* Implicit conversion methods from WeakReferences. */ + + implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get + + implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { + v match { + case Some(ref) => Option(fromWeakReference(ref)) + case None => None } } - private def currentTime = System.currentTimeMillis() + implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { + kv match { case (k, v) => (k, fromWeakReference(v)) } + } + + implicit def fromWeakReferenceIterator[K, V]( + it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = { + it.map(fromWeakReferenceTuple) + } + + implicit def fromWeakReferenceMap[K, V]( + map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { + mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 62ee704d580c2..0b07bdcf63b97 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -461,10 +461,10 @@ private[spark] object Utils extends Logging { private val hostPortParseResults = new ConcurrentHashMap[String, (String, Int)]() def parseHostPort(hostPort: String): (String, Int) = { - { - // Check cache first. - val cached = hostPortParseResults.get(hostPort) - if (cached != null) return cached + // Check cache first. + val cached = hostPortParseResults.get(hostPort) + if (cached != null) { + return cached } val indx: Int = hostPort.lastIndexOf(':') diff --git a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala b/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala deleted file mode 100644 index 6cc3007f5d7ac..0000000000000 --- a/core/src/main/scala/org/apache/spark/util/WrappedJavaHashMap.scala +++ /dev/null @@ -1,152 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.Map -import java.util.{Map => JMap} -import java.util.Map.{Entry => JMapEntry} -import scala.collection.{immutable, JavaConversions} -import scala.reflect.ClassTag - -/** - * Convenient wrapper class for exposing Java HashMaps as Scala Maps even if the - * exposed key-value type is different from the internal type. This allows these - * implementations of WrappedJavaHashMap to be drop-in replacements for Scala HashMaps. - * - * While Java <-> Scala conversion methods exists, its hard to understand the performance - * implications and thread safety of the Scala wrapper. This class allows you to convert - * between types and applying the necessary overridden methods to take care of performance. - * - * Note that the threading behavior of an implementation of WrappedJavaHashMap is tied to that of - * the internal Java HashMap used in the implementation. Each implementation must use - * necessary traits (e.g, scala.collection.mutable.SynchronizedMap), etc. to achieve the - * desired thread safety. - * - * @tparam K External key type - * @tparam V External value type - * @tparam IK Internal key type - * @tparam IV Internal value type - */ -private[spark] abstract class WrappedJavaHashMap[K, V, IK, IV] extends Map[K, V] { - - /* Methods that must be defined. */ - - /** - * Internal Java HashMap that is being wrapped. - * Scoped private[util] so that rest of Spark code cannot - * directly access the internal map. - */ - private[util] val internalJavaMap: JMap[IK, IV] - - /** Method to get a new instance of the internal Java HashMap. */ - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] - - /* - Methods that convert between internal and external types. These implementations - optimistically assume that the internal types are same as external types. These must - be overridden if the internal and external types are different. Otherwise there will be - runtime exceptions. - */ - - @inline protected def externalKeyToInternalKey(k: K): IK = { - k.asInstanceOf[IK] // works only if K is same or subclass of K - } - - @inline protected def externalValueToInternalValue(v: V): IV = { - v.asInstanceOf[IV] // works only if V is same or subclass of - } - - @inline protected def internalKeyToExternalKey(ik: IK): K = { - ik.asInstanceOf[K] - } - - @inline protected def internalValueToExternalValue(iv: IV): V = { - iv.asInstanceOf[V] - } - - @inline protected def internalPairToExternalPair(ip: JMapEntry[IK, IV]): (K, V) = { - (internalKeyToExternalKey(ip.getKey), internalValueToExternalValue(ip.getValue) ) - } - - /* Implicit methods to convert the types. */ - - @inline implicit private def convExtKeyToIntKey(k: K) = externalKeyToInternalKey(k) - - @inline implicit private def convExtValueToIntValue(v: V) = externalValueToInternalValue(v) - - @inline implicit private def convIntKeyToExtKey(ia: IK) = internalKeyToExternalKey(ia) - - @inline implicit private def convIntValueToExtValue(ib: IV) = internalValueToExternalValue(ib) - - @inline implicit private def convIntPairToExtPair(ip: JMapEntry[IK, IV]) = { - internalPairToExternalPair(ip) - } - - /* Methods that must be implemented for a scala.collection.mutable.Map */ - - def get(key: K): Option[V] = { - Option(internalJavaMap.get(key)) - } - - def iterator: Iterator[(K, V)] = { - val jIterator = internalJavaMap.entrySet().iterator() - JavaConversions.asScalaIterator(jIterator).map(kv => convIntPairToExtPair(kv)) - } - - /* Other methods that are implemented to ensure performance. */ - - def +=(kv: (K, V)): this.type = { - internalJavaMap.put(kv._1, kv._2) - this - } - - def -=(key: K): this.type = { - internalJavaMap.remove(key) - this - } - - override def + [V1 >: V](kv: (K, V1)): Map[K, V1] = { - val newMap = newInstance[K, V1]() - newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) - newMap += kv - newMap - } - - override def - (key: K): Map[K, V] = { - val newMap = newInstance[K, V]() - newMap.internalJavaMap.asInstanceOf[JMap[IK, IV]].putAll(this.internalJavaMap) - newMap -= key - } - - override def foreach[U](f: ((K, V)) => U) { - val jIterator = internalJavaMap.entrySet().iterator() - while(jIterator.hasNext) { - f(jIterator.next()) - } - } - - override def empty: Map[K, V] = newInstance[K, V]() - - override def size: Int = internalJavaMap.size - - override def filter(p: ((K, V)) => Boolean): Map[K, V] = { - newInstance[K, V]() ++= iterator.filter(p) - } - - def toMap: immutable.Map[K, V] = iterator.toMap -} diff --git a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala index 96ba3929c1685..f1bfb6666ddda 100644 --- a/core/src/test/scala/org/apache/spark/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/BroadcastSuite.scala @@ -19,68 +19,256 @@ package org.apache.spark import org.scalatest.FunSuite -class BroadcastSuite extends FunSuite with LocalSparkContext { +import org.apache.spark.storage._ +import org.apache.spark.broadcast.HttpBroadcast +import org.apache.spark.storage.BroadcastBlockId +class BroadcastSuite extends FunSuite with LocalSparkContext { - override def afterEach() { - super.afterEach() - System.clearProperty("spark.broadcast.factory") - } + private val httpConf = broadcastConf("HttpBroadcastFactory") + private val torrentConf = broadcastConf("TorrentBroadcastFactory") test("Using HttpBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing HttpBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing HttpBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.HttpBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", httpConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) } test("Using TorrentBroadcast locally") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === Set((1, 10), (2, 10))) + sc = new SparkContext("local", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === Set((1, 10), (2, 10))) } test("Accessing TorrentBroadcast variables from multiple threads") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - sc = new SparkContext("local[10]", "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to 10).map(x => (x, 10)).toSet) + sc = new SparkContext("local[10]", "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) } test("Accessing TorrentBroadcast variables in a local cluster") { - System.setProperty("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") val numSlaves = 4 - sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test") - val list = List(1, 2, 3, 4) - val listBroadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, listBroadcast.value.sum)) - assert(results.collect.toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", torrentConf) + val list = List[Int](1, 2, 3, 4) + val broadcast = sc.broadcast(list) + val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + + test("Unpersisting HttpBroadcast on executors only") { + testUnpersistHttpBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting HttpBroadcast on executors and driver") { + testUnpersistHttpBroadcast(2, removeFromDriver = true) + } + + test("Unpersisting TorrentBroadcast on executors only") { + testUnpersistTorrentBroadcast(2, removeFromDriver = false) + } + + test("Unpersisting TorrentBroadcast on executors and driver") { + testUnpersistTorrentBroadcast(2, removeFromDriver = true) } + /** + * Verify the persistence of state associated with an HttpBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks and the broadcast file + * are present only on the expected nodes. + */ + private def testUnpersistHttpBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id)) + + // Verify that the broadcast file is created, and blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") + } + assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!") + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head) + assert(statuses.size === numSlaves + 1) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. In the latter case, also verify that the broadcast file is deleted on the driver. + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + assert(blockIds.size === 1) + val statuses = bmm.getBlockStatus(blockIds.head) + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + val possiblyNot = if (removeFromDriver) "" else " not" + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) + assert(removeFromDriver === !HttpBroadcast.getFile(blockIds.head.broadcastId).exists, + "Broadcast file should%s be deleted".format(possiblyNot)) + } + + testUnpersistBroadcast(numSlaves, httpConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. + * + * This test creates a broadcast variable, uses it on all executors, and then unpersists it. + * In between each step, this test verifies that the broadcast blocks are present only on the + * expected nodes. + */ + private def testUnpersistTorrentBroadcast(numSlaves: Int, removeFromDriver: Boolean) { + def getBlockIds(id: Long) = { + val broadcastBlockId = BroadcastBlockId(id) + val metaBlockId = BroadcastBlockId(id, "meta") + // Assume broadcast value is small enough to fit into 1 piece + val pieceBlockId = BroadcastBlockId(id, "piece0") + Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId) + } + + // Verify that blocks are persisted only on the driver + def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockIds.head) + assert(statuses.size === 1) + statuses.head match { case (bm, status) => + assert(bm.executorId === "", "Block should only be on the driver") + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store on the driver") + assert(status.diskSize === 0, "Block should not be in disk store on the driver") + } + } + } + + // Verify that blocks are persisted in both the executors and the driver + def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockId) + if (blockId.field == "meta") { + // Meta data is only on the driver + assert(statuses.size === 1) + statuses.head match { case (bm, _) => assert(bm.executorId === "") } + } else { + // Other blocks are on both the executors and the driver + assert(statuses.size === numSlaves + 1) + statuses.foreach { case (_, status) => + assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) + assert(status.memSize > 0, "Block should be in memory store") + assert(status.diskSize === 0, "Block should not be in disk store") + } + } + } + } + + // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver + // is true. + def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) { + val expectedNumBlocks = if (removeFromDriver) 0 else 1 + val possiblyNot = if (removeFromDriver) "" else " not" + blockIds.foreach { blockId => + val statuses = bmm.getBlockStatus(blockId) + assert(statuses.size === expectedNumBlocks, + "Block should%s be unpersisted on the driver".format(possiblyNot)) + } + } + + testUnpersistBroadcast(numSlaves, torrentConf, getBlockIds, afterCreation, + afterUsingBroadcast, afterUnpersist, removeFromDriver) + } + + /** + * This test runs in 4 steps: + * + * 1) Create broadcast variable, and verify that all state is persisted on the driver. + * 2) Use the broadcast variable on all executors, and verify that all state is persisted + * on both the driver and the executors. + * 3) Unpersist the broadcast, and verify that all state is removed where they should be. + * 4) [Optional] If removeFromDriver is false, we verify that the broadcast is re-usable. + */ + private def testUnpersistBroadcast( + numSlaves: Int, + broadcastConf: SparkConf, + getBlockIds: Long => Seq[BroadcastBlockId], + afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit, + removeFromDriver: Boolean) { + + sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", broadcastConf) + val blockManagerMaster = sc.env.blockManager.master + val list = List[Int](1, 2, 3, 4) + + // Create broadcast variable + val broadcast = sc.broadcast(list) + val blocks = getBlockIds(broadcast.id) + afterCreation(blocks, blockManagerMaster) + + // Use broadcast variable on all executors + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + afterUsingBroadcast(blocks, blockManagerMaster) + + // Unpersist broadcast + if (removeFromDriver) { + broadcast.destroy() + } else { + broadcast.unpersist() + } + afterUnpersist(blocks, blockManagerMaster) + + // If the broadcast is removed from driver, all subsequent uses of the broadcast variable + // should throw SparkExceptions. Otherwise, the result should be the same as before. + if (removeFromDriver) { + // Using this variable on the executors crashes them, which hangs the test. + // Instead, crash the driver by directly accessing the broadcast value. + intercept[SparkException] { broadcast.value } + } else { + val results = sc.parallelize(1 to numSlaves, numSlaves).map(x => (x, broadcast.value.sum)) + assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) + } + } + + /** Helper method to create a SparkConf that uses the given broadcast factory. */ + private def broadcastConf(factoryName: String): SparkConf = { + val conf = new SparkConf + conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) + conf + } } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index b07f8817b7974..3d95547b20fc1 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark -import scala.collection.mutable.{ArrayBuffer, HashSet, SynchronizedSet} +import java.lang.ref.WeakReference + +import scala.collection.mutable.{HashSet, SynchronizedSet} +import scala.util.Random import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually @@ -25,10 +28,8 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkContext._ -import org.apache.spark.storage.{RDDBlockId, ShuffleBlockId} -import org.apache.spark.rdd.{ShuffleCoGroupSplitDep, RDD} -import scala.util.Random -import java.lang.ref.WeakReference +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.{BroadcastBlockId, RDDBlockId, ShuffleBlockId} class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { @@ -45,9 +46,9 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Explicit cleanup cleaner.cleanupRDD(rdd) - tester.assertCleanup + tester.assertCleanup() - // verify that RDDs can be re-executed after cleaning up + // Verify that RDDs can be re-executed after cleaning up assert(rdd.collect().toList === collected) } @@ -58,87 +59,101 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo // Explicit cleanup shuffleDeps.foreach(s => cleaner.cleanupShuffle(s)) - tester.assertCleanup + tester.assertCleanup() // Verify that shuffles can be re-executed after cleaning up assert(rdd.collect().toList === collected) } + test("cleanup broadcast") { + val broadcast = newBroadcast + val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + + // Explicit cleanup + cleaner.cleanupBroadcast(broadcast) + tester.assertCleanup() + } + test("automatically cleanup RDD") { var rdd = newRDD.persist() rdd.count() - - // test that GC does not cause RDD cleanup due to a strong reference + + // Test that GC does not cause RDD cleanup due to a strong reference val preGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) runGC() intercept[Exception] { - preGCTester.assertCleanup(timeout(1000 millis)) + preGCTester.assertCleanup()(timeout(1000 millis)) } - // test that GC causes RDD cleanup after dereferencing the RDD + // Test that GC causes RDD cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, rddIds = Seq(rdd.id)) - rdd = null // make RDD out of scope + rdd = null // Make RDD out of scope runGC() - postGCTester.assertCleanup + postGCTester.assertCleanup() } test("automatically cleanup shuffle") { var rdd = newShuffleRDD rdd.count() - // test that GC does not cause shuffle cleanup due to a strong reference - val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) + // Test that GC does not cause shuffle cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) runGC() intercept[Exception] { - preGCTester.assertCleanup(timeout(1000 millis)) + preGCTester.assertCleanup()(timeout(1000 millis)) } - // test that GC causes shuffle cleanup after dereferencing the RDD + // Test that GC causes shuffle cleanup after dereferencing the RDD val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) - rdd = null // make RDD out of scope, so that corresponding shuffle goes out of scope + rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope runGC() - postGCTester.assertCleanup + postGCTester.assertCleanup() } - test("automatically cleanup RDD + shuffle") { + test("automatically cleanup broadcast") { + var broadcast = newBroadcast - def randomRDD: RDD[_] = { - val rdd: RDD[_] = Random.nextInt(3) match { - case 0 => newRDD - case 1 => newShuffleRDD - case 2 => newPairRDD.join(newPairRDD) - } - if (Random.nextBoolean()) rdd.persist() - rdd.count() - rdd + // Test that GC does not cause broadcast cleanup due to a strong reference + val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + runGC() + intercept[Exception] { + preGCTester.assertCleanup()(timeout(1000 millis)) } - val buffer = new ArrayBuffer[RDD[_]] - for (i <- 1 to 500) { - buffer += randomRDD - } + // Test that GC causes broadcast cleanup after dereferencing the broadcast variable + val postGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) + broadcast = null // Make broadcast variable out of scope + runGC() + postGCTester.assertCleanup() + } + test("automatically cleanup RDD + shuffle + broadcast") { + val numRdds = 100 + val numBroadcasts = 4 // Broadcasts are more costly + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId + val broadcastIds = 0L until numBroadcasts - val preGCTester = new CleanerTester(sc, rddIds, shuffleIds) + val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) runGC() intercept[Exception] { - preGCTester.assertCleanup(timeout(1000 millis)) + preGCTester.assertCleanup()(timeout(1000 millis)) } - // test that GC causes shuffle cleanup after dereferencing the RDD - val postGCTester = new CleanerTester(sc, rddIds, shuffleIds) - buffer.clear() + + // Test that GC triggers the cleanup of all variables after the dereferencing them + val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) + broadcastBuffer.clear() + rddBuffer.clear() runGC() - postGCTester.assertCleanup + postGCTester.assertCleanup() } def newRDD = sc.makeRDD(1 to 10) - def newPairRDD = newRDD.map(_ -> 1) - def newShuffleRDD = newPairRDD.reduceByKey(_ + _) - + def newBroadcast = sc.broadcast(1 to 100) def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => @@ -148,11 +163,27 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val rdd = newShuffleRDD // Get all the shuffle dependencies - val shuffleDeps = getAllDependencies(rdd).filter(_.isInstanceOf[ShuffleDependency[_, _]]) + val shuffleDeps = getAllDependencies(rdd) + .filter(_.isInstanceOf[ShuffleDependency[_, _]]) .map(_.asInstanceOf[ShuffleDependency[_, _]]) (rdd, shuffleDeps) } + def randomRdd = { + val rdd: RDD[_] = Random.nextInt(3) match { + case 0 => newRDD + case 1 => newShuffleRDD + case 2 => newPairRDD.join(newPairRDD) + } + if (Random.nextBoolean()) rdd.persist() + rdd.count() + rdd + } + + def randomBroadcast = { + sc.broadcast(Random.nextInt(Int.MaxValue)) + } + /** Run GC and make sure it actually has run */ def runGC() { val weakRef = new WeakReference(new Object()) @@ -171,11 +202,16 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo /** Class to test whether RDDs, shuffles, etc. have been successfully cleaned. */ -class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[Int] = Nil) +class CleanerTester( + sc: SparkContext, + rddIds: Seq[Int] = Seq.empty, + shuffleIds: Seq[Int] = Seq.empty, + broadcastIds: Seq[Long] = Seq.empty) extends Logging { val toBeCleanedRDDIds = new HashSet[Int] with SynchronizedSet[Int] ++= rddIds val toBeCleanedShuffleIds = new HashSet[Int] with SynchronizedSet[Int] ++= shuffleIds + val toBeCleanedBroadcstIds = new HashSet[Long] with SynchronizedSet[Long] ++= broadcastIds val cleanerListener = new CleanerListener { def rddCleaned(rddId: Int): Unit = { @@ -187,6 +223,11 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In toBeCleanedShuffleIds -= shuffleId logInfo("Shuffle " + shuffleId + " cleaned") } + + def broadcastCleaned(broadcastId: Long): Unit = { + toBeCleanedBroadcstIds -= broadcastId + logInfo("Broadcast" + broadcastId + " cleaned") + } } val MAX_VALIDATION_ATTEMPTS = 10 @@ -197,7 +238,7 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In sc.cleaner.attachListener(cleanerListener) /** Assert that all the stuff has been cleaned up */ - def assertCleanup(implicit waitTimeout: Eventually.Timeout) { + def assertCleanup()(implicit waitTimeout: Eventually.Timeout) { try { eventually(waitTimeout, interval(10 millis)) { assert(isAllCleanedUp) @@ -211,7 +252,7 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In /** Verify that RDDs, shuffles, etc. occupy resources */ private def preCleanupValidate() { - assert(rddIds.nonEmpty || shuffleIds.nonEmpty, "Nothing to cleanup") + assert(rddIds.nonEmpty || shuffleIds.nonEmpty || broadcastIds.nonEmpty, "Nothing to cleanup") // Verify the RDDs have been persisted and blocks are present assert(rddIds.forall(sc.persistentRdds.contains), @@ -222,8 +263,12 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In // Verify the shuffle ids are registered and blocks are present assert(shuffleIds.forall(mapOutputTrackerMaster.containsShuffle), "One or more shuffles have not been registered cannot start cleaner test") - assert(shuffleIds.forall(shuffleId => diskBlockManager.containsBlock(shuffleBlockId(shuffleId))), + assert(shuffleIds.forall(sid => diskBlockManager.containsBlock(shuffleBlockId(sid))), "One or more shuffles' blocks cannot be found in disk manager, cannot start cleaner test") + + // Verify that the broadcast is in the driver's block manager + assert(broadcastIds.forall(bid => blockManager.getStatus(broadcastBlockId(bid)).isDefined), + "One ore more broadcasts have not been persisted in the driver's block manager") } /** @@ -236,14 +281,19 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In attempts += 1 logInfo("Attempt: " + attempts) try { - // Verify all the RDDs have been unpersisted + // Verify all RDDs have been unpersisted assert(rddIds.forall(!sc.persistentRdds.contains(_))) assert(rddIds.forall(rddId => !blockManager.master.contains(rddBlockId(rddId)))) - // Verify all the shuffle have been deregistered and cleaned up + // Verify all shuffles have been deregistered and cleaned up assert(shuffleIds.forall(!mapOutputTrackerMaster.containsShuffle(_))) - assert(shuffleIds.forall(shuffleId => - !diskBlockManager.containsBlock(shuffleBlockId(shuffleId)))) + assert(shuffleIds.forall(sid => !diskBlockManager.containsBlock(shuffleBlockId(sid)))) + + // Verify all broadcasts have been unpersisted + assert(broadcastIds.forall { bid => + blockManager.master.getBlockStatus(broadcastBlockId(bid)).isEmpty + }) + return } catch { case t: Throwable => @@ -260,18 +310,20 @@ class CleanerTester(sc: SparkContext, rddIds: Seq[Int] = Nil, shuffleIds: Seq[In s""" |\tRDDs = ${toBeCleanedRDDIds.mkString("[", ", ", "]")} |\tShuffles = ${toBeCleanedShuffleIds.mkString("[", ", ", "]")} + |\tBroadcasts = ${toBeCleanedBroadcstIds.mkString("[", ", ", "]")} """.stripMargin } - private def isAllCleanedUp = toBeCleanedRDDIds.isEmpty && toBeCleanedShuffleIds.isEmpty - - private def shuffleBlockId(shuffleId: Int) = ShuffleBlockId(shuffleId, 0, 0) + private def isAllCleanedUp = + toBeCleanedRDDIds.isEmpty && + toBeCleanedShuffleIds.isEmpty && + toBeCleanedBroadcstIds.isEmpty private def rddBlockId(rddId: Int) = RDDBlockId(rddId, 0) + private def shuffleBlockId(shuffleId: Int) = ShuffleBlockId(shuffleId, 0, 0) + private def broadcastBlockId(broadcastId: Long) = BroadcastBlockId(broadcastId) private def blockManager = sc.env.blockManager - private def diskBlockManager = blockManager.diskBlockManager - private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] -} \ No newline at end of file +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index b83033c35f6b7..6b2571cd9295e 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -96,7 +96,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { assert(tracker.getServerStatuses(10, 0).isEmpty) } - test("master register shuffle and unregister mapoutput and fetch") { + test("master register shuffle and unregister map output and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster(conf) tracker.trackerActor = diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 04e64ee7a45b3..b47de5eab95a4 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -28,8 +28,7 @@ import org.scalatest.concurrent.Timeouts._ import org.scalatest.matchers.ShouldMatchers._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf, SparkContext} -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, Utils} @@ -746,6 +745,46 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT assert(!store.get("list5").isDefined, "list5 was in store") } + test("query block statuses") { + store = new BlockManager("", actorSystem, master, serializer, 1200, conf, + securityMgr, mapOutputTracker) + val list = List.fill(2)(new Array[Byte](200)) + + // Tell master. By LRU, only list2 and list3 remains. + store.put("list1", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + store.put("list2", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = true) + store.put("list3", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = true) + + // getLocations and getBlockStatus should yield the same locations + assert(store.master.getLocations("list1").size === 0) + assert(store.master.getLocations("list2").size === 1) + assert(store.master.getLocations("list3").size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = false).size === 1) + assert(store.master.getBlockStatus("list1", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list2", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list3", askSlaves = true).size === 1) + + // This time don't tell master and see what happens. By LRU, only list5 and list6 remains. + store.put("list4", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + store.put("list5", list.iterator, StorageLevel.MEMORY_AND_DISK, tellMaster = false) + store.put("list6", list.iterator, StorageLevel.MEMORY_ONLY, tellMaster = false) + + // getLocations should return nothing because the master is not informed + // getBlockStatus without asking slaves should have the same result + // getBlockStatus with asking slaves, however, should return the actual block statuses + assert(store.master.getLocations("list4").size === 0) + assert(store.master.getLocations("list5").size === 0) + assert(store.master.getLocations("list6").size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list6", askSlaves = false).size === 0) + assert(store.master.getBlockStatus("list4", askSlaves = true).size === 0) + assert(store.master.getBlockStatus("list5", askSlaves = true).size === 1) + assert(store.master.getBlockStatus("list6", askSlaves = true).size === 1) + } + test("SPARK-1194 regression: fix the same-RDD rule for cache replacement") { store = new BlockManager("", actorSystem, master, serializer, 1200, conf, securityMgr, mapOutputTracker) diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 67c0a434c9b52..6bc8bcc036cb3 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -108,8 +108,7 @@ class JsonProtocolSuite extends FunSuite { // BlockId testBlockId(RDDBlockId(1, 2)) testBlockId(ShuffleBlockId(1, 2, 3)) - testBlockId(BroadcastBlockId(1L)) - testBlockId(BroadcastHelperBlockId(BroadcastBlockId(2L), "Spark")) + testBlockId(BroadcastBlockId(1L, "")) testBlockId(TaskResultBlockId(1L)) testBlockId(StreamBlockId(1, 2L)) testBlockId(TempBlockId(UUID.randomUUID())) @@ -556,4 +555,4 @@ class JsonProtocolSuite extends FunSuite { {"Event":"SparkListenerUnpersistRDD","RDD ID":12345} """ - } +} diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala new file mode 100644 index 0000000000000..6a5653ed2fb54 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.ref.WeakReference + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.scalatest.FunSuite + +class TimeStampedHashMapSuite extends FunSuite { + + // Test the testMap function - a Scala HashMap should obviously pass + testMap(new mutable.HashMap[String, String]()) + + // Test TimeStampedHashMap basic functionality + testMap(new TimeStampedHashMap[String, String]()) + testMapThreadSafety(new TimeStampedHashMap[String, String]()) + + // Test TimeStampedWeakValueHashMap basic functionality + testMap(new TimeStampedWeakValueHashMap[String, String]()) + testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]()) + + test("TimeStampedHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing weak references") { + var strongRef = new Object + val weakRef = new WeakReference(strongRef) + val map = new TimeStampedWeakValueHashMap[String, Object] + map("k1") = strongRef + map("k2") = "v2" + map("k3") = "v3" + assert(map("k1") === strongRef) + + // clear strong reference to "k1" + strongRef = null + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + assert(map.getReference("k1").isDefined) + val ref = map.getReference("k1").get + assert(ref.get === null) + assert(map.get("k1") === None) + + // operations should only display non-null entries + assert(map.iterator.forall { case (k, v) => k != "k1" }) + assert(map.filter { case (k, v) => k != "k2" }.size === 1) + assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3") + assert(map.toMap.size === 2) + assert(map.toMap.forall { case (k, v) => k != "k1" }) + val buffer = new ArrayBuffer[String] + map.foreach { case (k, v) => buffer += v.toString } + assert(buffer.size === 2) + assert(buffer.forall(_ != "k1")) + val plusMap = map + (("k4", "v4")) + assert(plusMap.size === 3) + assert(plusMap.forall { case (k, v) => k != "k1" }) + val minusMap = map - "k2" + assert(minusMap.size === 1) + assert(minusMap.head._1 == "k3") + + // clear null values - should only clear k1 + map.clearNullValues() + assert(map.getReference("k1") === None) + assert(map.get("k1") === None) + assert(map.get("k2").isDefined) + assert(map.get("k2").get === "v2") + } + + /** Test basic operations of a Scala mutable Map. */ + def testMap(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val testMap1 = newMap() + val testMap2 = newMap() + val name = testMap1.getClass.getSimpleName + + test(name + " - basic test") { + // put, get, and apply + testMap1 += (("k1", "v1")) + assert(testMap1.get("k1").isDefined) + assert(testMap1.get("k1").get === "v1") + testMap1("k2") = "v2" + assert(testMap1.get("k2").isDefined) + assert(testMap1.get("k2").get === "v2") + assert(testMap1("k2") === "v2") + testMap1.update("k3", "v3") + assert(testMap1.get("k3").isDefined) + assert(testMap1.get("k3").get === "v3") + + // remove + testMap1.remove("k1") + assert(testMap1.get("k1").isEmpty) + testMap1.remove("k2") + intercept[NoSuchElementException] { + testMap1("k2") // Map.apply() causes exception + } + testMap1 -= "k3" + assert(testMap1.get("k3").isEmpty) + + // multi put + val keys = (1 to 100).map(_.toString) + val pairs = keys.map(x => (x, x * 2)) + assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) + testMap2 ++= pairs + + // iterator + assert(testMap2.iterator.toSet === pairs.toSet) + + // filter + val filtered = testMap2.filter { case (_, v) => v.toInt % 2 == 0 } + val evenPairs = pairs.filter { case (_, v) => v.toInt % 2 == 0 } + assert(filtered.iterator.toSet === evenPairs.toSet) + + // foreach + val buffer = new ArrayBuffer[(String, String)] + testMap2.foreach(x => buffer += x) + assert(testMap2.toSet === buffer.toSet) + + // multi remove + testMap2("k1") = "v1" + testMap2 --= keys + assert(testMap2.size === 1) + assert(testMap2.iterator.toSeq.head === ("k1", "v1")) + + // + + val testMap3 = testMap2 + (("k0", "v0")) + assert(testMap3.size === 2) + assert(testMap3.get("k1").isDefined) + assert(testMap3.get("k1").get === "v1") + assert(testMap3.get("k0").isDefined) + assert(testMap3.get("k0").get === "v0") + + // - + val testMap4 = testMap3 - "k0" + assert(testMap4.size === 1) + assert(testMap4.get("k1").isDefined) + assert(testMap4.get("k1").get === "v1") + } + } + + /** Test thread safety of a Scala mutable map. */ + def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val name = newMap().getClass.getSimpleName + val testMap = newMap() + @volatile var error = false + + def getRandomKey(m: mutable.Map[String, String]): Option[String] = { + val keys = testMap.keysIterator.toSeq + if (keys.nonEmpty) { + Some(keys(Random.nextInt(keys.size))) + } else { + None + } + } + + val threads = (1 to 25).map(i => new Thread() { + override def run() { + try { + for (j <- 1 to 1000) { + Random.nextInt(3) match { + case 0 => + testMap(Random.nextString(10)) = Random.nextDouble().toString // put + case 1 => + getRandomKey(testMap).map(testMap.get) // get + case 2 => + getRandomKey(testMap).map(testMap.remove) // remove + } + } + } catch { + case t: Throwable => + error = true + throw t + } + } + }) + + test(name + " - threading safety test") { + threads.map(_.start) + threads.map(_.join) + assert(!error) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala deleted file mode 100644 index e446c7f75dc0b..0000000000000 --- a/core/src/test/scala/org/apache/spark/util/WrappedJavaHashMapSuite.scala +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util -import java.lang.ref.WeakReference - -import scala.collection.mutable.{ArrayBuffer, HashMap, Map} -import scala.util.Random - -import org.scalatest.FunSuite - -class WrappedJavaHashMapSuite extends FunSuite { - - // Test the testMap function - a Scala HashMap should obviously pass - testMap(new HashMap[String, String]()) - - // Test a simple WrappedJavaHashMap - testMap(new TestMap[String, String]()) - - // Test BoundedHashMap - testMap(new BoundedHashMap[String, String](100, true)) - - testMapThreadSafety(new BoundedHashMap[String, String](100, true)) - - // Test TimeStampedHashMap - testMap(new TimeStampedHashMap[String, String]) - - testMapThreadSafety(new TimeStampedHashMap[String, String]) - - test("TimeStampedHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedHashMap[String, String](false) - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis() - assert(map.internalMap.get("k1").timestamp < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - - // clearing by modification time - val map1 = new TimeStampedHashMap[String, String](true) - map1("k1") = "v1" - map1("k2") = "v2" - assert(map1("k1") === "v1") - Thread.sleep(10) - val threshTime1 = System.currentTimeMillis() - Thread.sleep(10) - assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime - assert(map1.internalMap.get("k1").timestamp < threshTime1) - assert(map1.internalMap.get("k2").timestamp >= threshTime1) - map1.clearOldValues(threshTime1) //should only clear k1 - assert(map1.get("k1") === None) - assert(map1.get("k2").isDefined) - } - - // Test TimeStampedHashMap - testMap(new TimeStampedWeakValueHashMap[String, String]) - - testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]) - - test("TimeStampedWeakValueHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedWeakValueHashMap[String, String]() - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis() - assert(map.internalJavaMap.get("k1").timestamp < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - } - - - test("TimeStampedWeakValueHashMap - get not returning null when weak reference is cleared") { - var strongRef = new Object - val weakRef = new WeakReference(strongRef) - val map = new TimeStampedWeakValueHashMap[String, Object] - - map("k1") = strongRef - assert(map("k1") === strongRef) - - strongRef = null - val startTime = System.currentTimeMillis - System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. - System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { - System.gc() - System.runFinalization() - Thread.sleep(100) - } - assert(map.internalJavaMap.get("k1").weakValue.get == null) - assert(map.get("k1") === None) - - // TODO (TD): Test clearing of null-value pairs - } - - def testMap(hashMapConstructor: => Map[String, String]) { - def newMap() = hashMapConstructor - - val name = newMap().getClass.getSimpleName - - test(name + " - basic test") { - val testMap1 = newMap() - - // put and get - testMap1 += (("k1", "v1")) - assert(testMap1.get("k1").get === "v1") - testMap1("k2") = "v2" - assert(testMap1.get("k2").get === "v2") - assert(testMap1("k2") === "v2") - - // remove - testMap1.remove("k1") - assert(testMap1.get("k1").isEmpty) - testMap1.remove("k2") - intercept[Exception] { - testMap1("k2") // Map.apply() causes exception - } - - // multi put - val keys = (1 to 100).map(_.toString) - val pairs = keys.map(x => (x, x * 2)) - val testMap2 = newMap() - assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) - testMap2 ++= pairs - - // iterator - assert(testMap2.iterator.toSet === pairs.toSet) - testMap2("k1") = "v1" - - // foreach - val buffer = new ArrayBuffer[(String, String)] - testMap2.foreach(x => buffer += x) - assert(testMap2.toSet === buffer.toSet) - - // multi remove - testMap2 --= keys - assert(testMap2.size === 1) - assert(testMap2.iterator.toSeq.head === ("k1", "v1")) - } - } - - def testMapThreadSafety(hashMapConstructor: => Map[String, String]) { - def newMap() = hashMapConstructor - - val name = newMap().getClass.getSimpleName - val testMap = newMap() - @volatile var error = false - - def getRandomKey(m: Map[String, String]): Option[String] = { - val keys = testMap.keysIterator.toSeq - if (keys.nonEmpty) { - Some(keys(Random.nextInt(keys.size))) - } else { - None - } - } - - val threads = (1 to 100).map(i => new Thread() { - override def run() { - try { - for (j <- 1 to 1000) { - Random.nextInt(3) match { - case 0 => - testMap(Random.nextString(10)) = Random.nextDouble.toString // put - case 1 => - getRandomKey(testMap).map(testMap.get) // get - case 2 => - getRandomKey(testMap).map(testMap.remove) // remove - } - } - } catch { - case t : Throwable => - error = true - throw t - } - } - }) - - test(name + " - threading safety test") { - threads.map(_.start) - threads.map(_.join) - assert(!error) - } - } -} - -class TestMap[A, B] extends WrappedJavaHashMap[A, B, A, B] { - private[util] val internalJavaMap: util.Map[A, B] = new util.HashMap[A, B]() - - private[util] def newInstance[K1, V1](): WrappedJavaHashMap[K1, V1, _, _] = { - new TestMap[K1, V1] - } -}