diff --git a/assembly/pom.xml b/assembly/pom.xml
index b2a9d0780ee2b..594fa0c779e1b 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -142,8 +142,10 @@
com/google/common/base/Absent*
+ com/google/common/base/Function
com/google/common/base/Optional*
com/google/common/base/Present*
+ com/google/common/base/Supplier
diff --git a/bin/spark-class b/bin/spark-class
index 1b945461fabc8..2f0441bb3c1c2 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -29,6 +29,7 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
+export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}"
. "$FWDIR"/bin/load-spark-env.sh
@@ -120,8 +121,8 @@ fi
JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
-if [ -e "$FWDIR/conf/java-opts" ] ; then
- JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`"
+if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then
+ JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`"
fi
# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala!
diff --git a/build/mvn b/build/mvn
index 43471f83e904c..f91e2b4bdcc02 100755
--- a/build/mvn
+++ b/build/mvn
@@ -68,10 +68,10 @@ install_app() {
# Install maven under the build/ folder
install_mvn() {
install_app \
- "http://apache.claz.org/maven/maven-3/3.2.3/binaries" \
- "apache-maven-3.2.3-bin.tar.gz" \
- "apache-maven-3.2.3/bin/mvn"
- MVN_BIN="${_DIR}/apache-maven-3.2.3/bin/mvn"
+ "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \
+ "apache-maven-3.2.5-bin.tar.gz" \
+ "apache-maven-3.2.5/bin/mvn"
+ MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn"
}
# Install zinc under the build/ folder
diff --git a/core/pom.xml b/core/pom.xml
index d9a49c9e08afc..1984682b9c099 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -372,8 +372,10 @@
com.google.guava:guava
com/google/common/base/Absent*
+ com/google/common/base/Function
com/google/common/base/Optional*
com/google/common/base/Present*
+ com/google/common/base/Supplier
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index a1f7133f897ee..f23ba9dba167f 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -190,6 +190,7 @@ span.additional-metric-title {
/* Hide all additional metrics by default. This is done here rather than using JavaScript to
* avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
-.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time {
+.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time,
+.getting_result_time {
display: none;
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index d4f2624061e35..419d093d55643 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -118,15 +118,17 @@ trait Logging {
// org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently
// org.apache.logging.slf4j.Log4jLoggerFactory
val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass)
- val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
- if (!log4j12Initialized && usingLog4j12) {
- val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
- case Some(url) =>
- PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
- case None =>
- System.err.println(s"Spark was unable to load $defaultLogProps")
+ if (usingLog4j12) {
+ val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
+ if (!log4j12Initialized) {
+ val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ case None =>
+ System.err.println(s"Spark was unable to load $defaultLogProps")
+ }
}
}
Logging.initialized = true
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index f9d4aa4240e9d..cd91c8f87547b 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -17,9 +17,11 @@
package org.apache.spark
+import java.util.concurrent.ConcurrentHashMap
+
import scala.collection.JavaConverters._
-import scala.collection.concurrent.TrieMap
-import scala.collection.mutable.{HashMap, LinkedHashSet}
+import scala.collection.mutable.LinkedHashSet
+
import org.apache.spark.serializer.KryoSerializer
/**
@@ -47,12 +49,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Create a SparkConf that loads defaults from system properties and the classpath */
def this() = this(true)
- private[spark] val settings = new TrieMap[String, String]()
+ private val settings = new ConcurrentHashMap[String, String]()
if (loadDefaults) {
// Load any spark.* system properties
for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) {
- settings(k) = v
+ set(k, v)
}
}
@@ -64,7 +66,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
if (value == null) {
throw new NullPointerException("null value for " + key)
}
- settings(key) = value
+ settings.put(key, value)
this
}
@@ -130,15 +132,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Set multiple parameters together */
def setAll(settings: Traversable[(String, String)]) = {
- this.settings ++= settings
+ this.settings.putAll(settings.toMap.asJava)
this
}
/** Set a parameter if it isn't already configured */
def setIfMissing(key: String, value: String): SparkConf = {
- if (!settings.contains(key)) {
- settings(key) = value
- }
+ settings.putIfAbsent(key, value)
this
}
@@ -164,21 +164,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Get a parameter; throws a NoSuchElementException if it's not set */
def get(key: String): String = {
- settings.getOrElse(key, throw new NoSuchElementException(key))
+ getOption(key).getOrElse(throw new NoSuchElementException(key))
}
/** Get a parameter, falling back to a default if not set */
def get(key: String, defaultValue: String): String = {
- settings.getOrElse(key, defaultValue)
+ getOption(key).getOrElse(defaultValue)
}
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
- settings.get(key)
+ Option(settings.get(key))
}
/** Get all parameters as a list of pairs */
- def getAll: Array[(String, String)] = settings.toArray
+ def getAll: Array[(String, String)] = {
+ settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray
+ }
/** Get a parameter as an integer, falling back to a default if not set */
def getInt(key: String, defaultValue: Int): Int = {
@@ -225,11 +227,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getAppId: String = get("spark.app.id")
/** Does the configuration contain a given parameter? */
- def contains(key: String): Boolean = settings.contains(key)
+ def contains(key: String): Boolean = settings.containsKey(key)
/** Copy this object */
override def clone: SparkConf = {
- new SparkConf(false).setAll(settings)
+ new SparkConf(false).setAll(getAll)
}
/**
@@ -241,7 +243,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Checks for illegal or deprecated config settings. Throws an exception for the former. Not
* idempotent - may mutate this conf object to convert deprecated settings to supported ones. */
private[spark] def validateSettings() {
- if (settings.contains("spark.local.dir")) {
+ if (contains("spark.local.dir")) {
val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " +
"the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)."
logWarning(msg)
@@ -266,7 +268,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
}
// Validate spark.executor.extraJavaOptions
- settings.get(executorOptsKey).map { javaOpts =>
+ getOption(executorOptsKey).map { javaOpts =>
if (javaOpts.contains("-Dspark")) {
val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " +
"Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit."
@@ -346,7 +348,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
* configuration out for debugging.
*/
def toDebugString: String = {
- settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
+ getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 6a354ed4d1486..4c4ee04cc515e 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -85,6 +85,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
val startTime = System.currentTimeMillis()
+ @volatile private var stopped: Boolean = false
+
+ private def assertNotStopped(): Unit = {
+ if (stopped) {
+ throw new IllegalStateException("Cannot call methods on a stopped SparkContext")
+ }
+ }
+
/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
@@ -525,6 +533,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* modified collection. Pass a copy of the argument to avoid this.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
@@ -540,6 +549,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -549,6 +559,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
+ assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
@@ -582,6 +593,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, String)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -627,6 +639,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, PortableDataStream)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -651,6 +664,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@Experimental
def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
: RDD[Array[Byte]] = {
+ assertNotStopped()
conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
classOf[FixedLengthBinaryInputFormat],
@@ -684,6 +698,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
@@ -703,6 +718,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
@@ -782,6 +798,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ assertNotStopped()
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -802,6 +819,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
@@ -817,6 +835,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
valueClass: Class[V],
minPartitions: Int
): RDD[(K, V)] = {
+ assertNotStopped()
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
@@ -828,9 +847,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* If you plan to directly cache Hadoop writable objects, you should first copy them using
* a `map` function.
* */
- def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
- ): RDD[(K, V)] =
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
+ }
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -858,6 +878,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
+ assertNotStopped()
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
@@ -879,6 +900,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
path: String,
minPartitions: Int = defaultMinPartitions
): RDD[T] = {
+ assertNotStopped()
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader))
}
@@ -954,6 +976,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* The variable will be sent to each cluster only once.
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
+ assertNotStopped()
+ if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have created RDD broadcast variables but not used them:
+ logWarning("Can not directly broadcast RDDs; instead, call collect() and "
+ + "broadcast the result (see SPARK-5063)")
+ }
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
@@ -1046,6 +1075,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ assertNotStopped()
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
@@ -1058,6 +1088,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
+ assertNotStopped()
val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
rddInfos.filter(_.isCached)
@@ -1075,6 +1106,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getExecutorStorageStatus: Array[StorageStatus] = {
+ assertNotStopped()
env.blockManager.master.getStorageStatus
}
@@ -1084,6 +1116,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getAllPools: Seq[Schedulable] = {
+ assertNotStopped()
// TODO(xiajunluan): We should take nested pools into account
taskScheduler.rootPool.schedulableQueue.toSeq
}
@@ -1094,6 +1127,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
def getPoolForName(pool: String): Option[Schedulable] = {
+ assertNotStopped()
Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
}
@@ -1101,6 +1135,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ assertNotStopped()
taskScheduler.schedulingMode
}
@@ -1206,16 +1241,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
postApplicationEnd()
ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
+ if (!stopped) {
+ stopped = true
env.metricsSystem.report()
metadataCleaner.cancel()
env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
+ dagScheduler.stop()
+ dagScheduler = null
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@@ -1289,8 +1322,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- if (dagScheduler == null) {
- throw new SparkException("SparkContext has been shutdown")
+ if (stopped) {
+ throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
@@ -1377,6 +1410,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
+ assertNotStopped()
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
@@ -1399,6 +1433,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
+ assertNotStopped()
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
@@ -1417,11 +1452,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* for more information.
*/
def cancelJobGroup(groupId: String) {
+ assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
+ assertNotStopped()
dagScheduler.cancelAllJobs()
}
@@ -1468,13 +1505,20 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def getCheckpointDir = checkpointDir
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
- def defaultParallelism: Int = taskScheduler.defaultParallelism
+ def defaultParallelism: Int = {
+ assertNotStopped()
+ taskScheduler.defaultParallelism
+ }
/** Default min number of partitions for Hadoop RDDs when not given by user */
@deprecated("use defaultMinPartitions", "1.0.0")
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
- /** Default min number of partitions for Hadoop RDDs when not given by user */
+ /**
+ * Default min number of partitions for Hadoop RDDs when not given by user
+ * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2.
+ * The reasons for this are discussed in https://github.com/mesos/spark/pull/718
+ */
def defaultMinPartitions: Int = math.min(defaultParallelism, 2)
private val nextShuffleId = new AtomicInteger(0)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 4d418037bd33f..1264a8126153b 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -326,6 +326,10 @@ object SparkEnv extends Logging {
// Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
+ // We need to set the executor ID before the MetricsSystem is created because sources and
+ // sinks specified in the metrics configuration file will want to incorporate this executor's
+ // ID into the metrics they report.
+ conf.set("spark.executor.id", executorId)
val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
ms.start()
ms
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 2b084a2d73b78..0ae45f4ad9130 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -203,7 +203,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
if (!logInfos.isEmpty) {
val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]()
def addIfAbsent(info: FsApplicationHistoryInfo) = {
- if (!newApps.contains(info.id)) {
+ if (!newApps.contains(info.id) ||
+ newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) &&
+ !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) {
newApps += (info.id -> info)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 9a4adfbbb3d71..823825302658c 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -84,8 +84,12 @@ private[spark] class CoarseGrainedExecutorBackend(
}
case x: DisassociatedEvent =>
- logError(s"Driver $x disassociated! Shutting down.")
- System.exit(1)
+ if (x.remoteAddress == driver.anchorPath.address) {
+ logError(s"Driver $x disassociated! Shutting down.")
+ System.exit(1)
+ } else {
+ logWarning(s"Received irrelevant DisassociatedEvent $x")
+ }
case StopExecutor =>
logInfo("Driver commanded a shutdown")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 42566d1a14093..d8c2e41a7c715 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -41,11 +41,14 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
*/
private[spark] class Executor(
executorId: String,
- slaveHostname: String,
+ executorHostname: String,
env: SparkEnv,
isLocal: Boolean = false)
extends Logging
{
+
+ logInfo(s"Starting executor ID $executorId on host $executorHostname")
+
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -58,12 +61,12 @@ private[spark] class Executor(
@volatile private var isStopped = false
// No ip or host:port - just hostname
- Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ Utils.checkHost(executorHostname, "Expected executed slave to be a hostname")
// must not have port specified.
- assert (0 == Utils.parseHostPort(slaveHostname)._2)
+ assert (0 == Utils.parseHostPort(executorHostname)._2)
// Make sure the local hostname we report matches the cluster scheduler's name for this host
- Utils.setCustomHostname(slaveHostname)
+ Utils.setCustomHostname(executorHostname)
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 45633e3de01dd..83e8eb71260eb 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -130,8 +130,8 @@ private[spark] class MetricsSystem private (
if (appId.isDefined && executorId.isDefined) {
MetricRegistry.name(appId.get, executorId.get, source.sourceName)
} else {
- // Only Driver and Executor are set spark.app.id and spark.executor.id.
- // For instance, Master and Worker are not related to a specific application.
+ // Only Driver and Executor set spark.app.id and spark.executor.id.
+ // Other instance types, e.g. Master and Worker, are not related to a specific application.
val warningMsg = s"Using default name $defaultName for source because %s is not set."
if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) }
if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) }
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 97012c7033f9f..ab7410a1f7f99 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -76,10 +76,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* on RDD internals.
*/
abstract class RDD[T: ClassTag](
- @transient private var sc: SparkContext,
+ @transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
+ if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have defined nested RDDs without running jobs with them.
+ logWarning("Spark does not support nested RDDs (see SPARK-5063)")
+ }
+
+ private def sc: SparkContext = {
+ if (_sc == null) {
+ throw new SparkException(
+ "RDD transformations and actions can only be invoked by the driver, not inside of other " +
+ "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
+ "the values transformation and count action cannot be performed inside of the rdd1.map " +
+ "transformation. For more information, see SPARK-5063.")
+ }
+ _sc
+ }
+
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 6f446c5a95a0a..4307029d44fbb 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -24,8 +24,10 @@ private[spark] object ToolTips {
scheduler delay is large, consider decreasing the size of tasks or decreasing the size
of task results."""
- val TASK_DESERIALIZATION_TIME =
- """Time spent deserializating the task closure on the executor."""
+ val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor."
+
+ val SHUFFLE_READ_BLOCKED_TIME =
+ "Time that the task spent blocked waiting for shuffle data to be read from remote machines."
val INPUT = "Bytes read from Hadoop or from Spark storage."
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 09a936c2234c0..d8be1b20b3acd 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -132,6 +132,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
Task Deserialization Time
+ {if (hasShuffleRead) {
+
+
+
+ Shuffle Read Blocked Time
+
+
+ }}
@@ -167,7 +176,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
{if (hasAccumulators) Seq(("Accumulators", "")) else Nil} ++
{if (hasInput) Seq(("Input", "")) else Nil} ++
{if (hasOutput) Seq(("Output", "")) else Nil} ++
- {if (hasShuffleRead) Seq(("Shuffle Read", "")) else Nil} ++
+ {if (hasShuffleRead) {
+ Seq(("Shuffle Read Blocked Time", TaskDetailsClassNames.SHUFFLE_READ_BLOCKED_TIME),
+ ("Shuffle Read", ""))
+ } else {
+ Nil
+ }} ++
{if (hasShuffleWrite) Seq(("Write Time", ""), ("Shuffle Write", "")) else Nil} ++
{if (hasBytesSpilled) Seq(("Shuffle Spill (Memory)", ""), ("Shuffle Spill (Disk)", ""))
else Nil} ++
@@ -271,6 +285,12 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}
val outputQuantiles = Output +: getFormattedSizeQuantiles(outputSizes)
+ val shuffleReadBlockedTimes = validTasks.map { case TaskUIData(_, metrics, _) =>
+ metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble
+ }
+ val shuffleReadBlockedQuantiles = Shuffle Read Blocked Time +:
+ getFormattedTimeQuantiles(shuffleReadBlockedTimes)
+
val shuffleReadSizes = validTasks.map { case TaskUIData(_, metrics, _) =>
metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble
}
@@ -308,7 +328,14 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
{gettingResultQuantiles} ,
if (hasInput) {inputQuantiles} else Nil,
if (hasOutput) {outputQuantiles} else Nil,
- if (hasShuffleRead) {shuffleReadQuantiles} else Nil,
+ if (hasShuffleRead) {
+
+ {shuffleReadBlockedQuantiles}
+
+ {shuffleReadQuantiles}
+ } else {
+ Nil
+ },
if (hasShuffleWrite) {shuffleWriteQuantiles} else Nil,
if (hasBytesSpilled) {memoryBytesSpilledQuantiles} else Nil,
if (hasBytesSpilled) {diskBytesSpilledQuantiles} else Nil)
@@ -377,6 +404,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
.map(m => s"${Utils.bytesToString(m.bytesWritten)}")
.getOrElse("")
+ val maybeShuffleReadBlockedTime = metrics.flatMap(_.shuffleReadMetrics).map(_.fetchWaitTime)
+ val shuffleReadBlockedTimeSortable = maybeShuffleReadBlockedTime.map(_.toString).getOrElse("")
+ val shuffleReadBlockedTimeReadable =
+ maybeShuffleReadBlockedTime.map(ms => UIUtils.formatDuration(ms)).getOrElse("")
+
val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead)
val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("")
val shuffleReadReadable = maybeShuffleRead.map(Utils.bytesToString).getOrElse("")
@@ -449,6 +481,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}}
{if (hasShuffleRead) {
+
+ {shuffleReadBlockedTimeReadable}
+
{shuffleReadReadable}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
index 2d13bb6ddde42..37cf2c207ba40 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -27,6 +27,7 @@ package org.apache.spark.ui.jobs
private[spark] object TaskDetailsClassNames {
val SCHEDULER_DELAY = "scheduler_delay"
val TASK_DESERIALIZATION_TIME = "deserialization_time"
+ val SHUFFLE_READ_BLOCKED_TIME = "fetch_wait_time"
val RESULT_SERIALIZATION_TIME = "serialization_time"
val GETTING_RESULT_TIME = "getting_result_time"
}
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 7584ae79fc920..21487bc24d58a 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -171,11 +171,11 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
assert(jobB.get() === 100)
}
- ignore("two jobs sharing the same stage") {
+ test("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
- // sem2: make sure the first stage is not finished until cancel is issued
+ // twoJobsSharingStageSemaphore:
+ // make sure the first stage is not finished until cancel is issued
val sem1 = new Semaphore(0)
- val sem2 = new Semaphore(0)
sc = new SparkContext("local[2]", "test")
sc.addSparkListener(new SparkListener {
@@ -186,7 +186,7 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
// Create two actions that would share the some stages.
val rdd = sc.parallelize(1 to 10, 2).map { i =>
- sem2.acquire()
+ JobCancellationSuite.twoJobsSharingStageSemaphore.acquire()
(i, i)
}.reduceByKey(_+_)
val f1 = rdd.collectAsync()
@@ -196,13 +196,13 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
future {
sem1.acquire()
f1.cancel()
- sem2.release(10)
+ JobCancellationSuite.twoJobsSharingStageSemaphore.release(10)
}
- // Expect both to fail now.
- // TODO: update this test when we change Spark so cancelling f1 wouldn't affect f2.
+ // Expect f1 to fail due to cancellation,
intercept[SparkException] { f1.get() }
- intercept[SparkException] { f2.get() }
+ // but f2 should not be affected
+ f2.get()
}
def testCount() {
@@ -268,4 +268,5 @@ class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
object JobCancellationSuite {
val taskStartedSemaphore = new Semaphore(0)
val taskCancelledSemaphore = new Semaphore(0)
+ val twoJobsSharingStageSemaphore = new Semaphore(0)
}
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index b0a70f012f1f3..af3272692d7a1 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -170,6 +170,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
testPackage.runCallSiteTest(sc)
}
+ test("Broadcast variables cannot be created after SparkContext is stopped (SPARK-5065)") {
+ sc = new SparkContext("local", "test")
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ sc.broadcast(Seq(1, 2, 3))
+ }
+ assert(thrown.getMessage.toLowerCase.contains("stopped"))
+ }
+
/**
* Verify the persistence of state associated with an HttpBroadcast in either local mode or
* local-cluster mode (when distributed = true).
@@ -349,8 +358,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
package object testPackage extends Assertions {
def runCallSiteTest(sc: SparkContext) {
- val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
- val broadcast = sc.broadcast(rdd)
+ val broadcast = sc.broadcast(Array(1, 2, 3, 4))
broadcast.destroy()
val thrown = intercept[SparkException] { broadcast.value }
assert(thrown.getMessage.contains("BroadcastSuite.scala"))
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
index 8379883e065e7..3fbc1a21d10ed 100644
--- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala
@@ -167,6 +167,29 @@ class FsHistoryProviderSuite extends FunSuite with BeforeAndAfter with Matchers
list.size should be (1)
}
+ test("history file is renamed from inprogress to completed") {
+ val conf = new SparkConf()
+ .set("spark.history.fs.logDirectory", testDir.getAbsolutePath())
+ .set("spark.testing", "true")
+ val provider = new FsHistoryProvider(conf)
+
+ val logFile1 = new File(testDir, "app1" + EventLoggingListener.IN_PROGRESS)
+ writeFile(logFile1, true, None,
+ SparkListenerApplicationStart("app1", Some("app1"), 1L, "test"),
+ SparkListenerApplicationEnd(2L)
+ )
+ provider.checkForLogs()
+ val appListBeforeRename = provider.getListing()
+ appListBeforeRename.size should be (1)
+ appListBeforeRename.head.logPath should endWith(EventLoggingListener.IN_PROGRESS)
+
+ logFile1.renameTo(new File(testDir, "app1"))
+ provider.checkForLogs()
+ val appListAfterRename = provider.getListing()
+ appListAfterRename.size should be (1)
+ appListAfterRename.head.logPath should not endWith(EventLoggingListener.IN_PROGRESS)
+ }
+
private def writeFile(file: File, isNewFormat: Boolean, codec: Option[CompressionCodec],
events: SparkListenerEvent*) = {
val out =
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
index 1a28a9a187cd7..372d7aa453008 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerArgumentsTest.scala
@@ -43,7 +43,7 @@ class WorkerArgumentsTest extends FunSuite {
}
override def clone: SparkConf = {
- new MySparkConf().setAll(settings)
+ new MySparkConf().setAll(getAll)
}
}
val conf = new MySparkConf()
@@ -62,7 +62,7 @@ class WorkerArgumentsTest extends FunSuite {
}
override def clone: SparkConf = {
- new MySparkConf().setAll(settings)
+ new MySparkConf().setAll(getAll)
}
}
val conf = new MySparkConf()
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 381ee2d45630f..e33b4bbbb8e4c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -927,4 +927,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
mutableDependencies += dep
}
}
+
+ test("nested RDDs are not supported (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ val nestedRDD: RDD[RDD[Int]] = rdd.mapPartitions { x => Seq(rdd2.map(x => x)).iterator }
+ nestedRDD.count()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("actions cannot be performed inside of transformations (SPARK-5063)") {
+ val rdd: RDD[Int] = sc.parallelize(1 to 100)
+ val rdd2: RDD[Int] = sc.parallelize(1 to 100)
+ val thrown = intercept[SparkException] {
+ rdd.map(x => x * rdd2.count).collect()
+ }
+ assert(thrown.getMessage.contains("SPARK-5063"))
+ }
+
+ test("cannot run actions after SparkContext has been stopped (SPARK-5063)") {
+ val existingRDD = sc.parallelize(1 to 100)
+ sc.stop()
+ val thrown = intercept[IllegalStateException] {
+ existingRDD.count()
+ }
+ assert(thrown.getMessage.contains("shutdown"))
+ }
+
+ test("cannot call methods on a stopped SparkContext (SPARK-5063)") {
+ sc.stop()
+ def assertFails(block: => Any): Unit = {
+ val thrown = intercept[IllegalStateException] {
+ block
+ }
+ assert(thrown.getMessage.contains("stopped"))
+ }
+ assertFails { sc.parallelize(1 to 100) }
+ assertFails { sc.textFile("/nonexistent-path") }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
index dae7bf0e336de..8cf951adb354b 100644
--- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala
@@ -49,7 +49,7 @@ class LocalDirsSuite extends FunSuite {
}
override def clone: SparkConf = {
- new MySparkConf().setAll(settings)
+ new MySparkConf().setAll(getAll)
}
}
// spark.local.dir only contains invalid directories, but that's not a problem since
diff --git a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
index 10541f878476c..1026cb2aa7cae 100644
--- a/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/EventLoopSuite.scala
@@ -41,7 +41,7 @@ class EventLoopSuite extends FunSuite with Timeouts {
}
eventLoop.start()
(1 to 100).foreach(eventLoop.post)
- eventually(timeout(5 seconds), interval(200 millis)) {
+ eventually(timeout(5 seconds), interval(5 millis)) {
assert((1 to 100) === buffer.toSeq)
}
eventLoop.stop()
@@ -76,7 +76,7 @@ class EventLoopSuite extends FunSuite with Timeouts {
}
eventLoop.start()
eventLoop.post(1)
- eventually(timeout(5 seconds), interval(200 millis)) {
+ eventually(timeout(5 seconds), interval(5 millis)) {
assert(e === receivedError)
}
eventLoop.stop()
@@ -98,7 +98,7 @@ class EventLoopSuite extends FunSuite with Timeouts {
}
eventLoop.start()
eventLoop.post(1)
- eventually(timeout(5 seconds), interval(200 millis)) {
+ eventually(timeout(5 seconds), interval(5 millis)) {
assert(e === receivedError)
assert(eventLoop.isActive)
}
@@ -153,7 +153,7 @@ class EventLoopSuite extends FunSuite with Timeouts {
}.start()
}
- eventually(timeout(5 seconds), interval(200 millis)) {
+ eventually(timeout(5 seconds), interval(5 millis)) {
assert(threadNum * eventsFromEachThread === receivedEventsCount)
}
eventLoop.stop()
@@ -185,4 +185,22 @@ class EventLoopSuite extends FunSuite with Timeouts {
}
assert(false === eventLoop.isActive)
}
+
+ test("EventLoop: stop in eventThread") {
+ val eventLoop = new EventLoop[Int]("test") {
+
+ override def onReceive(event: Int): Unit = {
+ stop()
+ }
+
+ override def onError(e: Throwable): Unit = {
+ }
+
+ }
+ eventLoop.start()
+ eventLoop.post(1)
+ eventually(timeout(5 seconds), interval(5 millis)) {
+ assert(!eventLoop.isActive)
+ }
+ }
}
diff --git a/docs/configuration.md b/docs/configuration.md
index efbab4085317a..7c5b6d011cfd3 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -197,6 +197,27 @@ Apart from these, the following properties are also available, and may be useful
#### Runtime Environment
Property Name Default Meaning
+
+ spark.driver.extraJavaOptions
+ (none)
+
+ A string of extra JVM options to pass to the driver. For instance, GC settings or other logging.
+
+
+
+ spark.driver.extraClassPath
+ (none)
+
+ Extra classpath entries to append to the classpath of the driver.
+
+
+
+ spark.driver.extraLibraryPath
+ (none)
+
+ Set a special library path to use when launching the driver JVM.
+
+
spark.executor.extraJavaOptions
(none)
diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md
index 0e38fe2144e9f..77c0abbbacbd0 100644
--- a/docs/streaming-kafka-integration.md
+++ b/docs/streaming-kafka-integration.md
@@ -29,7 +29,7 @@ title: Spark Streaming + Kafka Integration Guide
streamingContext, [zookeeperQuorum], [group id of the consumer], [per-topic number of Kafka partitions to consume]);
See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html)
- and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
+ and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaKafkaWordCount.java).
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index 3bd1deaccfafe..14a87f8436984 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -58,8 +58,8 @@ for applications that involve the REPL (e.g. Spark shell).
Alternatively, if your application is submitted from a machine far from the worker machines (e.g.
locally on your laptop), it is common to use `cluster` mode to minimize network latency between
-the drivers and the executors. Note that `cluster` mode is currently not supported for standalone
-clusters, Mesos clusters, or Python applications.
+the drivers and the executors. Note that `cluster` mode is currently not supported for
+Mesos clusters or Python applications.
For Python applications, simply pass a `.py` file in the place of `` instead of a JAR,
and add Python `.zip`, `.egg` or `.py` files to the search path with `--py-files`.
diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
index 2adc63f7ff30e..387c0e421334b 100644
--- a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
+++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala
@@ -76,7 +76,7 @@ object KafkaWordCountProducer {
val Array(brokers, topic, messagesPerSec, wordsPerMessage) = args
- // Zookeper connection properties
+ // Zookeeper connection properties
val props = new Properties()
props.put("metadata.broker.list", brokers)
props.put("serializer.class", "kafka.serializer.StringEncoder")
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
new file mode 100644
index 0000000000000..cf62772b92651
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
@@ -0,0 +1,174 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.examples.mllib.AbstractParams
+import org.apache.spark.ml.recommendation.ALS
+import org.apache.spark.sql.{Row, SQLContext}
+
+/**
+ * An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
+ * Run with
+ * {{{
+ * bin/run-example ml.MovieLensALS
+ * }}}
+ */
+object MovieLensALS {
+
+ case class Rating(userId: Int, movieId: Int, rating: Float, timestamp: Long)
+
+ object Rating {
+ def parseRating(str: String): Rating = {
+ val fields = str.split("::")
+ assert(fields.size == 4)
+ Rating(fields(0).toInt, fields(1).toInt, fields(2).toFloat, fields(3).toLong)
+ }
+ }
+
+ case class Movie(movieId: Int, title: String, genres: Seq[String])
+
+ object Movie {
+ def parseMovie(str: String): Movie = {
+ val fields = str.split("::")
+ assert(fields.size == 3)
+ Movie(fields(0).toInt, fields(1), fields(2).split("|"))
+ }
+ }
+
+ case class Params(
+ ratings: String = null,
+ movies: String = null,
+ maxIter: Int = 10,
+ regParam: Double = 0.1,
+ rank: Int = 10,
+ numBlocks: Int = 10) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("MovieLensALS") {
+ head("MovieLensALS: an example app for ALS on MovieLens data.")
+ opt[String]("ratings")
+ .required()
+ .text("path to a MovieLens dataset of ratings")
+ .action((x, c) => c.copy(ratings = x))
+ opt[String]("movies")
+ .required()
+ .text("path to a MovieLens dataset of movies")
+ .action((x, c) => c.copy(movies = x))
+ opt[Int]("rank")
+ .text(s"rank, default: ${defaultParams.rank}}")
+ .action((x, c) => c.copy(rank = x))
+ opt[Int]("maxIter")
+ .text(s"max number of iterations, default: ${defaultParams.maxIter}")
+ .action((x, c) => c.copy(maxIter = x))
+ opt[Double]("regParam")
+ .text(s"regularization parameter, default: ${defaultParams.regParam}")
+ .action((x, c) => c.copy(regParam = x))
+ opt[Int]("numBlocks")
+ .text(s"number of blocks, default: ${defaultParams.numBlocks}")
+ .action((x, c) => c.copy(numBlocks = x))
+ note(
+ """
+ |Example command line to run this app:
+ |
+ | bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \
+ | examples/target/scala-*/spark-examples-*.jar \
+ | --rank 10 --maxIter 15 --regParam 0.1 \
+ | --movies path/to/movielens/movies.dat \
+ | --ratings path/to/movielens/ratings.dat
+ """.stripMargin)
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ } getOrElse {
+ System.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._
+
+ val ratings = sc.textFile(params.ratings).map(Rating.parseRating).cache()
+
+ val numRatings = ratings.count()
+ val numUsers = ratings.map(_.userId).distinct().count()
+ val numMovies = ratings.map(_.movieId).distinct().count()
+
+ println(s"Got $numRatings ratings from $numUsers users on $numMovies movies.")
+
+ val splits = ratings.randomSplit(Array(0.8, 0.2), 0L)
+ val training = splits(0).cache()
+ val test = splits(1).cache()
+
+ val numTraining = training.count()
+ val numTest = test.count()
+ println(s"Training: $numTraining, test: $numTest.")
+
+ ratings.unpersist(blocking = false)
+
+ val als = new ALS()
+ .setUserCol("userId")
+ .setItemCol("movieId")
+ .setRank(params.rank)
+ .setMaxIter(params.maxIter)
+ .setRegParam(params.regParam)
+ .setNumBlocks(params.numBlocks)
+
+ val model = als.fit(training)
+
+ val predictions = model.transform(test).cache()
+
+ // Evaluate the model.
+ // TODO: Create an evaluator to compute RMSE.
+ val mse = predictions.select('rating, 'prediction)
+ .flatMap { case Row(rating: Float, prediction: Float) =>
+ val err = rating.toDouble - prediction
+ val err2 = err * err
+ if (err2.isNaN) {
+ None
+ } else {
+ Some(err2)
+ }
+ }.mean()
+ val rmse = math.sqrt(mse)
+ println(s"Test RMSE = $rmse.")
+
+ // Inspect false positives.
+ predictions.registerTempTable("prediction")
+ sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie")
+ sqlContext.sql(
+ """
+ |SELECT userId, prediction.movieId, title, rating, prediction
+ | FROM prediction JOIN movie ON prediction.movieId = movie.movieId
+ | WHERE rating <= 1 AND prediction >= 4
+ | LIMIT 100
+ """.stripMargin)
+ .collect()
+ .foreach(println)
+
+ sc.stop()
+ }
+}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
index 897c7ee12a436..f1550ac2e18ad 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala
@@ -19,7 +19,7 @@ package org.apache.spark.graphx.impl
import scala.reflect.{classTag, ClassTag}
-import org.apache.spark.{OneToOneDependency, Partition, Partitioner, TaskContext}
+import org.apache.spark.{OneToOneDependency, HashPartitioner, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -46,7 +46,7 @@ class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] (
* partitioner that allows co-partitioning with `partitionsRDD`.
*/
override val partitioner =
- partitionsRDD.partitioner.orElse(Some(Partitioner.defaultPartitioner(partitionsRDD)))
+ partitionsRDD.partitioner.orElse(Some(new HashPartitioner(partitions.size)))
override def collect(): Array[Edge[ED]] = this.map(_.copy()).collect()
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
index 9da0064104fb6..ed9876b8dc21c 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphSuite.scala
@@ -386,4 +386,24 @@ class GraphSuite extends FunSuite with LocalSparkContext {
}
}
+ test("non-default number of edge partitions") {
+ val n = 10
+ val defaultParallelism = 3
+ val numEdgePartitions = 4
+ assert(defaultParallelism != numEdgePartitions)
+ val conf = new org.apache.spark.SparkConf()
+ .set("spark.default.parallelism", defaultParallelism.toString)
+ val sc = new SparkContext("local", "test", conf)
+ try {
+ val edges = sc.parallelize((1 to n).map(x => (x: VertexId, 0: VertexId)),
+ numEdgePartitions)
+ val graph = Graph.fromEdgeTuples(edges, 1)
+ val neighborAttrSums = graph.mapReduceTriplets[Int](
+ et => Iterator((et.dstId, et.srcAttr)), _ + _)
+ assert(neighborAttrSums.collect.toSet === Set((0: VertexId, n)))
+ } finally {
+ sc.stop()
+ }
+ }
+
}
diff --git a/make-distribution.sh b/make-distribution.sh
index 4e2f400be3053..0adca7851819b 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -115,7 +115,7 @@ if which git &>/dev/null; then
unset GITREV
fi
-if ! which $MVN &>/dev/null; then
+if ! which "$MVN" &>/dev/null; then
echo -e "Could not locate Maven command: '$MVN'."
echo -e "Specify the Maven command with the --mvn flag"
exit -1;
@@ -171,13 +171,16 @@ cd "$SPARK_HOME"
export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"
-BUILD_COMMAND="$MVN clean package -DskipTests $@"
+# Store the command as an array because $MVN variable might have spaces in it.
+# Normal quoting tricks don't work.
+# See: http://mywiki.wooledge.org/BashFAQ/050
+BUILD_COMMAND=("$MVN" clean package -DskipTests $@)
# Actually build the jar
echo -e "\nBuilding with..."
-echo -e "\$ $BUILD_COMMAND\n"
+echo -e "\$ ${BUILD_COMMAND[@]}\n"
-${BUILD_COMMAND}
+"${BUILD_COMMAND[@]}"
# Make directories
rm -rf "$DISTDIR"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
new file mode 100644
index 0000000000000..2d89e76a4c8b2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -0,0 +1,973 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import java.{util => ju}
+
+import scala.collection.mutable
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
+import org.netlib.util.intW
+
+import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.catalyst.expressions.Cast
+import org.apache.spark.sql.catalyst.plans.LeftOuter
+import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
+import org.apache.spark.util.Utils
+import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * Common params for ALS.
+ */
+private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
+ with HasPredictionCol {
+
+ /** Param for rank of the matrix factorization. */
+ val rank = new IntParam(this, "rank", "rank of the factorization", Some(10))
+ def getRank: Int = get(rank)
+
+ /** Param for number of user blocks. */
+ val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10))
+ def getNumUserBlocks: Int = get(numUserBlocks)
+
+ /** Param for number of item blocks. */
+ val numItemBlocks =
+ new IntParam(this, "numItemBlocks", "number of item blocks", Some(10))
+ def getNumItemBlocks: Int = get(numItemBlocks)
+
+ /** Param to decide whether to use implicit preference. */
+ val implicitPrefs =
+ new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false))
+ def getImplicitPrefs: Boolean = get(implicitPrefs)
+
+ /** Param for the alpha parameter in the implicit preference formulation. */
+ val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0))
+ def getAlpha: Double = get(alpha)
+
+ /** Param for the column name for user ids. */
+ val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user"))
+ def getUserCol: String = get(userCol)
+
+ /** Param for the column name for item ids. */
+ val itemCol =
+ new Param[String](this, "itemCol", "column name for item ids", Some("item"))
+ def getItemCol: String = get(itemCol)
+
+ /** Param for the column name for ratings. */
+ val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
+ def getRatingCol: String = get(ratingCol)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @param paramMap extra params
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = this.paramMap ++ paramMap
+ assert(schema(map(userCol)).dataType == IntegerType)
+ assert(schema(map(itemCol)).dataType== IntegerType)
+ val ratingType = schema(map(ratingCol)).dataType
+ assert(ratingType == FloatType || ratingType == DoubleType)
+ val predictionColName = map(predictionCol)
+ assert(!schema.fieldNames.contains(predictionColName),
+ s"Prediction column $predictionColName already exists.")
+ val newFields = schema.fields :+ StructField(map(predictionCol), FloatType, nullable = false)
+ StructType(newFields)
+ }
+}
+
+/**
+ * Model fitted by ALS.
+ */
+class ALSModel private[ml] (
+ override val parent: ALS,
+ override val fittingParamMap: ParamMap,
+ k: Int,
+ userFactors: RDD[(Int, Array[Float])],
+ itemFactors: RDD[(Int, Array[Float])])
+ extends Model[ALSModel] with ALSParams {
+
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ import dataset.sqlContext._
+ import org.apache.spark.ml.recommendation.ALSModel.Factor
+ val map = this.paramMap ++ paramMap
+ // TODO: Add DSL to simplify the code here.
+ val instanceTable = s"instance_$uid"
+ val userTable = s"user_$uid"
+ val itemTable = s"item_$uid"
+ val instances = dataset.as(Symbol(instanceTable))
+ val users = userFactors.map { case (id, features) =>
+ Factor(id, features)
+ }.as(Symbol(userTable))
+ val items = itemFactors.map { case (id, features) =>
+ Factor(id, features)
+ }.as(Symbol(itemTable))
+ val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
+ if (userFeatures != null && itemFeatures != null) {
+ blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
+ } else {
+ Float.NaN
+ }
+ }
+ val inputColumns = dataset.schema.fieldNames
+ val prediction =
+ predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol)
+ val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction
+ instances
+ .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr))
+ .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr))
+ .select(outputColumns: _*)
+ }
+
+ override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+private object ALSModel {
+ /** Case class to convert factors to SchemaRDDs */
+ private case class Factor(id: Int, features: Seq[Float])
+}
+
+/**
+ * Alternating Least Squares (ALS) matrix factorization.
+ *
+ * ALS attempts to estimate the ratings matrix `R` as the product of two lower-rank matrices,
+ * `X` and `Y`, i.e. `X * Yt = R`. Typically these approximations are called 'factor' matrices.
+ * The general approach is iterative. During each iteration, one of the factor matrices is held
+ * constant, while the other is solved for using least squares. The newly-solved factor matrix is
+ * then held constant while solving for the other factor matrix.
+ *
+ * This is a blocked implementation of the ALS factorization algorithm that groups the two sets
+ * of factors (referred to as "users" and "products") into blocks and reduces communication by only
+ * sending one copy of each user vector to each product block on each iteration, and only for the
+ * product blocks that need that user's feature vector. This is achieved by pre-computing some
+ * information about the ratings matrix to determine the "out-links" of each user (which blocks of
+ * products it will contribute to) and "in-link" information for each product (which of the feature
+ * vectors it receives from each user block it will depend on). This allows us to send only an
+ * array of feature vectors between each user block and product block, and have the product block
+ * find the users' ratings and update the products based on these messages.
+ *
+ * For implicit preference data, the algorithm used is based on
+ * "Collaborative Filtering for Implicit Feedback Datasets", available at
+ * [[http://dx.doi.org/10.1109/ICDM.2008.22]], adapted for the blocked approach used here.
+ *
+ * Essentially instead of finding the low-rank approximations to the rating matrix `R`,
+ * this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
+ * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of
+ * indicated user
+ * preferences rather than explicit ratings given to items.
+ */
+class ALS extends Estimator[ALSModel] with ALSParams {
+
+ import org.apache.spark.ml.recommendation.ALS.Rating
+
+ def setRank(value: Int): this.type = set(rank, value)
+ def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value)
+ def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value)
+ def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value)
+ def setAlpha(value: Double): this.type = set(alpha, value)
+ def setUserCol(value: String): this.type = set(userCol, value)
+ def setItemCol(value: String): this.type = set(itemCol, value)
+ def setRatingCol(value: String): this.type = set(ratingCol, value)
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+ def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** Sets both numUserBlocks and numItemBlocks to the specific value. */
+ def setNumBlocks(value: Int): this.type = {
+ setNumUserBlocks(value)
+ setNumItemBlocks(value)
+ this
+ }
+
+ setMaxIter(20)
+ setRegParam(1.0)
+
+ override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = {
+ import dataset.sqlContext._
+ val map = this.paramMap ++ paramMap
+ val ratings =
+ dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType))
+ .map { row =>
+ new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
+ }
+ val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
+ numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
+ maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
+ alpha = map(alpha))
+ val model = new ALSModel(this, map, map(rank), userFactors, itemFactors)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+private[recommendation] object ALS extends Logging {
+
+ /** Rating class for better code readability. */
+ private[recommendation] case class Rating(user: Int, item: Int, rating: Float)
+
+ /** Cholesky solver for least square problems. */
+ private[recommendation] class CholeskySolver {
+
+ private val upper = "U"
+ private val info = new intW(0)
+
+ /**
+ * Solves a least squares problem with L2 regularization:
+ *
+ * min norm(A x - b)^2^ + lambda * n * norm(x)^2^
+ *
+ * @param ne a [[NormalEquation]] instance that contains AtA, Atb, and n (number of instances)
+ * @param lambda regularization constant, which will be scaled by n
+ * @return the solution x
+ */
+ def solve(ne: NormalEquation, lambda: Double): Array[Float] = {
+ val k = ne.k
+ // Add scaled lambda to the diagonals of AtA.
+ val scaledlambda = lambda * ne.n
+ var i = 0
+ var j = 2
+ while (i < ne.triK) {
+ ne.ata(i) += scaledlambda
+ i += j
+ j += 1
+ }
+ lapack.dppsv(upper, k, 1, ne.ata, ne.atb, k, info)
+ val code = info.`val`
+ assert(code == 0, s"lapack.dppsv returned $code.")
+ val x = new Array[Float](k)
+ i = 0
+ while (i < k) {
+ x(i) = ne.atb(i).toFloat
+ i += 1
+ }
+ ne.reset()
+ x
+ }
+ }
+
+ /** Representing a normal equation (ALS' subproblem). */
+ private[recommendation] class NormalEquation(val k: Int) extends Serializable {
+
+ /** Number of entries in the upper triangular part of a k-by-k matrix. */
+ val triK = k * (k + 1) / 2
+ /** A^T^ * A */
+ val ata = new Array[Double](triK)
+ /** A^T^ * b */
+ val atb = new Array[Double](k)
+ /** Number of observations. */
+ var n = 0
+
+ private val da = new Array[Double](k)
+ private val upper = "U"
+
+ private def copyToDouble(a: Array[Float]): Unit = {
+ var i = 0
+ while (i < k) {
+ da(i) = a(i)
+ i += 1
+ }
+ }
+
+ /** Adds an observation. */
+ def add(a: Array[Float], b: Float): this.type = {
+ require(a.size == k)
+ copyToDouble(a)
+ blas.dspr(upper, k, 1.0, da, 1, ata)
+ blas.daxpy(k, b.toDouble, da, 1, atb, 1)
+ n += 1
+ this
+ }
+
+ /**
+ * Adds an observation with implicit feedback. Note that this does not increment the counter.
+ */
+ def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = {
+ require(a.size == k)
+ // Extension to the original paper to handle b < 0. confidence is a function of |b| instead
+ // so that it is never negative.
+ val confidence = 1.0 + alpha * math.abs(b)
+ copyToDouble(a)
+ blas.dspr(upper, k, confidence - 1.0, da, 1, ata)
+ // For b <= 0, the corresponding preference is 0. So the term below is only added for b > 0.
+ if (b > 0) {
+ blas.daxpy(k, confidence, da, 1, atb, 1)
+ }
+ this
+ }
+
+ /** Merges another normal equation object. */
+ def merge(other: NormalEquation): this.type = {
+ require(other.k == k)
+ blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1)
+ blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1)
+ n += other.n
+ this
+ }
+
+ /** Resets everything to zero, which should be called after each solve. */
+ def reset(): Unit = {
+ ju.Arrays.fill(ata, 0.0)
+ ju.Arrays.fill(atb, 0.0)
+ n = 0
+ }
+ }
+
+ /**
+ * Implementation of the ALS algorithm.
+ */
+ private def train(
+ ratings: RDD[Rating],
+ rank: Int = 10,
+ numUserBlocks: Int = 10,
+ numItemBlocks: Int = 10,
+ maxIter: Int = 10,
+ regParam: Double = 1.0,
+ implicitPrefs: Boolean = false,
+ alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = {
+ val userPart = new HashPartitioner(numUserBlocks)
+ val itemPart = new HashPartitioner(numItemBlocks)
+ val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
+ val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
+ val blockRatings = partitionRatings(ratings, userPart, itemPart).cache()
+ val (userInBlocks, userOutBlocks) = makeBlocks("user", blockRatings, userPart, itemPart)
+ // materialize blockRatings and user blocks
+ userOutBlocks.count()
+ val swappedBlockRatings = blockRatings.map {
+ case ((userBlockId, itemBlockId), RatingBlock(userIds, itemIds, localRatings)) =>
+ ((itemBlockId, userBlockId), RatingBlock(itemIds, userIds, localRatings))
+ }
+ val (itemInBlocks, itemOutBlocks) = makeBlocks("item", swappedBlockRatings, itemPart, userPart)
+ // materialize item blocks
+ itemOutBlocks.count()
+ var userFactors = initialize(userInBlocks, rank)
+ var itemFactors = initialize(itemInBlocks, rank)
+ if (implicitPrefs) {
+ for (iter <- 1 to maxIter) {
+ userFactors.setName(s"userFactors-$iter").persist()
+ val previousItemFactors = itemFactors
+ itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
+ userLocalIndexEncoder, implicitPrefs, alpha)
+ previousItemFactors.unpersist()
+ itemFactors.setName(s"itemFactors-$iter").persist()
+ val previousUserFactors = userFactors
+ userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
+ itemLocalIndexEncoder, implicitPrefs, alpha)
+ previousUserFactors.unpersist()
+ }
+ } else {
+ for (iter <- 0 until maxIter) {
+ itemFactors = computeFactors(userFactors, userOutBlocks, itemInBlocks, rank, regParam,
+ userLocalIndexEncoder)
+ userFactors = computeFactors(itemFactors, itemOutBlocks, userInBlocks, rank, regParam,
+ itemLocalIndexEncoder)
+ }
+ }
+ val userIdAndFactors = userInBlocks
+ .mapValues(_.srcIds)
+ .join(userFactors)
+ .values
+ .setName("userFactors")
+ .cache()
+ userIdAndFactors.count()
+ itemFactors.unpersist()
+ val itemIdAndFactors = itemInBlocks
+ .mapValues(_.srcIds)
+ .join(itemFactors)
+ .values
+ .setName("itemFactors")
+ .cache()
+ itemIdAndFactors.count()
+ userInBlocks.unpersist()
+ userOutBlocks.unpersist()
+ itemInBlocks.unpersist()
+ itemOutBlocks.unpersist()
+ blockRatings.unpersist()
+ val userOutput = userIdAndFactors.flatMap { case (ids, factors) =>
+ ids.view.zip(factors)
+ }
+ val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) =>
+ ids.view.zip(factors)
+ }
+ (userOutput, itemOutput)
+ }
+
+ /**
+ * Factor block that stores factors (Array[Float]) in an Array.
+ */
+ private type FactorBlock = Array[Array[Float]]
+
+ /**
+ * Out-link block that stores, for each dst (item/user) block, which src (user/item) factors to
+ * send. For example, outLinkBlock(0) contains the local indices (not the original src IDs) of the
+ * src factors in this block to send to dst block 0.
+ */
+ private type OutBlock = Array[Array[Int]]
+
+ /**
+ * In-link block for computing src (user/item) factors. This includes the original src IDs
+ * of the elements within this block as well as encoded dst (item/user) indices and corresponding
+ * ratings. The dst indices are in the form of (blockId, localIndex), which are not the original
+ * dst IDs. To compute src factors, we expect receiving dst factors that match the dst indices.
+ * For example, if we have an in-link record
+ *
+ * {srcId: 0, dstBlockId: 2, dstLocalIndex: 3, rating: 5.0},
+ *
+ * and assume that the dst factors are stored as dstFactors: Map[Int, Array[Array[Float]]], which
+ * is a blockId to dst factors map, the corresponding dst factor of the record is dstFactor(2)(3).
+ *
+ * We use a CSC-like (compressed sparse column) format to store the in-link information. So we can
+ * compute src factors one after another using only one normal equation instance.
+ *
+ * @param srcIds src ids (ordered)
+ * @param dstPtrs dst pointers. Elements in range [dstPtrs(i), dstPtrs(i+1)) of dst indices and
+ * ratings are associated with srcIds(i).
+ * @param dstEncodedIndices encoded dst indices
+ * @param ratings ratings
+ *
+ * @see [[LocalIndexEncoder]]
+ */
+ private[recommendation] case class InBlock(
+ srcIds: Array[Int],
+ dstPtrs: Array[Int],
+ dstEncodedIndices: Array[Int],
+ ratings: Array[Float]) {
+ /** Size of the block. */
+ val size: Int = ratings.size
+
+ require(dstEncodedIndices.size == size)
+ require(dstPtrs.size == srcIds.size + 1)
+ }
+
+ /**
+ * Initializes factors randomly given the in-link blocks.
+ *
+ * @param inBlocks in-link blocks
+ * @param rank rank
+ * @return initialized factor blocks
+ */
+ private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = {
+ // Choose a unit vector uniformly at random from the unit sphere, but from the
+ // "first quadrant" where all elements are nonnegative. This can be done by choosing
+ // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing.
+ // This appears to create factorizations that have a slightly better reconstruction
+ // (<1%) compared picking elements uniformly at random in [0,1].
+ inBlocks.map { case (srcBlockId, inBlock) =>
+ val random = new XORShiftRandom(srcBlockId)
+ val factors = Array.fill(inBlock.srcIds.size) {
+ val factor = Array.fill(rank)(random.nextGaussian().toFloat)
+ val nrm = blas.snrm2(rank, factor, 1)
+ blas.sscal(rank, 1.0f / nrm, factor, 1)
+ factor
+ }
+ (srcBlockId, factors)
+ }
+ }
+
+ /**
+ * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays.
+ */
+ private[recommendation]
+ case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) {
+ /** Size of the block. */
+ val size: Int = srcIds.size
+ require(dstIds.size == size)
+ require(ratings.size == size)
+ }
+
+ /**
+ * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing.
+ */
+ private[recommendation] class RatingBlockBuilder extends Serializable {
+
+ private val srcIds = mutable.ArrayBuilder.make[Int]
+ private val dstIds = mutable.ArrayBuilder.make[Int]
+ private val ratings = mutable.ArrayBuilder.make[Float]
+ var size = 0
+
+ /** Adds a rating. */
+ def add(r: Rating): this.type = {
+ size += 1
+ srcIds += r.user
+ dstIds += r.item
+ ratings += r.rating
+ this
+ }
+
+ /** Merges another [[RatingBlockBuilder]]. */
+ def merge(other: RatingBlock): this.type = {
+ size += other.srcIds.size
+ srcIds ++= other.srcIds
+ dstIds ++= other.dstIds
+ ratings ++= other.ratings
+ this
+ }
+
+ /** Builds a [[RatingBlock]]. */
+ def build(): RatingBlock = {
+ RatingBlock(srcIds.result(), dstIds.result(), ratings.result())
+ }
+ }
+
+ /**
+ * Partitions raw ratings into blocks.
+ *
+ * @param ratings raw ratings
+ * @param srcPart partitioner for src IDs
+ * @param dstPart partitioner for dst IDs
+ *
+ * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock)
+ */
+ private def partitionRatings(
+ ratings: RDD[Rating],
+ srcPart: Partitioner,
+ dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = {
+
+ /* The implementation produces the same result as the following but generates less objects.
+
+ ratings.map { r =>
+ ((srcPart.getPartition(r.user), dstPart.getPartition(r.item)), r)
+ }.aggregateByKey(new RatingBlockBuilder)(
+ seqOp = (b, r) => b.add(r),
+ combOp = (b0, b1) => b0.merge(b1.build()))
+ .mapValues(_.build())
+ */
+
+ val numPartitions = srcPart.numPartitions * dstPart.numPartitions
+ ratings.mapPartitions { iter =>
+ val builders = Array.fill(numPartitions)(new RatingBlockBuilder)
+ iter.flatMap { r =>
+ val srcBlockId = srcPart.getPartition(r.user)
+ val dstBlockId = dstPart.getPartition(r.item)
+ val idx = srcBlockId + srcPart.numPartitions * dstBlockId
+ val builder = builders(idx)
+ builder.add(r)
+ if (builder.size >= 2048) { // 2048 * (3 * 4) = 24k
+ builders(idx) = new RatingBlockBuilder
+ Iterator.single(((srcBlockId, dstBlockId), builder.build()))
+ } else {
+ Iterator.empty
+ }
+ } ++ {
+ builders.view.zipWithIndex.filter(_._1.size > 0).map { case (block, idx) =>
+ val srcBlockId = idx % srcPart.numPartitions
+ val dstBlockId = idx / srcPart.numPartitions
+ ((srcBlockId, dstBlockId), block.build())
+ }
+ }
+ }.groupByKey().mapValues { blocks =>
+ val builder = new RatingBlockBuilder
+ blocks.foreach(builder.merge)
+ builder.build()
+ }.setName("ratingBlocks")
+ }
+
+ /**
+ * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples.
+ * @param encoder encoder for dst indices
+ */
+ private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) {
+
+ private val srcIds = mutable.ArrayBuilder.make[Int]
+ private val dstEncodedIndices = mutable.ArrayBuilder.make[Int]
+ private val ratings = mutable.ArrayBuilder.make[Float]
+
+ /**
+ * Adds a dst block of (srcId, dstLocalIndex, rating) tuples.
+ *
+ * @param dstBlockId dst block ID
+ * @param srcIds original src IDs
+ * @param dstLocalIndices dst local indices
+ * @param ratings ratings
+ */
+ def add(
+ dstBlockId: Int,
+ srcIds: Array[Int],
+ dstLocalIndices: Array[Int],
+ ratings: Array[Float]): this.type = {
+ val sz = srcIds.size
+ require(dstLocalIndices.size == sz)
+ require(ratings.size == sz)
+ this.srcIds ++= srcIds
+ this.ratings ++= ratings
+ var j = 0
+ while (j < sz) {
+ this.dstEncodedIndices += encoder.encode(dstBlockId, dstLocalIndices(j))
+ j += 1
+ }
+ this
+ }
+
+ /** Builds a [[UncompressedInBlock]]. */
+ def build(): UncompressedInBlock = {
+ new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result())
+ }
+ }
+
+ /**
+ * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays.
+ */
+ private[recommendation] class UncompressedInBlock(
+ val srcIds: Array[Int],
+ val dstEncodedIndices: Array[Int],
+ val ratings: Array[Float]) {
+
+ /** Size the of block. */
+ def size: Int = srcIds.size
+
+ /**
+ * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a
+ * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format.
+ * Sorting is done using Spark's built-in Timsort to avoid generating too many objects.
+ */
+ def compress(): InBlock = {
+ val sz = size
+ assert(sz > 0, "Empty in-link block should not exist.")
+ sort()
+ val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int]
+ val dstCountsBuilder = mutable.ArrayBuilder.make[Int]
+ var preSrcId = srcIds(0)
+ uniqueSrcIdsBuilder += preSrcId
+ var curCount = 1
+ var i = 1
+ var j = 0
+ while (i < sz) {
+ val srcId = srcIds(i)
+ if (srcId != preSrcId) {
+ uniqueSrcIdsBuilder += srcId
+ dstCountsBuilder += curCount
+ preSrcId = srcId
+ j += 1
+ curCount = 0
+ }
+ curCount += 1
+ i += 1
+ }
+ dstCountsBuilder += curCount
+ val uniqueSrcIds = uniqueSrcIdsBuilder.result()
+ val numUniqueSrdIds = uniqueSrcIds.size
+ val dstCounts = dstCountsBuilder.result()
+ val dstPtrs = new Array[Int](numUniqueSrdIds + 1)
+ var sum = 0
+ i = 0
+ while (i < numUniqueSrdIds) {
+ sum += dstCounts(i)
+ i += 1
+ dstPtrs(i) = sum
+ }
+ InBlock(uniqueSrcIds, dstPtrs, dstEncodedIndices, ratings)
+ }
+
+ private def sort(): Unit = {
+ val sz = size
+ // Since there might be interleaved log messages, we insert a unique id for easy pairing.
+ val sortId = Utils.random.nextInt()
+ logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)")
+ val start = System.nanoTime()
+ val sorter = new Sorter(new UncompressedInBlockSort)
+ sorter.sort(this, 0, size, Ordering[IntWrapper])
+ val duration = (System.nanoTime() - start) / 1e9
+ logDebug(s"Sorting took $duration seconds. (sortId = $sortId)")
+ }
+ }
+
+ /**
+ * A wrapper that holds a primitive integer key.
+ *
+ * @see [[UncompressedInBlockSort]]
+ */
+ private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] {
+ override def compare(that: IntWrapper): Int = {
+ key.compare(that.key)
+ }
+ }
+
+ /**
+ * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]].
+ */
+ private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] {
+
+ override def newKey(): IntWrapper = new IntWrapper()
+
+ override def getKey(
+ data: UncompressedInBlock,
+ pos: Int,
+ reuse: IntWrapper): IntWrapper = {
+ if (reuse == null) {
+ new IntWrapper(data.srcIds(pos))
+ } else {
+ reuse.key = data.srcIds(pos)
+ reuse
+ }
+ }
+
+ override def getKey(
+ data: UncompressedInBlock,
+ pos: Int): IntWrapper = {
+ getKey(data, pos, null)
+ }
+
+ private def swapElements[@specialized(Int, Float) T](
+ data: Array[T],
+ pos0: Int,
+ pos1: Int): Unit = {
+ val tmp = data(pos0)
+ data(pos0) = data(pos1)
+ data(pos1) = tmp
+ }
+
+ override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = {
+ swapElements(data.srcIds, pos0, pos1)
+ swapElements(data.dstEncodedIndices, pos0, pos1)
+ swapElements(data.ratings, pos0, pos1)
+ }
+
+ override def copyRange(
+ src: UncompressedInBlock,
+ srcPos: Int,
+ dst: UncompressedInBlock,
+ dstPos: Int,
+ length: Int): Unit = {
+ System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length)
+ System.arraycopy(src.dstEncodedIndices, srcPos, dst.dstEncodedIndices, dstPos, length)
+ System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length)
+ }
+
+ override def allocate(length: Int): UncompressedInBlock = {
+ new UncompressedInBlock(
+ new Array[Int](length), new Array[Int](length), new Array[Float](length))
+ }
+
+ override def copyElement(
+ src: UncompressedInBlock,
+ srcPos: Int,
+ dst: UncompressedInBlock,
+ dstPos: Int): Unit = {
+ dst.srcIds(dstPos) = src.srcIds(srcPos)
+ dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos)
+ dst.ratings(dstPos) = src.ratings(srcPos)
+ }
+ }
+
+ /**
+ * Creates in-blocks and out-blocks from rating blocks.
+ * @param prefix prefix for in/out-block names
+ * @param ratingBlocks rating blocks
+ * @param srcPart partitioner for src IDs
+ * @param dstPart partitioner for dst IDs
+ * @return (in-blocks, out-blocks)
+ */
+ private def makeBlocks(
+ prefix: String,
+ ratingBlocks: RDD[((Int, Int), RatingBlock)],
+ srcPart: Partitioner,
+ dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = {
+ val inBlocks = ratingBlocks.map {
+ case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) =>
+ // The implementation is a faster version of
+ // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap
+ val start = System.nanoTime()
+ val dstIdSet = new OpenHashSet[Int](1 << 20)
+ dstIds.foreach(dstIdSet.add)
+ val sortedDstIds = new Array[Int](dstIdSet.size)
+ var i = 0
+ var pos = dstIdSet.nextPos(0)
+ while (pos != -1) {
+ sortedDstIds(i) = dstIdSet.getValue(pos)
+ pos = dstIdSet.nextPos(pos + 1)
+ i += 1
+ }
+ assert(i == dstIdSet.size)
+ ju.Arrays.sort(sortedDstIds)
+ val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size)
+ i = 0
+ while (i < sortedDstIds.size) {
+ dstIdToLocalIndex.update(sortedDstIds(i), i)
+ i += 1
+ }
+ logDebug(
+ "Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.")
+ val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
+ (srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
+ }.groupByKey(new HashPartitioner(srcPart.numPartitions))
+ .mapValues { iter =>
+ val builder =
+ new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions))
+ iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
+ builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
+ }
+ builder.build().compress()
+ }.setName(prefix + "InBlocks").cache()
+ val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
+ val encoder = new LocalIndexEncoder(dstPart.numPartitions)
+ val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int])
+ var i = 0
+ val seen = new Array[Boolean](dstPart.numPartitions)
+ while (i < srcIds.size) {
+ var j = dstPtrs(i)
+ ju.Arrays.fill(seen, false)
+ while (j < dstPtrs(i + 1)) {
+ val dstBlockId = encoder.blockId(dstEncodedIndices(j))
+ if (!seen(dstBlockId)) {
+ activeIds(dstBlockId) += i // add the local index in this out-block
+ seen(dstBlockId) = true
+ }
+ j += 1
+ }
+ i += 1
+ }
+ activeIds.map { x =>
+ x.result()
+ }
+ }.setName(prefix + "OutBlocks").cache()
+ (inBlocks, outBlocks)
+ }
+
+ /**
+ * Compute dst factors by constructing and solving least square problems.
+ *
+ * @param srcFactorBlocks src factors
+ * @param srcOutBlocks src out-blocks
+ * @param dstInBlocks dst in-blocks
+ * @param rank rank
+ * @param regParam regularization constant
+ * @param srcEncoder encoder for src local indices
+ * @param implicitPrefs whether to use implicit preference
+ * @param alpha the alpha constant in the implicit preference formulation
+ *
+ * @return dst factors
+ */
+ private def computeFactors(
+ srcFactorBlocks: RDD[(Int, FactorBlock)],
+ srcOutBlocks: RDD[(Int, OutBlock)],
+ dstInBlocks: RDD[(Int, InBlock)],
+ rank: Int,
+ regParam: Double,
+ srcEncoder: LocalIndexEncoder,
+ implicitPrefs: Boolean = false,
+ alpha: Double = 1.0): RDD[(Int, FactorBlock)] = {
+ val numSrcBlocks = srcFactorBlocks.partitions.size
+ val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None
+ val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap {
+ case (srcBlockId, (srcOutBlock, srcFactors)) =>
+ srcOutBlock.view.zipWithIndex.map { case (activeIndices, dstBlockId) =>
+ (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
+ }
+ }
+ val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size))
+ dstInBlocks.join(merged).mapValues {
+ case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
+ val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
+ srcFactors.foreach { case (srcBlockId, factors) =>
+ sortedSrcFactors(srcBlockId) = factors
+ }
+ val dstFactors = new Array[Array[Float]](dstIds.size)
+ var j = 0
+ val ls = new NormalEquation(rank)
+ val solver = new CholeskySolver // TODO: add NNLS solver
+ while (j < dstIds.size) {
+ ls.reset()
+ if (implicitPrefs) {
+ ls.merge(YtY.get)
+ }
+ var i = srcPtrs(j)
+ while (i < srcPtrs(j + 1)) {
+ val encoded = srcEncodedIndices(i)
+ val blockId = srcEncoder.blockId(encoded)
+ val localIndex = srcEncoder.localIndex(encoded)
+ val srcFactor = sortedSrcFactors(blockId)(localIndex)
+ val rating = ratings(i)
+ if (implicitPrefs) {
+ ls.addImplicit(srcFactor, rating, alpha)
+ } else {
+ ls.add(srcFactor, rating)
+ }
+ i += 1
+ }
+ dstFactors(j) = solver.solve(ls, regParam)
+ j += 1
+ }
+ dstFactors
+ }
+ }
+
+ /**
+ * Computes the Gramian matrix of user or item factors, which is only used in implicit preference.
+ * Caching of the input factors is handled in [[ALS#train]].
+ */
+ private def computeYtY(factorBlocks: RDD[(Int, FactorBlock)], rank: Int): NormalEquation = {
+ factorBlocks.values.aggregate(new NormalEquation(rank))(
+ seqOp = (ne, factors) => {
+ factors.foreach(ne.add(_, 0.0f))
+ ne
+ },
+ combOp = (ne1, ne2) => ne1.merge(ne2))
+ }
+
+ /**
+ * Encoder for storing (blockId, localIndex) into a single integer.
+ *
+ * We use the leading bits (including the sign bit) to store the block id and the rest to store
+ * the local index. This is based on the assumption that users/items are approximately evenly
+ * partitioned. With this assumption, we should be able to encode two billion distinct values.
+ *
+ * @param numBlocks number of blocks
+ */
+ private[recommendation] class LocalIndexEncoder(numBlocks: Int) extends Serializable {
+
+ require(numBlocks > 0, s"numBlocks must be positive but found $numBlocks.")
+
+ private[this] final val numLocalIndexBits =
+ math.min(java.lang.Integer.numberOfLeadingZeros(numBlocks - 1), 31)
+ private[this] final val localIndexMask = (1 << numLocalIndexBits) - 1
+
+ /** Encodes a (blockId, localIndex) into a single integer. */
+ def encode(blockId: Int, localIndex: Int): Int = {
+ require(blockId < numBlocks)
+ require((localIndex & ~localIndexMask) == 0)
+ (blockId << numLocalIndexBits) | localIndex
+ }
+
+ /** Gets the block id from an encoded index. */
+ @inline
+ def blockId(encoded: Int): Int = {
+ encoded >>> numLocalIndexBits
+ }
+
+ /** Gets the local index from an encoded index. */
+ @inline
+ def localIndex(encoded: Int): Int = {
+ encoded & localIndexMask
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 6b5c934f015ba..11633e8242313 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -279,45 +279,81 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[VectorWithNorm])
: Array[Array[VectorWithNorm]] = {
- // Initialize each run's center to a random point
+ // Initialize empty centers and point costs.
+ val centers = Array.tabulate(runs)(r => ArrayBuffer.empty[VectorWithNorm])
+ var costs = data.map(_ => Vectors.dense(Array.fill(runs)(Double.PositiveInfinity))).cache()
+
+ // Initialize each run's first center to a random point.
val seed = new XORShiftRandom(this.seed).nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
- val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+ val newCenters = Array.tabulate(runs)(r => ArrayBuffer(sample(r).toDense))
+
+ /** Merges new centers to centers. */
+ def mergeNewCenters(): Unit = {
+ var r = 0
+ while (r < runs) {
+ centers(r) ++= newCenters(r)
+ newCenters(r).clear()
+ r += 1
+ }
+ }
// On each step, sample 2 * k points on average for each run with probability proportional
- // to their squared distance from that run's current centers
+ // to their squared distance from that run's centers. Note that only distances between points
+ // and new centers are computed in each iteration.
var step = 0
while (step < initializationSteps) {
- val bcCenters = data.context.broadcast(centers)
- val sumCosts = data.flatMap { point =>
- (0 until runs).map { r =>
- (r, KMeans.pointCost(bcCenters.value(r), point))
- }
- }.reduceByKey(_ + _).collectAsMap()
- val chosen = data.mapPartitionsWithIndex { (index, points) =>
+ val bcNewCenters = data.context.broadcast(newCenters)
+ val preCosts = costs
+ costs = data.zip(preCosts).map { case (point, cost) =>
+ Vectors.dense(
+ Array.tabulate(runs) { r =>
+ math.min(KMeans.pointCost(bcNewCenters.value(r), point), cost(r))
+ })
+ }.cache()
+ val sumCosts = costs
+ .aggregate(Vectors.zeros(runs))(
+ seqOp = (s, v) => {
+ // s += v
+ axpy(1.0, v, s)
+ s
+ },
+ combOp = (s0, s1) => {
+ // s0 += s1
+ axpy(1.0, s1, s0)
+ s0
+ }
+ )
+ preCosts.unpersist(blocking = false)
+ val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) =>
val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
- points.flatMap { p =>
- (0 until runs).filter { r =>
- rand.nextDouble() < 2.0 * KMeans.pointCost(bcCenters.value(r), p) * k / sumCosts(r)
- }.map((_, p))
+ pointsWithCosts.flatMap { case (p, c) =>
+ val rs = (0 until runs).filter { r =>
+ rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r)
+ }
+ if (rs.length > 0) Some(p, rs) else None
}
}.collect()
- chosen.foreach { case (r, p) =>
- centers(r) += p.toDense
+ mergeNewCenters()
+ chosen.foreach { case (p, rs) =>
+ rs.foreach(newCenters(_) += p.toDense)
}
step += 1
}
+ mergeNewCenters()
+ costs.unpersist(blocking = false)
+
// Finally, we might have a set of more than k candidate centers for each run; weigh each
// candidate by the number of points in the dataset mapping to it and run a local k-means++
// on the weighted centers to pick just k of them
val bcCenters = data.context.broadcast(centers)
val weightMap = data.flatMap { p =>
- (0 until runs).map { r =>
+ Iterator.tabulate(runs) { r =>
((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0)
}
}.reduceByKey(_ + _).collectAsMap()
- val finalCenters = (0 until runs).map { r =>
+ val finalCenters = (0 until runs).par.map { r =>
val myCenters = centers(r).toArray
val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray
LocalKMeans.kMeansPlusPlus(r, myCenters, myWeights, k, 30)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 7ee0224ad4662..b3022add38469 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -333,7 +333,7 @@ object Vectors {
math.pow(sum, 1.0 / p)
}
}
-
+
/**
* Returns the squared distance between two Vectors.
* @param v1 first Vector.
@@ -341,8 +341,9 @@ object Vectors {
* @return squared distance between two Vectors.
*/
def sqdist(v1: Vector, v2: Vector): Double = {
+ require(v1.size == v2.size, "vector dimension mismatch")
var squaredDistance = 0.0
- (v1, v2) match {
+ (v1, v2) match {
case (v1: SparseVector, v2: SparseVector) =>
val v1Values = v1.values
val v1Indices = v1.indices
@@ -350,12 +351,12 @@ object Vectors {
val v2Indices = v2.indices
val nnzv1 = v1Indices.size
val nnzv2 = v2Indices.size
-
+
var kv1 = 0
var kv2 = 0
while (kv1 < nnzv1 || kv2 < nnzv2) {
var score = 0.0
-
+
if (kv2 >= nnzv2 || (kv1 < nnzv1 && v1Indices(kv1) < v2Indices(kv2))) {
score = v1Values(kv1)
kv1 += 1
@@ -397,7 +398,7 @@ object Vectors {
val nnzv1 = indices.size
val nnzv2 = v2.size
var iv1 = if (nnzv1 > 0) indices(kv1) else -1
-
+
while (kv2 < nnzv2) {
var score = 0.0
if (kv2 != iv1) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index bee951a2e5e26..5f84677be238d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -90,7 +90,7 @@ case class Rating(user: Int, product: Int, rating: Double)
*
* Essentially instead of finding the low-rank approximations to the rating matrix `R`,
* this finds the approximations for a preference matrix `P` where the elements of `P` are 1 if
- * r > 0 and 0 if r = 0. The ratings then act as 'confidence' values related to strength of
+ * r > 0 and 0 if r <= 0. The ratings then act as 'confidence' values related to strength of
* indicated user
* preferences rather than explicit ratings given to items.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index e9304b5e5c650..482dd4b272d1d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -140,6 +140,7 @@ private class RandomForest (
logDebug("maxBins = " + metadata.maxBins)
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
+ logDebug("subsamplingRate = " + strategy.subsamplingRate)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
@@ -155,19 +156,12 @@ private class RandomForest (
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
- val (subsample, withReplacement) = {
- // TODO: Have a stricter check for RF in the strategy
- val isRandomForest = numTrees > 1
- if (isRandomForest) {
- (1.0, true)
- } else {
- (strategy.subsamplingRate, false)
- }
- }
+ val withReplacement = if (numTrees > 1) true else false
val baggedInput
- = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
- .persist(StorageLevel.MEMORY_AND_DISK)
+ = BaggedPoint.convertToBaggedRDD(treeInput,
+ strategy.subsamplingRate, numTrees,
+ withReplacement, seed).persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
val maxDepth = strategy.maxDepth
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index cf51d041c65a9..ed8e6a796f8c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -68,6 +68,15 @@ case class BoostingStrategy(
@Experimental
object BoostingStrategy {
+ /**
+ * Returns default configuration for the boosting algorithm
+ * @param algo Learning goal. Supported: "Classification" or "Regression"
+ * @return Configuration for boosting algorithm
+ */
+ def defaultParams(algo: String): BoostingStrategy = {
+ defaultParams(Algo.fromString(algo))
+ }
+
/**
* Returns default configuration for the boosting algorithm
* @param algo Learning goal. Supported:
@@ -75,15 +84,15 @@ object BoostingStrategy {
* [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
* @return Configuration for boosting algorithm
*/
- def defaultParams(algo: String): BoostingStrategy = {
- val treeStrategy = Strategy.defaultStrategy(algo)
- treeStrategy.maxDepth = 3
+ def defaultParams(algo: Algo): BoostingStrategy = {
+ val treeStragtegy = Strategy.defaultStategy(algo)
+ treeStragtegy.maxDepth = 3
algo match {
- case "Classification" =>
- treeStrategy.numClasses = 2
- new BoostingStrategy(treeStrategy, LogLoss)
- case "Regression" =>
- new BoostingStrategy(treeStrategy, SquaredError)
+ case Algo.Classification =>
+ treeStragtegy.numClasses = 2
+ new BoostingStrategy(treeStragtegy, LogLoss)
+ case Algo.Regression =>
+ new BoostingStrategy(treeStragtegy, SquaredError)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by boosting.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index d5cd89ab94e81..3308adb6752ff 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -156,6 +156,9 @@ class Strategy (
s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
require(maxMemoryInMB <= 10240,
s"DecisionTree Strategy requires maxMemoryInMB <= 10240, but was given $maxMemoryInMB")
+ require(subsamplingRate > 0 && subsamplingRate <= 1,
+ s"DecisionTree Strategy requires subsamplingRate <=1 and >0, but was given " +
+ s"$subsamplingRate")
}
/** Returns a shallow copy of this instance. */
@@ -173,11 +176,19 @@ object Strategy {
* Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
* @param algo "Classification" or "Regression"
*/
- def defaultStrategy(algo: String): Strategy = algo match {
- case "Classification" =>
+ def defaultStrategy(algo: String): Strategy = {
+ defaultStategy(Algo.fromString(algo))
+ }
+
+ /**
+ * Construct a default set of parameters for [[org.apache.spark.mllib.tree.DecisionTree]]
+ * @param algo Algo.Classification or Algo.Regression
+ */
+ def defaultStategy(algo: Algo): Strategy = algo match {
+ case Algo.Classification =>
new Strategy(algo = Classification, impurity = Gini, maxDepth = 10,
numClasses = 2)
- case "Regression" =>
+ case Algo.Regression =>
new Strategy(algo = Regression, impurity = Variance, maxDepth = 10,
numClasses = 0)
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 0e02345aa3774..b7950e00786ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -94,6 +94,10 @@ private[tree] class EntropyAggregator(numClasses: Int)
throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
+ if (label < 0) {
+ throw new IllegalArgumentException(s"EntropyAggregator given label $label" +
+ s"but requires label is non-negative.")
+ }
allStats(offset + label.toInt) += instanceWeight
}
@@ -147,6 +151,7 @@ private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalc
val lbl = label.toInt
require(lbl < stats.length,
s"EntropyCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ require(lbl >= 0, "Entropy does not support negative labels")
val cnt = count
if (cnt == 0) {
0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 7c83cd48e16a0..c946db9c0d1c8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -90,6 +90,10 @@ private[tree] class GiniAggregator(numClasses: Int)
throw new IllegalArgumentException(s"GiniAggregator given label $label" +
s" but requires label < numClasses (= $statsSize).")
}
+ if (label < 0) {
+ throw new IllegalArgumentException(s"GiniAggregator given label $label" +
+ s"but requires label is non-negative.")
+ }
allStats(offset + label.toInt) += instanceWeight
}
@@ -143,6 +147,7 @@ private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalcula
val lbl = label.toInt
require(lbl < stats.length,
s"GiniCalculator.prob given invalid label: $lbl (should be < ${stats.length}")
+ require(lbl >= 0, "GiniImpurity does not support negative labels")
val cnt = count
if (cnt == 0) {
0
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
new file mode 100644
index 0000000000000..cdd4db1b5b7dc
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -0,0 +1,435 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.recommendation
+
+import java.util.Random
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.scalatest.FunSuite
+
+import org.apache.spark.Logging
+import org.apache.spark.ml.recommendation.ALS._
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
+
+class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
+
+ private var sqlContext: SQLContext = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ }
+
+ test("LocalIndexEncoder") {
+ val random = new Random
+ for (numBlocks <- Seq(1, 2, 5, 10, 20, 50, 100)) {
+ val encoder = new LocalIndexEncoder(numBlocks)
+ val maxLocalIndex = Int.MaxValue / numBlocks
+ val tests = Seq.fill(5)((random.nextInt(numBlocks), random.nextInt(maxLocalIndex))) ++
+ Seq((0, 0), (numBlocks - 1, maxLocalIndex))
+ tests.foreach { case (blockId, localIndex) =>
+ val err = s"Failed with numBlocks=$numBlocks, blockId=$blockId, and localIndex=$localIndex."
+ val encoded = encoder.encode(blockId, localIndex)
+ assert(encoder.blockId(encoded) === blockId, err)
+ assert(encoder.localIndex(encoded) === localIndex, err)
+ }
+ }
+ }
+
+ test("normal equation construction with explict feedback") {
+ val k = 2
+ val ne0 = new NormalEquation(k)
+ .add(Array(1.0f, 2.0f), 3.0f)
+ .add(Array(4.0f, 5.0f), 6.0f)
+ assert(ne0.k === k)
+ assert(ne0.triK === k * (k + 1) / 2)
+ assert(ne0.n === 2)
+ // NumPy code that computes the expected values:
+ // A = np.matrix("1 2; 4 5")
+ // b = np.matrix("3; 6")
+ // ata = A.transpose() * A
+ // atb = A.transpose() * b
+ assert(Vectors.dense(ne0.ata) ~== Vectors.dense(17.0, 22.0, 29.0) relTol 1e-8)
+ assert(Vectors.dense(ne0.atb) ~== Vectors.dense(27.0, 36.0) relTol 1e-8)
+
+ val ne1 = new NormalEquation(2)
+ .add(Array(7.0f, 8.0f), 9.0f)
+ ne0.merge(ne1)
+ assert(ne0.n === 3)
+ // NumPy code that computes the expected values:
+ // A = np.matrix("1 2; 4 5; 7 8")
+ // b = np.matrix("3; 6; 9")
+ // ata = A.transpose() * A
+ // atb = A.transpose() * b
+ assert(Vectors.dense(ne0.ata) ~== Vectors.dense(66.0, 78.0, 93.0) relTol 1e-8)
+ assert(Vectors.dense(ne0.atb) ~== Vectors.dense(90.0, 108.0) relTol 1e-8)
+
+ intercept[IllegalArgumentException] {
+ ne0.add(Array(1.0f), 2.0f)
+ }
+ intercept[IllegalArgumentException] {
+ ne0.add(Array(1.0f, 2.0f, 3.0f), 4.0f)
+ }
+ intercept[IllegalArgumentException] {
+ val ne2 = new NormalEquation(3)
+ ne0.merge(ne2)
+ }
+
+ ne0.reset()
+ assert(ne0.n === 0)
+ assert(ne0.ata.forall(_ == 0.0))
+ assert(ne0.atb.forall(_ == 0.0))
+ }
+
+ test("normal equation construction with implicit feedback") {
+ val k = 2
+ val alpha = 0.5
+ val ne0 = new NormalEquation(k)
+ .addImplicit(Array(-5.0f, -4.0f), -3.0f, alpha)
+ .addImplicit(Array(-2.0f, -1.0f), 0.0f, alpha)
+ .addImplicit(Array(1.0f, 2.0f), 3.0f, alpha)
+ assert(ne0.k === k)
+ assert(ne0.triK === k * (k + 1) / 2)
+ assert(ne0.n === 0) // addImplicit doesn't increase the count.
+ // NumPy code that computes the expected values:
+ // alpha = 0.5
+ // A = np.matrix("-5 -4; -2 -1; 1 2")
+ // b = np.matrix("-3; 0; 3")
+ // b1 = b > 0
+ // c = 1.0 + alpha * np.abs(b)
+ // C = np.diag(c.A1)
+ // I = np.eye(3)
+ // ata = A.transpose() * (C - I) * A
+ // atb = A.transpose() * C * b1
+ assert(Vectors.dense(ne0.ata) ~== Vectors.dense(39.0, 33.0, 30.0) relTol 1e-8)
+ assert(Vectors.dense(ne0.atb) ~== Vectors.dense(2.5, 5.0) relTol 1e-8)
+ }
+
+ test("CholeskySolver") {
+ val k = 2
+ val ne0 = new NormalEquation(k)
+ .add(Array(1.0f, 2.0f), 4.0f)
+ .add(Array(1.0f, 3.0f), 9.0f)
+ .add(Array(1.0f, 4.0f), 16.0f)
+ val ne1 = new NormalEquation(k)
+ .merge(ne0)
+
+ val chol = new CholeskySolver
+ val x0 = chol.solve(ne0, 0.0).map(_.toDouble)
+ // NumPy code that computes the expected solution:
+ // A = np.matrix("1 2; 1 3; 1 4")
+ // b = b = np.matrix("3; 6")
+ // x0 = np.linalg.lstsq(A, b)[0]
+ assert(Vectors.dense(x0) ~== Vectors.dense(-8.333333, 6.0) relTol 1e-6)
+
+ assert(ne0.n === 0)
+ assert(ne0.ata.forall(_ == 0.0))
+ assert(ne0.atb.forall(_ == 0.0))
+
+ val x1 = chol.solve(ne1, 0.5).map(_.toDouble)
+ // NumPy code that computes the expected solution, where lambda is scaled by n:
+ // x0 = np.linalg.solve(A.transpose() * A + 0.5 * 3 * np.eye(2), A.transpose() * b)
+ assert(Vectors.dense(x1) ~== Vectors.dense(-0.1155556, 3.28) relTol 1e-6)
+ }
+
+ test("RatingBlockBuilder") {
+ val emptyBuilder = new RatingBlockBuilder()
+ assert(emptyBuilder.size === 0)
+ val emptyBlock = emptyBuilder.build()
+ assert(emptyBlock.srcIds.isEmpty)
+ assert(emptyBlock.dstIds.isEmpty)
+ assert(emptyBlock.ratings.isEmpty)
+
+ val builder0 = new RatingBlockBuilder()
+ .add(Rating(0, 1, 2.0f))
+ .add(Rating(3, 4, 5.0f))
+ assert(builder0.size === 2)
+ val builder1 = new RatingBlockBuilder()
+ .add(Rating(6, 7, 8.0f))
+ .merge(builder0.build())
+ assert(builder1.size === 3)
+ val block = builder1.build()
+ val ratings = Seq.tabulate(block.size) { i =>
+ (block.srcIds(i), block.dstIds(i), block.ratings(i))
+ }.toSet
+ assert(ratings === Set((0, 1, 2.0f), (3, 4, 5.0f), (6, 7, 8.0f)))
+ }
+
+ test("UncompressedInBlock") {
+ val encoder = new LocalIndexEncoder(10)
+ val uncompressed = new UncompressedInBlockBuilder(encoder)
+ .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f))
+ .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f))
+ .build()
+ assert(uncompressed.size === 5)
+ val records = Seq.tabulate(uncompressed.size) { i =>
+ val dstEncodedIndex = uncompressed.dstEncodedIndices(i)
+ val dstBlockId = encoder.blockId(dstEncodedIndex)
+ val dstLocalIndex = encoder.localIndex(dstEncodedIndex)
+ (uncompressed.srcIds(i), dstBlockId, dstLocalIndex, uncompressed.ratings(i))
+ }.toSet
+ val expected =
+ Set((1, 0, 0, 1.0f), (0, 0, 1, 2.0f), (2, 0, 4, 3.0f), (3, 1, 2, 4.0f), (0, 1, 5, 5.0f))
+ assert(records === expected)
+
+ val compressed = uncompressed.compress()
+ assert(compressed.size === 5)
+ assert(compressed.srcIds.toSeq === Seq(0, 1, 2, 3))
+ assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5))
+ var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)]
+ var i = 0
+ while (i < compressed.srcIds.size) {
+ var j = compressed.dstPtrs(i)
+ while (j < compressed.dstPtrs(i + 1)) {
+ val dstEncodedIndex = compressed.dstEncodedIndices(j)
+ val dstBlockId = encoder.blockId(dstEncodedIndex)
+ val dstLocalIndex = encoder.localIndex(dstEncodedIndex)
+ decompressed += ((compressed.srcIds(i), dstBlockId, dstLocalIndex, compressed.ratings(j)))
+ j += 1
+ }
+ i += 1
+ }
+ assert(decompressed.toSet === expected)
+ }
+
+ /**
+ * Generates an explicit feedback dataset for testing ALS.
+ * @param numUsers number of users
+ * @param numItems number of items
+ * @param rank rank
+ * @param noiseStd the standard deviation of additive Gaussian noise on training data
+ * @param seed random seed
+ * @return (training, test)
+ */
+ def genExplicitTestData(
+ numUsers: Int,
+ numItems: Int,
+ rank: Int,
+ noiseStd: Double = 0.0,
+ seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+ val trainingFraction = 0.6
+ val testFraction = 0.3
+ val totalFraction = trainingFraction + testFraction
+ val random = new Random(seed)
+ val userFactors = genFactors(numUsers, rank, random)
+ val itemFactors = genFactors(numItems, rank, random)
+ val training = ArrayBuffer.empty[Rating]
+ val test = ArrayBuffer.empty[Rating]
+ for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
+ val x = random.nextDouble()
+ if (x < totalFraction) {
+ val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
+ if (x < trainingFraction) {
+ val noise = noiseStd * random.nextGaussian()
+ training += Rating(userId, itemId, rating + noise.toFloat)
+ } else {
+ test += Rating(userId, itemId, rating)
+ }
+ }
+ }
+ logInfo(s"Generated an explicit feedback dataset with ${training.size} ratings for training " +
+ s"and ${test.size} for test.")
+ (sc.parallelize(training, 2), sc.parallelize(test, 2))
+ }
+
+ /**
+ * Generates an implicit feedback dataset for testing ALS.
+ * @param numUsers number of users
+ * @param numItems number of items
+ * @param rank rank
+ * @param noiseStd the standard deviation of additive Gaussian noise on training data
+ * @param seed random seed
+ * @return (training, test)
+ */
+ def genImplicitTestData(
+ numUsers: Int,
+ numItems: Int,
+ rank: Int,
+ noiseStd: Double = 0.0,
+ seed: Long = 11L): (RDD[Rating], RDD[Rating]) = {
+ // The assumption of the implicit feedback model is that unobserved ratings are more likely to
+ // be negatives.
+ val positiveFraction = 0.8
+ val negativeFraction = 1.0 - positiveFraction
+ val trainingFraction = 0.6
+ val testFraction = 0.3
+ val totalFraction = trainingFraction + testFraction
+ val random = new Random(seed)
+ val userFactors = genFactors(numUsers, rank, random)
+ val itemFactors = genFactors(numItems, rank, random)
+ val training = ArrayBuffer.empty[Rating]
+ val test = ArrayBuffer.empty[Rating]
+ for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) {
+ val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1)
+ val threshold = if (rating > 0) positiveFraction else negativeFraction
+ val observed = random.nextDouble() < threshold
+ if (observed) {
+ val x = random.nextDouble()
+ if (x < totalFraction) {
+ if (x < trainingFraction) {
+ val noise = noiseStd * random.nextGaussian()
+ training += Rating(userId, itemId, rating + noise.toFloat)
+ } else {
+ test += Rating(userId, itemId, rating)
+ }
+ }
+ }
+ }
+ logInfo(s"Generated an implicit feedback dataset with ${training.size} ratings for training " +
+ s"and ${test.size} for test.")
+ (sc.parallelize(training, 2), sc.parallelize(test, 2))
+ }
+
+ /**
+ * Generates random user/item factors, with i.i.d. values drawn from U(a, b).
+ * @param size number of users/items
+ * @param rank number of features
+ * @param random random number generator
+ * @param a min value of the support (default: -1)
+ * @param b max value of the support (default: 1)
+ * @return a sequence of (ID, factors) pairs
+ */
+ private def genFactors(
+ size: Int,
+ rank: Int,
+ random: Random,
+ a: Float = -1.0f,
+ b: Float = 1.0f): Seq[(Int, Array[Float])] = {
+ require(size > 0 && size < Int.MaxValue / 3)
+ require(b > a)
+ val ids = mutable.Set.empty[Int]
+ while (ids.size < size) {
+ ids += random.nextInt()
+ }
+ val width = b - a
+ ids.toSeq.sorted.map(id => (id, Array.fill(rank)(a + random.nextFloat() * width)))
+ }
+
+ /**
+ * Test ALS using the given training/test splits and parameters.
+ * @param training training dataset
+ * @param test test dataset
+ * @param rank rank of the matrix factorization
+ * @param maxIter max number of iterations
+ * @param regParam regularization constant
+ * @param implicitPrefs whether to use implicit preference
+ * @param numUserBlocks number of user blocks
+ * @param numItemBlocks number of item blocks
+ * @param targetRMSE target test RMSE
+ */
+ def testALS(
+ training: RDD[Rating],
+ test: RDD[Rating],
+ rank: Int,
+ maxIter: Int,
+ regParam: Double,
+ implicitPrefs: Boolean = false,
+ numUserBlocks: Int = 2,
+ numItemBlocks: Int = 3,
+ targetRMSE: Double = 0.05): Unit = {
+ val sqlContext = this.sqlContext
+ import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute}
+ val als = new ALS()
+ .setRank(rank)
+ .setRegParam(regParam)
+ .setImplicitPrefs(implicitPrefs)
+ .setNumUserBlocks(numUserBlocks)
+ .setNumItemBlocks(numItemBlocks)
+ val alpha = als.getAlpha
+ val model = als.fit(training)
+ val predictions = model.transform(test)
+ .select('rating, 'prediction)
+ .map { case Row(rating: Float, prediction: Float) =>
+ (rating.toDouble, prediction.toDouble)
+ }
+ val rmse =
+ if (implicitPrefs) {
+ // TODO: Use a better (rank-based?) evaluation metric for implicit feedback.
+ // We limit the ratings and the predictions to interval [0, 1] and compute the weighted RMSE
+ // with the confidence scores as weights.
+ val (totalWeight, weightedSumSq) = predictions.map { case (rating, prediction) =>
+ val confidence = 1.0 + alpha * math.abs(rating)
+ val rating01 = math.max(math.min(rating, 1.0), 0.0)
+ val prediction01 = math.max(math.min(prediction, 1.0), 0.0)
+ val err = prediction01 - rating01
+ (confidence, confidence * err * err)
+ }.reduce { case ((c0, e0), (c1, e1)) =>
+ (c0 + c1, e0 + e1)
+ }
+ math.sqrt(weightedSumSq / totalWeight)
+ } else {
+ val mse = predictions.map { case (rating, prediction) =>
+ val err = rating - prediction
+ err * err
+ }.mean()
+ math.sqrt(mse)
+ }
+ logInfo(s"Test RMSE is $rmse.")
+ assert(rmse < targetRMSE)
+ }
+
+ test("exact rank-1 matrix") {
+ val (training, test) = genExplicitTestData(numUsers = 20, numItems = 40, rank = 1)
+ testALS(training, test, maxIter = 1, rank = 1, regParam = 1e-5, targetRMSE = 0.001)
+ testALS(training, test, maxIter = 1, rank = 2, regParam = 1e-5, targetRMSE = 0.001)
+ }
+
+ test("approximate rank-1 matrix") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 20, numItems = 40, rank = 1, noiseStd = 0.01)
+ testALS(training, test, maxIter = 2, rank = 1, regParam = 0.01, targetRMSE = 0.02)
+ testALS(training, test, maxIter = 2, rank = 2, regParam = 0.01, targetRMSE = 0.02)
+ }
+
+ test("approximate rank-2 matrix") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03)
+ testALS(training, test, maxIter = 4, rank = 3, regParam = 0.01, targetRMSE = 0.03)
+ }
+
+ test("different block settings") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ for ((numUserBlocks, numItemBlocks) <- Seq((1, 1), (1, 2), (2, 1), (2, 2))) {
+ testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, targetRMSE = 0.03,
+ numUserBlocks = numUserBlocks, numItemBlocks = numItemBlocks)
+ }
+ }
+
+ test("more blocks than ratings") {
+ val (training, test) =
+ genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+ testALS(training, test, maxIter = 2, rank = 1, regParam = 1e-4, targetRMSE = 0.002,
+ numItemBlocks = 5, numUserBlocks = 5)
+ }
+
+ test("implicit feedback") {
+ val (training, test) =
+ genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true,
+ targetRMSE = 0.3)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index f3b7bfda788fa..e9fc37e000526 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -215,7 +215,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext {
* @param samplingRate what fraction of the user-product pairs are known
* @param matchThreshold max difference allowed to consider a predicted rating correct
* @param implicitPrefs flag to test implicit feedback
- * @param bulkPredict flag to test bulk prediciton
+ * @param bulkPredict flag to test bulk predicition
* @param negativeWeights whether the generated data can contain negative values
* @param numUserBlocks number of user blocks to partition users into
* @param numProductBlocks number of product blocks to partition products into
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
new file mode 100644
index 0000000000000..92b498580af03
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/ImpuritySuite.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.tree.impurity.{EntropyAggregator, GiniAggregator}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+/**
+ * Test suites for [[GiniAggregator]] and [[EntropyAggregator]].
+ */
+class ImpuritySuite extends FunSuite with MLlibTestSparkContext {
+ test("Gini impurity does not support negative labels") {
+ val gini = new GiniAggregator(2)
+ intercept[IllegalArgumentException] {
+ gini.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ }
+ }
+
+ test("Entropy does not support negative labels") {
+ val entropy = new EntropyAggregator(2)
+ intercept[IllegalArgumentException] {
+ entropy.update(Array(0.0, 1.0, 2.0), 0, -1, 0.0)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index f7f0f20c6c125..55e963977b54f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -196,6 +196,22 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
featureSubsetStrategy = "sqrt", seed = 12345)
EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
+
+ test("subsampling rate in RandomForest"){
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(5, 20)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClasses = 2, categoricalFeaturesInfo = Map.empty[Int, Int],
+ useNodeIdCache = true)
+
+ val rf1 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+ featureSubsetStrategy = "auto", seed = 123)
+ strategy.subsamplingRate = 0.5
+ val rf2 = RandomForest.trainClassifier(rdd, strategy, numTrees = 3,
+ featureSubsetStrategy = "auto", seed = 123)
+ assert(rf1.toDebugString != rf2.toDebugString)
+ }
+
}
diff --git a/pom.xml b/pom.xml
index b993391b15042..05cb3797fc55b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -117,7 +117,7 @@
2.0.1
0.21.0
shaded-protobuf
- 1.7.5
+ 1.7.10
1.2.17
1.0.4
2.4.1
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 127973b658190..bc5d81f12d746 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -90,6 +90,10 @@ object MimaExcludes {
// SPARK-5297 Java FileStream do not work with custom key/values
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.streaming.api.java.JavaStreamingContext.fileStream")
+ ) ++ Seq(
+ // SPARK-5315 Spark Streaming Java API returns Scala DStream
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.api.java.JavaDStreamLike.reduceByWindow")
)
case v if v.startsWith("1.2") =>
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 64f6a3ca6bf4c..568e21f3803bf 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -229,6 +229,14 @@ def _ensure_initialized(cls, instance=None, gateway=None):
else:
SparkContext._active_spark_context = instance
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle SparkContext, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to reference SparkContext from a broadcast "
+ "variable, action, or transforamtion. SparkContext can only be used on the driver, "
+ "not in code that it run on workers. For more information, see SPARK-5063."
+ )
+
def __enter__(self):
"""
Enable 'with SparkContext(...) as sc: app(sc)' syntax.
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 4977400ac1c05..f4cfe4845dc20 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -141,6 +141,17 @@ def id(self):
def __repr__(self):
return self._jrdd.toString()
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle an RDD, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to broadcast an RDD or reference an RDD from an "
+ "action or transformation. RDD transformations and actions can only be invoked by the "
+ "driver, not inside of other transformations; for example, "
+ "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values "
+ "transformation and count action cannot be performed inside of the rdd1.map "
+ "transformation. For more information, see SPARK-5063."
+ )
+
@property
def context(self):
"""
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
index e59c24adb84af..0e285d6088ec1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala
@@ -160,6 +160,14 @@ final private[streaming] class DStreamGraph extends Serializable with Logging {
}
}
+ /**
+ * Get the maximum remember duration across all the input streams. This is a conservative but
+ * safe remember duration which can be used to perform cleanup operations.
+ */
+ def getMaxInputStreamRememberDuration(): Duration = {
+ inputStreams.map { _.rememberDuration }.maxBy { _.milliseconds }
+ }
+
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
logDebug("DStreamGraph.writeObject used")
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index e0542eda1383f..c382a12f4d099 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -211,7 +211,9 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
* @param slideDuration sliding interval of the window (i.e., the interval after which
* the new DStream will generate RDDs); must be a multiple of this
* DStream's batching interval
+ * @deprecated As this API is not Java compatible.
*/
+ @deprecated("Use Java-compatible version of reduceByWindow", "1.3.0")
def reduceByWindow(
reduceFunc: (T, T) => T,
windowDuration: Duration,
@@ -220,6 +222,24 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration)
}
+ /**
+ * Return a new DStream in which each RDD has a single element generated by reducing all
+ * elements in a sliding window over this DStream.
+ * @param reduceFunc associative reduce function
+ * @param windowDuration width of the window; must be a multiple of this DStream's
+ * batching interval
+ * @param slideDuration sliding interval of the window (i.e., the interval after which
+ * the new DStream will generate RDDs); must be a multiple of this
+ * DStream's batching interval
+ */
+ def reduceByWindow(
+ reduceFunc: JFunction2[T, T, T],
+ windowDuration: Duration,
+ slideDuration: Duration
+ ): JavaDStream[T] = {
+ dstream.reduceByWindow(reduceFunc, windowDuration, slideDuration)
+ }
+
/**
* Return a new DStream in which each RDD has a single element generated by reducing all
* elements in a sliding window over this DStream. However, the reduction is done incrementally
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
index afd3c4bc4c4fe..8be04314c4285 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala
@@ -94,15 +94,4 @@ abstract class ReceiverInputDStream[T: ClassTag](@transient ssc_ : StreamingCont
}
Some(blockRDD)
}
-
- /**
- * Clear metadata that are older than `rememberDuration` of this DStream.
- * This is an internal method that should not be called directly. This
- * implementation overrides the default implementation to clear received
- * block information.
- */
- private[streaming] override def clearMetadata(time: Time) {
- super.clearMetadata(time)
- ssc.scheduler.receiverTracker.cleanupOldMetadata(time - rememberDuration)
- }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
index ab9fa192191aa..7bf3c33319491 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverMessage.scala
@@ -17,7 +17,10 @@
package org.apache.spark.streaming.receiver
-/** Messages sent to the NetworkReceiver. */
+import org.apache.spark.streaming.Time
+
+/** Messages sent to the Receiver. */
private[streaming] sealed trait ReceiverMessage extends Serializable
private[streaming] object StopReceiver extends ReceiverMessage
+private[streaming] case class CleanupOldBlocks(threshTime: Time) extends ReceiverMessage
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
index d7229c2b96d0b..716cf2c7f32fc 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.spark.{Logging, SparkEnv, SparkException}
import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.Time
import org.apache.spark.streaming.scheduler._
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -82,6 +83,9 @@ private[streaming] class ReceiverSupervisorImpl(
case StopReceiver =>
logInfo("Received stop signal")
stop("Stopped by driver", None)
+ case CleanupOldBlocks(threshTime) =>
+ logDebug("Received delete old batch signal")
+ cleanupOldBlocks(threshTime)
}
def ref = self
@@ -193,4 +197,9 @@ private[streaming] class ReceiverSupervisorImpl(
/** Generate new block ID */
private def nextBlockId = StreamBlockId(streamId, newBlockId.getAndIncrement)
+
+ private def cleanupOldBlocks(cleanupThreshTime: Time): Unit = {
+ logDebug(s"Cleaning up blocks older then $cleanupThreshTime")
+ receivedBlockHandler.cleanupOldBlocks(cleanupThreshTime.milliseconds)
+ }
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 39b66e1130768..8632c94349bf9 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -17,12 +17,14 @@
package org.apache.spark.streaming.scheduler
-import akka.actor.{ActorRef, ActorSystem, Props, Actor}
-import org.apache.spark.{SparkException, SparkEnv, Logging}
-import org.apache.spark.streaming.{Checkpoint, Time, CheckpointWriter}
-import org.apache.spark.streaming.util.{ManualClock, RecurringTimer, Clock}
import scala.util.{Failure, Success, Try}
+import akka.actor.{ActorRef, Props, Actor}
+
+import org.apache.spark.{SparkEnv, Logging}
+import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time}
+import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer}
+
/** Event classes for JobGenerator */
private[scheduler] sealed trait JobGeneratorEvent
private[scheduler] case class GenerateJobs(time: Time) extends JobGeneratorEvent
@@ -206,9 +208,13 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val timesToReschedule = (pendingTimes ++ downTimes).distinct.sorted(Time.ordering)
logInfo("Batches to reschedule (" + timesToReschedule.size + " batches): " +
timesToReschedule.mkString(", "))
- timesToReschedule.foreach(time =>
+ timesToReschedule.foreach { time =>
+ // Allocate the related blocks when recovering from failure, because some blocks that were
+ // added but not allocated, are dangling in the queue after recovering, we have to allocate
+ // those blocks to the next batch, which is the batch they were supposed to go.
+ jobScheduler.receiverTracker.allocateBlocksToBatch(time) // allocate received blocks to batch
jobScheduler.submitJobSet(JobSet(time, graph.generateJobs(time)))
- )
+ }
// Restart the timer
timer.start(restartTime.milliseconds)
@@ -238,13 +244,17 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Clear DStream metadata for the given `time`. */
private def clearMetadata(time: Time) {
ssc.graph.clearMetadata(time)
- jobScheduler.receiverTracker.cleanupOldMetadata(time - graph.batchDuration)
// If checkpointing is enabled, then checkpoint,
// else mark batch to be fully processed
if (shouldCheckpoint) {
eventActor ! DoCheckpoint(time)
} else {
+ // If checkpointing is not enabled, then delete metadata information about
+ // received blocks (block data not saved in any case). Otherwise, wait for
+ // checkpointing of this batch to complete.
+ val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
+ jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
markBatchFullyProcessed(time)
}
}
@@ -252,6 +262,11 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
/** Clear DStream checkpoint data for the given `time`. */
private def clearCheckpointData(time: Time) {
ssc.graph.clearCheckpointData(time)
+
+ // All the checkpoint information about which batches have been processed, etc have
+ // been saved to checkpoints, so its safe to delete block metadata and data WAL files
+ val maxRememberDuration = graph.getMaxInputStreamRememberDuration()
+ jobScheduler.receiverTracker.cleanupOldBlocksAndBatches(time - maxRememberDuration)
markBatchFullyProcessed(time)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index c3d9d7b6813d3..e19ac939f9ac5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -67,7 +67,7 @@ private[streaming] class ReceivedBlockTracker(
extends Logging {
private type ReceivedBlockQueue = mutable.Queue[ReceivedBlockInfo]
-
+
private val streamIdToUnallocatedBlockQueues = new mutable.HashMap[Int, ReceivedBlockQueue]
private val timeToAllocatedBlocks = new mutable.HashMap[Time, AllocatedBlocks]
private val logManagerOption = createLogManager()
@@ -107,8 +107,14 @@ private[streaming] class ReceivedBlockTracker(
lastAllocatedBatchTime = batchTime
allocatedBlocks
} else {
- throw new SparkException(s"Unexpected allocation of blocks, " +
- s"last batch = $lastAllocatedBatchTime, batch time to allocate = $batchTime ")
+ // This situation occurs when:
+ // 1. WAL is ended with BatchAllocationEvent, but without BatchCleanupEvent,
+ // possibly processed batch job or half-processed batch job need to be processed again,
+ // so the batchTime will be equal to lastAllocatedBatchTime.
+ // 2. Slow checkpointing makes recovered batch time older than WAL recovered
+ // lastAllocatedBatchTime.
+ // This situation will only occurs in recovery time.
+ logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery")
}
}
@@ -150,7 +156,6 @@ private[streaming] class ReceivedBlockTracker(
writeToLog(BatchCleanupEvent(timesToCleanup))
timeToAllocatedBlocks --= timesToCleanup
logManagerOption.foreach(_.cleanupOldLogs(cleanupThreshTime.milliseconds, waitForCompletion))
- log
}
/** Stop the block tracker. */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
index 8dbb42a86e3bd..4f998869731ed 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala
@@ -24,9 +24,8 @@ import scala.language.existentials
import akka.actor._
import org.apache.spark.{Logging, SerializableWritable, SparkEnv, SparkException}
-import org.apache.spark.SparkContext._
import org.apache.spark.streaming.{StreamingContext, Time}
-import org.apache.spark.streaming.receiver.{Receiver, ReceiverSupervisorImpl, StopReceiver}
+import org.apache.spark.streaming.receiver.{CleanupOldBlocks, Receiver, ReceiverSupervisorImpl, StopReceiver}
/**
* Messages used by the NetworkReceiver and the ReceiverTracker to communicate
@@ -119,9 +118,20 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false
}
}
- /** Clean up metadata older than the given threshold time */
- def cleanupOldMetadata(cleanupThreshTime: Time) {
+ /**
+ * Clean up the data and metadata of blocks and batches that are strictly
+ * older than the threshold time. Note that this does not
+ */
+ def cleanupOldBlocksAndBatches(cleanupThreshTime: Time) {
+ // Clean up old block and batch metadata
receivedBlockTracker.cleanupOldBatches(cleanupThreshTime, waitForCompletion = false)
+
+ // Signal the receivers to delete old block data
+ if (ssc.conf.getBoolean("spark.streaming.receiver.writeAheadLog.enable", false)) {
+ logInfo(s"Cleanup old received batch data: $cleanupThreshTime")
+ receiverInfo.values.flatMap { info => Option(info.actor) }
+ .foreach { _ ! CleanupOldBlocks(cleanupThreshTime) }
+ }
}
/** Register a receiver */
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
index 27a28bab83ed5..858ba3c9eb4e5 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/HdfsUtils.scala
@@ -63,7 +63,7 @@ private[streaming] object HdfsUtils {
}
def getFileSystemForPath(path: Path, conf: Configuration): FileSystem = {
- // For local file systems, return the raw loca file system, such calls to flush()
+ // For local file systems, return the raw local file system, such calls to flush()
// actually flushes the stream.
val fs = path.getFileSystem(conf)
fs match {
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index d92e7fe899a09..d4c40745658c2 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -306,7 +306,17 @@ public void testReduce() {
@SuppressWarnings("unchecked")
@Test
- public void testReduceByWindow() {
+ public void testReduceByWindowWithInverse() {
+ testReduceByWindow(true);
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void testReduceByWindowWithoutInverse() {
+ testReduceByWindow(false);
+ }
+
+ private void testReduceByWindow(boolean withInverse) {
List> inputData = Arrays.asList(
Arrays.asList(1,2,3),
Arrays.asList(4,5,6),
@@ -319,8 +329,14 @@ public void testReduceByWindow() {
Arrays.asList(24));
JavaDStream stream = JavaTestUtils.attachTestInputStream(ssc, inputData, 1);
- JavaDStream reducedWindowed = stream.reduceByWindow(new IntegerSum(),
+ JavaDStream reducedWindowed = null;
+ if (withInverse) {
+ reducedWindowed = stream.reduceByWindow(new IntegerSum(),
new IntegerDifference(), new Duration(2000), new Duration(1000));
+ } else {
+ reducedWindowed = stream.reduceByWindow(new IntegerSum(),
+ new Duration(2000), new Duration(1000));
+ }
JavaTestUtils.attachTestOutputStream(reducedWindowed);
List> result = JavaTestUtils.runStreams(ssc, 4, 4);
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
index 40434b1f9b709..6500608bba87c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala
@@ -28,21 +28,16 @@ import java.io.File
*/
class FailureSuite extends TestSuiteBase with Logging {
- var directory = "FailureSuite"
+ val directory = Utils.createTempDir().getAbsolutePath
val numBatches = 30
override def batchDuration = Milliseconds(1000)
override def useManualClock = false
- override def beforeFunction() {
- super.beforeFunction()
- Utils.deleteRecursively(new File(directory))
- }
-
override def afterFunction() {
- super.afterFunction()
Utils.deleteRecursively(new File(directory))
+ super.afterFunction()
}
test("multiple failures with map") {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index de7e9d624bf6b..fbb7b0bfebafc 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -82,15 +82,15 @@ class ReceivedBlockTrackerSuite
receivedBlockTracker.allocateBlocksToBatch(2)
receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty
- // Verify that batch 2 cannot be allocated again
- intercept[SparkException] {
- receivedBlockTracker.allocateBlocksToBatch(2)
- }
+ // Verify that older batches have no operation on batch allocation,
+ // will return the same blocks as previously allocated.
+ receivedBlockTracker.allocateBlocksToBatch(1)
+ receivedBlockTracker.getBlocksOfBatchAndStream(1, streamId) shouldEqual blockInfos
- // Verify that older batches cannot be allocated again
- intercept[SparkException] {
- receivedBlockTracker.allocateBlocksToBatch(1)
- }
+ blockInfos.map(receivedBlockTracker.addBlock)
+ receivedBlockTracker.allocateBlocksToBatch(2)
+ receivedBlockTracker.getBlocksOfBatchAndStream(2, streamId) shouldBe empty
+ receivedBlockTracker.getUnallocatedBlocks(streamId) shouldEqual blockInfos
}
test("block addition, block to batch allocation and cleanup with write ahead log") {
@@ -186,14 +186,14 @@ class ReceivedBlockTrackerSuite
tracker4.getBlocksOfBatchAndStream(batchTime1, streamId) shouldBe empty // should be cleaned
tracker4.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2
}
-
+
test("enabling write ahead log but not setting checkpoint dir") {
conf.set("spark.streaming.receiver.writeAheadLog.enable", "true")
intercept[SparkException] {
createTracker(setCheckpointDir = false)
}
}
-
+
test("setting checkpoint dir but not enabling write ahead log") {
// When WAL config is not set, log manager should not be enabled
val tracker1 = createTracker(setCheckpointDir = true)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
index e26c0c6859e57..e8c34a9ee40b9 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala
@@ -17,21 +17,26 @@
package org.apache.spark.streaming
+import java.io.File
import java.nio.ByteBuffer
import java.util.concurrent.Semaphore
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.SparkConf
-import org.apache.spark.storage.{StorageLevel, StreamBlockId}
-import org.apache.spark.streaming.receiver.{BlockGenerator, BlockGeneratorListener, Receiver, ReceiverSupervisor}
-import org.scalatest.FunSuite
+import com.google.common.io.Files
import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
+import org.apache.spark.SparkConf
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.storage.StreamBlockId
+import org.apache.spark.streaming.receiver._
+import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._
+
/** Testsuite for testing the network receiver behavior */
-class ReceiverSuite extends FunSuite with Timeouts {
+class ReceiverSuite extends TestSuiteBase with Timeouts with Serializable {
test("receiver life cycle") {
@@ -192,7 +197,6 @@ class ReceiverSuite extends FunSuite with Timeouts {
val minExpectedMessagesPerBlock = expectedMessagesPerBlock - 3
val maxExpectedMessagesPerBlock = expectedMessagesPerBlock + 1
val receivedBlockSizes = recordedBlocks.map { _.size }.mkString(",")
- println(minExpectedMessagesPerBlock, maxExpectedMessagesPerBlock, ":", receivedBlockSizes)
assert(
// the first and last block may be incomplete, so we slice them out
recordedBlocks.drop(1).dropRight(1).forall { block =>
@@ -203,39 +207,91 @@ class ReceiverSuite extends FunSuite with Timeouts {
)
}
-
/**
- * An implementation of NetworkReceiver that is used for testing a receiver's life cycle.
+ * Test whether write ahead logs are generated by received,
+ * and automatically cleaned up. The clean up must be aware of the
+ * remember duration of the input streams. E.g., input streams on which window()
+ * has been applied must remember the data for longer, and hence corresponding
+ * WALs should be cleaned later.
*/
- class FakeReceiver extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
- @volatile var otherThread: Thread = null
- @volatile var receiving = false
- @volatile var onStartCalled = false
- @volatile var onStopCalled = false
-
- def onStart() {
- otherThread = new Thread() {
- override def run() {
- receiving = true
- while(!isStopped()) {
- Thread.sleep(10)
- }
+ test("write ahead log - generating and cleaning") {
+ val sparkConf = new SparkConf()
+ .setMaster("local[4]") // must be at least 3 as we are going to start 2 receivers
+ .setAppName(framework)
+ .set("spark.ui.enabled", "true")
+ .set("spark.streaming.receiver.writeAheadLog.enable", "true")
+ .set("spark.streaming.receiver.writeAheadLog.rollingInterval", "1")
+ val batchDuration = Milliseconds(500)
+ val tempDirectory = Files.createTempDir()
+ val logDirectory1 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 0))
+ val logDirectory2 = new File(checkpointDirToLogDir(tempDirectory.getAbsolutePath, 1))
+ val allLogFiles1 = new mutable.HashSet[String]()
+ val allLogFiles2 = new mutable.HashSet[String]()
+ logInfo("Temp checkpoint directory = " + tempDirectory)
+
+ def getBothCurrentLogFiles(): (Seq[String], Seq[String]) = {
+ (getCurrentLogFiles(logDirectory1), getCurrentLogFiles(logDirectory2))
+ }
+
+ def getCurrentLogFiles(logDirectory: File): Seq[String] = {
+ try {
+ if (logDirectory.exists()) {
+ logDirectory1.listFiles().filter { _.getName.startsWith("log") }.map { _.toString }
+ } else {
+ Seq.empty
}
+ } catch {
+ case e: Exception =>
+ Seq.empty
}
- onStartCalled = true
- otherThread.start()
-
}
- def onStop() {
- onStopCalled = true
- otherThread.join()
+ def printLogFiles(message: String, files: Seq[String]) {
+ logInfo(s"$message (${files.size} files):\n" + files.mkString("\n"))
}
- def reset() {
- receiving = false
- onStartCalled = false
- onStopCalled = false
+ withStreamingContext(new StreamingContext(sparkConf, batchDuration)) { ssc =>
+ tempDirectory.deleteOnExit()
+ val receiver1 = ssc.sparkContext.clean(new FakeReceiver(sendData = true))
+ val receiver2 = ssc.sparkContext.clean(new FakeReceiver(sendData = true))
+ val receiverStream1 = ssc.receiverStream(receiver1)
+ val receiverStream2 = ssc.receiverStream(receiver2)
+ receiverStream1.register()
+ receiverStream2.window(batchDuration * 6).register() // 3 second window
+ ssc.checkpoint(tempDirectory.getAbsolutePath())
+ ssc.start()
+
+ // Run until sufficient WAL files have been generated and
+ // the first WAL files has been deleted
+ eventually(timeout(20 seconds), interval(batchDuration.milliseconds millis)) {
+ val (logFiles1, logFiles2) = getBothCurrentLogFiles()
+ allLogFiles1 ++= logFiles1
+ allLogFiles2 ++= logFiles2
+ if (allLogFiles1.size > 0) {
+ assert(!logFiles1.contains(allLogFiles1.toSeq.sorted.head))
+ }
+ if (allLogFiles2.size > 0) {
+ assert(!logFiles2.contains(allLogFiles2.toSeq.sorted.head))
+ }
+ assert(allLogFiles1.size >= 7)
+ assert(allLogFiles2.size >= 7)
+ }
+ ssc.stop(stopSparkContext = true, stopGracefully = true)
+
+ val sortedAllLogFiles1 = allLogFiles1.toSeq.sorted
+ val sortedAllLogFiles2 = allLogFiles2.toSeq.sorted
+ val (leftLogFiles1, leftLogFiles2) = getBothCurrentLogFiles()
+
+ printLogFiles("Receiver 0: all", sortedAllLogFiles1)
+ printLogFiles("Receiver 0: left", leftLogFiles1)
+ printLogFiles("Receiver 1: all", sortedAllLogFiles2)
+ printLogFiles("Receiver 1: left", leftLogFiles2)
+
+ // Verify that necessary latest log files are not deleted
+ // receiverStream1 needs to retain just the last batch = 1 log file
+ // receiverStream2 needs to retain 3 seconds (3-seconds window) = 3 log files
+ assert(sortedAllLogFiles1.takeRight(1).forall(leftLogFiles1.contains))
+ assert(sortedAllLogFiles2.takeRight(3).forall(leftLogFiles2.contains))
}
}
@@ -315,3 +371,42 @@ class ReceiverSuite extends FunSuite with Timeouts {
}
}
+/**
+ * An implementation of Receiver that is used for testing a receiver's life cycle.
+ */
+class FakeReceiver(sendData: Boolean = false) extends Receiver[Int](StorageLevel.MEMORY_ONLY) {
+ @volatile var otherThread: Thread = null
+ @volatile var receiving = false
+ @volatile var onStartCalled = false
+ @volatile var onStopCalled = false
+
+ def onStart() {
+ otherThread = new Thread() {
+ override def run() {
+ receiving = true
+ var count = 0
+ while(!isStopped()) {
+ if (sendData) {
+ store(count)
+ count += 1
+ }
+ Thread.sleep(10)
+ }
+ }
+ }
+ onStartCalled = true
+ otherThread.start()
+ }
+
+ def onStop() {
+ onStopCalled = true
+ otherThread.join()
+ }
+
+ def reset() {
+ receiving = false
+ onStartCalled = false
+ onStopCalled = false
+ }
+}
+
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
index 4c35b60c57df3..d00f29665a58f 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala
@@ -60,7 +60,6 @@ private[yarn] class YarnAllocator(
import YarnAllocator._
- // These two complementary data structures are locked on allocatedHostToContainersMap.
// Visible for testing.
val allocatedHostToContainersMap =
new HashMap[String, collection.mutable.Set[ContainerId]]
@@ -355,20 +354,18 @@ private[yarn] class YarnAllocator(
}
}
- allocatedHostToContainersMap.synchronized {
- if (allocatedContainerToHostMap.containsKey(containerId)) {
- val host = allocatedContainerToHostMap.get(containerId).get
- val containerSet = allocatedHostToContainersMap.get(host).get
+ if (allocatedContainerToHostMap.containsKey(containerId)) {
+ val host = allocatedContainerToHostMap.get(containerId).get
+ val containerSet = allocatedHostToContainersMap.get(host).get
- containerSet.remove(containerId)
- if (containerSet.isEmpty) {
- allocatedHostToContainersMap.remove(host)
- } else {
- allocatedHostToContainersMap.update(host, containerSet)
- }
-
- allocatedContainerToHostMap.remove(containerId)
+ containerSet.remove(containerId)
+ if (containerSet.isEmpty) {
+ allocatedHostToContainersMap.remove(host)
+ } else {
+ allocatedHostToContainersMap.update(host, containerSet)
}
+
+ allocatedContainerToHostMap.remove(containerId)
}
}
}