diff --git a/.gitignore b/.gitignore
index b54a3058de659..4f177c82ae5e0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,7 +7,7 @@
sbt/*.jar
.settings
.cache
-.generated-mima-excludes
+.generated-mima*
/build/
work/
out/
diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
index 110bd0a9a0c41..55241d33cd3f0 100644
--- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
@@ -80,7 +80,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
test("large number of iterations") {
// This tests whether jobs with a large number of iterations finish in a reasonable time,
// because non-memoized recursion in RDD or DAGScheduler used to cause them to hang
- failAfter(10 seconds) {
+ failAfter(30 seconds) {
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
@@ -101,7 +101,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
- val numSupersteps = 50
+ val numSupersteps = 20
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 7df43a555d562..2cf4e381c1c88 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -38,8 +38,10 @@ else
JAR_CMD="jar"
fi
-# First check if we have a dependencies jar. If so, include binary classes with the deps jar
-if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then
+# A developer option to prepend more recently compiled Spark classes
+if [ -n "$SPARK_PREPEND_CLASSES" ]; then
+ echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\
+ "classes ahead of assembly." >&2
CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
@@ -51,17 +53,31 @@ if [ -f "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar ]; then
CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes"
+fi
- ASSEMBLY_JAR=$(ls "$ASSEMBLY_DIR"/spark-assembly*hadoop*-deps.jar 2>/dev/null)
+# Use spark-assembly jar from either RELEASE or assembly directory
+if [ -f "$FWDIR/RELEASE" ]; then
+ assembly_folder="$FWDIR"/lib
else
- # Else use spark-assembly jar from either RELEASE or assembly directory
- if [ -f "$FWDIR/RELEASE" ]; then
- ASSEMBLY_JAR=$(ls "$FWDIR"/lib/spark-assembly*hadoop*.jar 2>/dev/null)
- else
- ASSEMBLY_JAR=$(ls "$ASSEMBLY_DIR"/spark-assembly*hadoop*.jar 2>/dev/null)
- fi
+ assembly_folder="$ASSEMBLY_DIR"
fi
+num_jars=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l)
+if [ "$num_jars" -eq "0" ]; then
+ echo "Failed to find Spark assembly in $assembly_folder"
+ echo "You need to build Spark before running this program."
+ exit 1
+fi
+if [ "$num_jars" -gt "1" ]; then
+ jars_list=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*.jar")
+ echo "Found multiple Spark assembly jars in $assembly_folder:"
+ echo "$jars_list"
+ echo "Please remove all but one jar."
+ exit 1
+fi
+
+ASSEMBLY_JAR=$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null)
+
# Verify that versions of java used to build the jars and run Spark are compatible
jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1)
if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then
diff --git a/bin/pyspark b/bin/pyspark
index d0fa56f31913f..0b5ed40e2157d 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -45,7 +45,7 @@ fi
. $FWDIR/bin/load-spark-env.sh
# Figure out which Python executable to use
-if [ -z "$PYSPARK_PYTHON" ] ; then
+if [[ -z "$PYSPARK_PYTHON" ]]; then
PYSPARK_PYTHON="python"
fi
export PYSPARK_PYTHON
@@ -59,7 +59,7 @@ export OLD_PYTHONSTARTUP=$PYTHONSTARTUP
export PYTHONSTARTUP=$FWDIR/python/pyspark/shell.py
# If IPython options are specified, assume user wants to run IPython
-if [ -n "$IPYTHON_OPTS" ]; then
+if [[ -n "$IPYTHON_OPTS" ]]; then
IPYTHON=1
fi
@@ -76,6 +76,16 @@ for i in "$@"; do
done
export PYSPARK_SUBMIT_ARGS
+# For pyspark tests
+if [[ -n "$SPARK_TESTING" ]]; then
+ if [[ -n "$PYSPARK_DOC_TEST" ]]; then
+ exec "$PYSPARK_PYTHON" -m doctest $1
+ else
+ exec "$PYSPARK_PYTHON" $1
+ fi
+ exit
+fi
+
# If a python file is provided, directly run spark-submit.
if [[ "$1" =~ \.py$ ]]; then
echo -e "\nWARNING: Running python applications through ./bin/pyspark is deprecated as of Spark 1.0." 1>&2
diff --git a/bin/spark-class b/bin/spark-class
index e884511010c6c..cfe363a71da31 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -108,23 +108,6 @@ fi
export JAVA_OPTS
# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala!
-if [ ! -f "$FWDIR/RELEASE" ]; then
- # Exit if the user hasn't compiled Spark
- num_jars=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar" | wc -l)
- jars_list=$(ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/ | grep "spark-assembly.*hadoop.*.jar")
- if [ "$num_jars" -eq "0" ]; then
- echo "Failed to find Spark assembly in $FWDIR/assembly/target/scala-$SCALA_VERSION/" >&2
- echo "You need to build Spark before running this program." >&2
- exit 1
- fi
- if [ "$num_jars" -gt "1" ]; then
- echo "Found multiple Spark assembly jars in $FWDIR/assembly/target/scala-$SCALA_VERSION:" >&2
- echo "$jars_list"
- echo "Please remove all but one jar."
- exit 1
- fi
-fi
-
TOOLS_DIR="$FWDIR"/tools
SPARK_TOOLS_JAR=""
if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/*assembly*[0-9Tg].jar ]; then
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index e2d2250982daa..bf3c3a6ceb5ef 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -96,7 +96,7 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
- def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _]) {
+ def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
}
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 2c31cc20211ff..c8c194a111aac 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleHandle
/**
* :: DeveloperApi ::
@@ -50,19 +51,24 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
- * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to null,
+ * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
*/
@DeveloperApi
-class ShuffleDependency[K, V](
+class ShuffleDependency[K, V, C](
@transient rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
- val serializer: Serializer = null)
+ val serializer: Option[Serializer] = None,
+ val keyOrdering: Option[Ordering[K]] = None,
+ val aggregator: Option[Aggregator[K, V, C]] = None)
extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
val shuffleId: Int = rdd.context.newShuffleId()
+ val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
+ shuffleId, rdd.partitions.size, this)
+
rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index d721aba709600..35970c2f50892 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -290,6 +290,9 @@ class SparkContext(config: SparkConf) extends Logging {
value <- Option(System.getenv(envKey)).orElse(Option(System.getProperty(propKey)))} {
executorEnvs(envKey) = value
}
+ Option(System.getenv("SPARK_PREPEND_CLASSES")).foreach { v =>
+ executorEnvs("SPARK_PREPEND_CLASSES") = v
+ }
// The Mesos scheduler backend relies on this environment variable to set executor memory.
// TODO: Set this only in the Mesos scheduler.
executorEnvs("SPARK_EXECUTOR_MEMORY") = executorMemory + "m"
@@ -297,7 +300,7 @@ class SparkContext(config: SparkConf) extends Logging {
// Set SPARK_USER for user who is running SparkContext.
val sparkUser = Option {
- Option(System.getProperty("user.name")).getOrElse(System.getenv("SPARK_USER"))
+ Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name"))
}.getOrElse {
SparkContext.SPARK_UNKNOWN_USER
}
@@ -431,12 +434,21 @@ class SparkContext(config: SparkConf) extends Logging {
// Methods for creating RDDs
- /** Distribute a local Scala collection to form an RDD. */
+ /** Distribute a local Scala collection to form an RDD.
+ *
+ * @note Parallelize acts lazily. If `seq` is a mutable collection and is
+ * altered after the call to parallelize and before the first action on the
+ * RDD, the resultant RDD will reflect the modified collection. Pass a copy of
+ * the argument to avoid this.
+ */
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
- /** Distribute a local Scala collection to form an RDD. */
+ /** Distribute a local Scala collection to form an RDD.
+ *
+ * This method is identical to `parallelize`.
+ */
def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
parallelize(seq, numSlices)
}
@@ -823,9 +835,11 @@ class SparkContext(config: SparkConf) extends Logging {
}
/**
+ * :: DeveloperApi ::
* Return information about what RDDs are cached, if they are in mem or on disk, how much space
* they take, etc.
*/
+ @DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
}
@@ -837,8 +851,10 @@ class SparkContext(config: SparkConf) extends Logging {
def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap
/**
+ * :: DeveloperApi ::
* Return information about blocks stored in all of the slaves
*/
+ @DeveloperApi
def getExecutorStorageStatus: Array[StorageStatus] = {
env.blockManager.master.getStorageStatus
}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 720151a6b0f84..8dfa8cc4b5b3f 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -34,6 +34,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.ConnectionManager
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{AkkaUtils, Utils}
@@ -56,7 +57,7 @@ class SparkEnv (
val closureSerializer: Serializer,
val cacheManager: CacheManager,
val mapOutputTracker: MapOutputTracker,
- val shuffleFetcher: ShuffleFetcher,
+ val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
val blockManager: BlockManager,
val connectionManager: ConnectionManager,
@@ -80,7 +81,7 @@ class SparkEnv (
pythonWorkers.foreach { case(key, worker) => worker.stop() }
httpFileServer.stop()
mapOutputTracker.stop()
- shuffleFetcher.stop()
+ shuffleManager.stop()
broadcastManager.stop()
blockManager.stop()
blockManager.master.stop()
@@ -163,13 +164,20 @@ object SparkEnv extends Logging {
def instantiateClass[T](propertyName: String, defaultClassName: String): T = {
val name = conf.get(propertyName, defaultClassName)
val cls = Class.forName(name, true, Utils.getContextOrSparkClassLoader)
- // First try with the constructor that takes SparkConf. If we can't find one,
- // use a no-arg constructor instead.
+ // Look for a constructor taking a SparkConf and a boolean isDriver, then one taking just
+ // SparkConf, then one taking no arguments
try {
- cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+ cls.getConstructor(classOf[SparkConf], java.lang.Boolean.TYPE)
+ .newInstance(conf, new java.lang.Boolean(isDriver))
+ .asInstanceOf[T]
} catch {
case _: NoSuchMethodException =>
- cls.getConstructor().newInstance().asInstanceOf[T]
+ try {
+ cls.getConstructor(classOf[SparkConf]).newInstance(conf).asInstanceOf[T]
+ } catch {
+ case _: NoSuchMethodException =>
+ cls.getConstructor().newInstance().asInstanceOf[T]
+ }
}
}
@@ -219,9 +227,6 @@ object SparkEnv extends Logging {
val cacheManager = new CacheManager(blockManager)
- val shuffleFetcher = instantiateClass[ShuffleFetcher](
- "spark.shuffle.fetcher", "org.apache.spark.BlockStoreShuffleFetcher")
-
val httpFileServer = new HttpFileServer(securityManager)
httpFileServer.initialize()
conf.set("spark.fileserver.uri", httpFileServer.serverUri)
@@ -242,6 +247,9 @@ object SparkEnv extends Logging {
"."
}
+ val shuffleManager = instantiateClass[ShuffleManager](
+ "spark.shuffle.manager", "org.apache.spark.shuffle.hash.HashShuffleManager")
+
// Warn about deprecated spark.cache.class property
if (conf.contains("spark.cache.class")) {
logWarning("The spark.cache.class property is no longer being used! Specify storage " +
@@ -255,7 +263,7 @@ object SparkEnv extends Logging {
closureSerializer,
cacheManager,
mapOutputTracker,
- shuffleFetcher,
+ shuffleManager,
broadcastManager,
blockManager,
connectionManager,
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 7dcfbf741c4f1..14fa9d8135afe 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -228,6 +228,50 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
: PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U](zeroValue: U, partitioner: Partitioner, seqFunc: JFunction2[U, V, U],
+ combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = {
+ implicit val ctag: ClassTag[U] = fakeClassTag
+ fromRDD(rdd.aggregateByKey(zeroValue, partitioner)(seqFunc, combFunc))
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U](zeroValue: U, numPartitions: Int, seqFunc: JFunction2[U, V, U],
+ combFunc: JFunction2[U, U, U]): JavaPairRDD[K, U] = {
+ implicit val ctag: ClassTag[U] = fakeClassTag
+ fromRDD(rdd.aggregateByKey(zeroValue, numPartitions)(seqFunc, combFunc))
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's.
+ * The former operation is used for merging values within a partition, and the latter is used for
+ * merging values between partitions. To avoid memory allocation, both of these functions are
+ * allowed to modify and return their first argument instead of creating a new U.
+ */
+ def aggregateByKey[U](zeroValue: U, seqFunc: JFunction2[U, V, U], combFunc: JFunction2[U, U, U]):
+ JavaPairRDD[K, U] = {
+ implicit val ctag: ClassTag[U] = fakeClassTag
+ fromRDD(rdd.aggregateByKey(zeroValue)(seqFunc, combFunc))
+ }
+
/**
* Merge the values for each key using an associative function and a neutral "zero value" which
* may be added to the result an arbitrary number of times, and must not change the result
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index aeb159adc31d9..c371dc3a51c73 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -81,7 +81,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) extends
case "kill" =>
val driverId = driverArgs.driverId
- val killFuture = masterActor ! RequestKillDriver(driverId)
+ masterActor ! RequestKillDriver(driverId)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index d27e0e1f15c65..d09136de49807 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -23,9 +23,10 @@ import akka.actor.ActorRef
import com.google.common.base.Charsets
import com.google.common.io.Files
-import org.apache.spark.Logging
+import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
+import org.apache.spark.util.logging.FileAppender
/**
* Manages the execution of one executor process.
@@ -42,12 +43,15 @@ private[spark] class ExecutorRunner(
val sparkHome: File,
val workDir: File,
val workerUrl: String,
+ val conf: SparkConf,
var state: ExecutorState.Value)
extends Logging {
val fullId = appId + "/" + execId
var workerThread: Thread = null
var process: Process = null
+ var stdoutAppender: FileAppender = null
+ var stderrAppender: FileAppender = null
// NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might
// make sense to remove this in the future.
@@ -76,6 +80,13 @@ private[spark] class ExecutorRunner(
if (process != null) {
logInfo("Killing process!")
process.destroy()
+ process.waitFor()
+ if (stdoutAppender != null) {
+ stdoutAppender.stop()
+ }
+ if (stderrAppender != null) {
+ stderrAppender.stop()
+ }
val exitCode = process.waitFor()
worker ! ExecutorStateChanged(appId, execId, state, message, Some(exitCode))
}
@@ -137,11 +148,11 @@ private[spark] class ExecutorRunner(
// Redirect its stdout and stderr to files
val stdout = new File(executorDir, "stdout")
- CommandUtils.redirectStream(process.getInputStream, stdout)
+ stdoutAppender = FileAppender(process.getInputStream, stdout, conf)
val stderr = new File(executorDir, "stderr")
Files.write(header, stderr, Charsets.UTF_8)
- CommandUtils.redirectStream(process.getErrorStream, stderr)
+ stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
// Wait for it to exit; this is actually a bad thing if it happens, because we expect to run
// long-lived processes only. However, in the future, we might restart the executor a few
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 100de26170a50..a0ecaf709f8e2 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -235,7 +235,7 @@ private[spark] class Worker(
val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
self, workerId, host,
appDesc.sparkHome.map(userSparkHome => new File(userSparkHome)).getOrElse(sparkHome),
- workDir, akkaUrl, ExecutorState.RUNNING)
+ workDir, akkaUrl, conf, ExecutorState.RUNNING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
index 8381f59672ea3..6a5ffb1b71bfb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala
@@ -24,8 +24,10 @@ import scala.xml.Node
import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.Utils
+import org.apache.spark.Logging
+import org.apache.spark.util.logging.{FileAppender, RollingFileAppender}
-private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") {
+private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging {
private val worker = parent.worker
private val workDir = parent.workDir
@@ -39,21 +41,18 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") {
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
- val path = (appId, executorId, driverId) match {
+ val logDir = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
- s"${workDir.getPath}/$appId/$executorId/$logType"
+ s"${workDir.getPath}/$appId/$executorId/"
case (None, None, Some(d)) =>
- s"${workDir.getPath}/$driverId/$logType"
+ s"${workDir.getPath}/$driverId/"
case _ =>
throw new Exception("Request must specify either application or driver identifiers")
}
- val (startByte, endByte) = getByteRange(path, offset, byteLength)
- val file = new File(path)
- val logLength = file.length
-
- val pre = s"==== Bytes $startByte-$endByte of $logLength of $path ====\n"
- pre + Utils.offsetBytes(path, startByte, endByte)
+ val (logText, startByte, endByte, logLength) = getLog(logDir, logType, offset, byteLength)
+ val pre = s"==== Bytes $startByte-$endByte of $logLength of $logDir$logType ====\n"
+ pre + logText
}
def render(request: HttpServletRequest): Seq[Node] = {
@@ -65,19 +64,16 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") {
val offset = Option(request.getParameter("offset")).map(_.toLong)
val byteLength = Option(request.getParameter("byteLength")).map(_.toInt).getOrElse(defaultBytes)
- val (path, params) = (appId, executorId, driverId) match {
+ val (logDir, params) = (appId, executorId, driverId) match {
case (Some(a), Some(e), None) =>
- (s"${workDir.getPath}/$a/$e/$logType", s"appId=$a&executorId=$e")
+ (s"${workDir.getPath}/$a/$e/", s"appId=$a&executorId=$e")
case (None, None, Some(d)) =>
- (s"${workDir.getPath}/$d/$logType", s"driverId=$d")
+ (s"${workDir.getPath}/$d/", s"driverId=$d")
case _ =>
throw new Exception("Request must specify either application or driver identifiers")
}
- val (startByte, endByte) = getByteRange(path, offset, byteLength)
- val file = new File(path)
- val logLength = file.length
- val logText = {Utils.offsetBytes(path, startByte, endByte)}
+ val (logText, startByte, endByte, logLength) = getLog(logDir, logType, offset, byteLength)
val linkToMaster =
Back to Master
val range = Bytes {startByte.toString} - {endByte.toString} of {logLength}
@@ -127,23 +123,37 @@ private[spark] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") {
UIUtils.basicSparkPage(content, logType + " log page for " + appId)
}
- /** Determine the byte range for a log or log page. */
- private def getByteRange(path: String, offset: Option[Long], byteLength: Int): (Long, Long) = {
- val defaultBytes = 100 * 1024
- val maxBytes = 1024 * 1024
- val file = new File(path)
- val logLength = file.length()
- val getOffset = offset.getOrElse(logLength - defaultBytes)
- val startByte =
- if (getOffset < 0) {
- 0L
- } else if (getOffset > logLength) {
- logLength
- } else {
- getOffset
+ /** Get the part of the log files given the offset and desired length of bytes */
+ private def getLog(
+ logDirectory: String,
+ logType: String,
+ offsetOption: Option[Long],
+ byteLength: Int
+ ): (String, Long, Long, Long) = {
+ try {
+ val files = RollingFileAppender.getSortedRolledOverFiles(logDirectory, logType)
+ logDebug(s"Sorted log files of type $logType in $logDirectory:\n${files.mkString("\n")}")
+
+ val totalLength = files.map { _.length }.sum
+ val offset = offsetOption.getOrElse(totalLength - byteLength)
+ val startIndex = {
+ if (offset < 0) {
+ 0L
+ } else if (offset > totalLength) {
+ totalLength
+ } else {
+ offset
+ }
}
- val logPageLength = math.min(byteLength, maxBytes)
- val endByte = math.min(startByte + logPageLength, logLength)
- (startByte, endByte)
+ val endIndex = math.min(startIndex + totalLength, totalLength)
+ logDebug(s"Getting log from $startIndex to $endIndex")
+ val logText = Utils.offsetBytes(files, startIndex, endIndex)
+ logDebug(s"Got log of length ${logText.length} bytes")
+ (logText, startIndex, endIndex, totalLength)
+ } catch {
+ case e: Exception =>
+ logError(s"Error getting $logType logs from directory $logDirectory", e)
+ ("Error getting logs due to exception: " + e.getMessage, 0, 0, 0)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/Connection.scala b/core/src/main/scala/org/apache/spark/network/Connection.scala
index 3ffaaab23d0f5..3b6298a26d7c5 100644
--- a/core/src/main/scala/org/apache/spark/network/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/Connection.scala
@@ -210,7 +210,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
var nextMessageToBeUsed = 0
def addMessage(message: Message) {
- messages.synchronized{
+ messages.synchronized {
/* messages += message */
messages.enqueue(message)
logDebug("Added [" + message + "] to outbox for sending to " +
@@ -223,7 +223,7 @@ class SendingConnection(val address: InetSocketAddress, selector_ : Selector,
while (!messages.isEmpty) {
/* nextMessageToBeUsed = nextMessageToBeUsed % messages.size */
/* val message = messages(nextMessageToBeUsed) */
- val message = messages.dequeue
+ val message = messages.dequeue()
val chunk = message.getChunkForSending(defaultChunkSize)
if (chunk.isDefined) {
messages.enqueue(message)
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 5dd5fd0047c0d..cf1c985c2fff9 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -250,14 +250,14 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
try {
while(!selectorThread.isInterrupted) {
while (! registerRequests.isEmpty) {
- val conn: SendingConnection = registerRequests.dequeue
+ val conn: SendingConnection = registerRequests.dequeue()
addListeners(conn)
conn.connect()
addConnection(conn)
}
while(!keyInterestChangeRequests.isEmpty) {
- val (key, ops) = keyInterestChangeRequests.dequeue
+ val (key, ops) = keyInterestChangeRequests.dequeue()
try {
if (key.isValid) {
@@ -532,9 +532,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
}
return
}
- var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ val securityMsgResp = SecurityMessage.fromResponse(replyToken,
securityMsg.getConnectionId.toString())
- var message = securityMsgResp.toBufferMessage
+ val message = securityMsgResp.toBufferMessage
if (message == null) throw new Exception("Error creating security message")
sendSecurityMessage(waitingConn.getRemoteConnectionManagerId(), message)
} catch {
@@ -568,9 +568,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
logDebug("Server sasl not completed: " + connection.connectionId)
}
if (replyToken != null) {
- var securityMsgResp = SecurityMessage.fromResponse(replyToken,
+ val securityMsgResp = SecurityMessage.fromResponse(replyToken,
securityMsg.getConnectionId)
- var message = securityMsgResp.toBufferMessage
+ val message = securityMsgResp.toBufferMessage
if (message == null) throw new Exception("Error creating security Message")
sendSecurityMessage(connection.getRemoteConnectionManagerId(), message)
}
@@ -618,7 +618,7 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
return true
}
}
- return false
+ false
}
private def handleMessage(
@@ -694,9 +694,9 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf,
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()
- var securityMsg = SecurityMessage.fromResponse(firstResponse,
+ val securityMsg = SecurityMessage.fromResponse(firstResponse,
conn.connectionId.toString())
- var message = securityMsg.toBufferMessage
+ val message = securityMsg.toBufferMessage
if (message == null) throw new Exception("Error creating security message")
connectionsAwaitingSasl += ((conn.connectionId, conn))
sendSecurityMessage(connManagerId, message)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 9ff76892aed32..5951865e56c9d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap}
import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.ShuffleHandle
private[spark] sealed trait CoGroupSplitDep extends Serializable
@@ -44,7 +45,7 @@ private[spark] case class NarrowCoGroupSplitDep(
}
}
-private[spark] case class ShuffleCoGroupSplitDep(shuffleId: Int) extends CoGroupSplitDep
+private[spark] case class ShuffleCoGroupSplitDep(handle: ShuffleHandle) extends CoGroupSplitDep
private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep])
extends Partition with Serializable {
@@ -74,10 +75,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Seq[CoGroup]
- private var serializer: Serializer = null
+ private var serializer: Option[Serializer] = None
+ /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
- this.serializer = serializer
+ this.serializer = Option(serializer)
this
}
@@ -88,7 +90,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[Any, Any](rdd, part, serializer)
+ new ShuffleDependency[K, Any, CoGroupCombiner](rdd, part, serializer)
}
}
}
@@ -100,8 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
array(i) = new CoGroupPartition(i, rdds.zipWithIndex.map { case (rdd, j) =>
// Assume each RDD contributed a single dependency, and get it
dependencies(j) match {
- case s: ShuffleDependency[_, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleId)
+ case s: ShuffleDependency[_, _, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
@@ -126,11 +128,11 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
val it = rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, Any]]]
rddIterators += ((it, depNum))
- case ShuffleCoGroupSplitDep(shuffleId) =>
+ case ShuffleCoGroupSplitDep(handle) =>
// Read map outputs of shuffle
- val fetcher = SparkEnv.get.shuffleFetcher
- val ser = Serializer.getSerializer(serializer)
- val it = fetcher.fetch[Product2[K, Any]](shuffleId, split.index, context, ser)
+ val it = SparkEnv.get.shuffleManager
+ .getReader(handle, split.index, split.index + 1, context)
+ .read()
rddIterators += ((it, depNum))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 8909980957058..b6ad9b6c3e168 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -118,6 +118,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
combineByKey(createCombiner, mergeValue, mergeCombiners, new HashPartitioner(numPartitions))
}
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U,
+ combOp: (U, U) => U): RDD[(K, U)] = {
+ // Serialize the zero value to a byte array so that we can get a new clone of it on each key
+ val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroArray = new Array[Byte](zeroBuffer.limit)
+ zeroBuffer.get(zeroArray)
+
+ lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ def createZero() = cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))
+
+ combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U: ClassTag](zeroValue: U, numPartitions: Int)(seqOp: (U, V) => U,
+ combOp: (U, U) => U): RDD[(K, U)] = {
+ aggregateByKey(zeroValue, new HashPartitioner(numPartitions))(seqOp, combOp)
+ }
+
+ /**
+ * Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ * This function can return a different result type, U, than the type of the values in this RDD,
+ * V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ * as in scala.TraversableOnce. The former operation is used for merging values within a
+ * partition, and the latter is used for merging values between partitions. To avoid memory
+ * allocation, both of these functions are allowed to modify and return their first argument
+ * instead of creating a new U.
+ */
+ def aggregateByKey[U: ClassTag](zeroValue: U)(seqOp: (U, V) => U,
+ combOp: (U, U) => U): RDD[(K, U)] = {
+ aggregateByKey(zeroValue, defaultPartitioner(self))(seqOp, combOp)
+ }
+
/**
* Merge the values for each key using an associative function and a neutral "zero value" which
* may be added to the result an arbitrary number of times, and must not change the result
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index fb12738f499f3..446f369c9ea16 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1190,7 +1190,7 @@ abstract class RDD[T: ClassTag](
/** User code that created this RDD (e.g. `textFile`, `parallelize`). */
@transient private[spark] val creationSiteInfo = Utils.getCallSiteInfo
- private[spark] def getCreationSite: String = creationSiteInfo.toString
+ private[spark] def getCreationSite: String = Option(creationSiteInfo).getOrElse("").toString
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 802b0bdfb2d59..bb108ef163c56 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -42,10 +42,11 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
part: Partitioner)
extends RDD[P](prev.context, Nil) {
- private var serializer: Serializer = null
+ private var serializer: Option[Serializer] = None
+ /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, P] = {
- this.serializer = serializer
+ this.serializer = Option(serializer)
this
}
@@ -60,9 +61,10 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
}
override def compute(split: Partition, context: TaskContext): Iterator[P] = {
- val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
- val ser = Serializer.getSerializer(serializer)
- SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, ser)
+ val dep = dependencies.head.asInstanceOf[ShuffleDependency[K, V, V]]
+ SparkEnv.get.shuffleManager.getReader(dep.shuffleHandle, split.index, split.index + 1, context)
+ .read()
+ .asInstanceOf[Iterator[P]]
}
override def clearDependencies() {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 9a09c05bbc959..ed24ea22a661c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -54,10 +54,11 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
- private var serializer: Serializer = null
+ private var serializer: Option[Serializer] = None
+ /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
- this.serializer = serializer
+ this.serializer = Option(serializer)
this
}
@@ -79,8 +80,8 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
// Each CoGroupPartition will depend on rdd1 and rdd2
array(i) = new CoGroupPartition(i, Seq(rdd1, rdd2).zipWithIndex.map { case (rdd, j) =>
dependencies(j) match {
- case s: ShuffleDependency[_, _] =>
- new ShuffleCoGroupSplitDep(s.shuffleId)
+ case s: ShuffleDependency[_, _, _] =>
+ new ShuffleCoGroupSplitDep(s.shuffleHandle)
case _ =>
new NarrowCoGroupSplitDep(rdd, i, rdd.partitions(i))
}
@@ -93,7 +94,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = {
val partition = p.asInstanceOf[CoGroupPartition]
- val ser = Serializer.getSerializer(serializer)
val map = new JHashMap[K, ArrayBuffer[V]]
def getSeq(k: K): ArrayBuffer[V] = {
val seq = map.get(k)
@@ -109,9 +109,10 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
case NarrowCoGroupSplitDep(rdd, _, itsSplit) =>
rdd.iterator(itsSplit, context).asInstanceOf[Iterator[Product2[K, V]]].foreach(op)
- case ShuffleCoGroupSplitDep(shuffleId) =>
- val iter = SparkEnv.get.shuffleFetcher.fetch[Product2[K, V]](shuffleId, partition.index,
- context, ser)
+ case ShuffleCoGroupSplitDep(handle) =>
+ val iter = SparkEnv.get.shuffleManager
+ .getReader(handle, partition.index, partition.index + 1, context)
+ .read()
iter.foreach(op)
}
// the first dep is rdd1; add all values to the map
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index e09a4221e8315..3c85b5a2ae776 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -190,7 +190,7 @@ class DAGScheduler(
* The jobId value passed in will be used if the stage doesn't already exist with
* a lower jobId (jobId always increases across jobs.)
*/
- private def getShuffleMapStage(shuffleDep: ShuffleDependency[_,_], jobId: Int): Stage = {
+ private def getShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): Stage = {
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
@@ -210,7 +210,7 @@ class DAGScheduler(
private def newStage(
rdd: RDD[_],
numTasks: Int,
- shuffleDep: Option[ShuffleDependency[_,_]],
+ shuffleDep: Option[ShuffleDependency[_, _, _]],
jobId: Int,
callSite: Option[String] = None)
: Stage =
@@ -233,7 +233,7 @@ class DAGScheduler(
private def newOrUsedStage(
rdd: RDD[_],
numTasks: Int,
- shuffleDep: ShuffleDependency[_,_],
+ shuffleDep: ShuffleDependency[_, _, _],
jobId: Int,
callSite: Option[String] = None)
: Stage =
@@ -269,7 +269,7 @@ class DAGScheduler(
// we can't do it in its constructor because # of partitions is unknown
for (dep <- r.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_] =>
+ case shufDep: ShuffleDependency[_, _, _] =>
parents += getShuffleMapStage(shufDep, jobId)
case _ =>
visit(dep.rdd)
@@ -290,7 +290,7 @@ class DAGScheduler(
if (getCacheLocs(rdd).contains(Nil)) {
for (dep <- rdd.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_] =>
+ case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
missing += mapStage
@@ -1088,7 +1088,7 @@ class DAGScheduler(
visitedRdds += rdd
for (dep <- rdd.dependencies) {
dep match {
- case shufDep: ShuffleDependency[_,_] =>
+ case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getShuffleMapStage(shufDep, stage.jobId)
if (!mapStage.isAvailable) {
visitedStages += mapStage
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index ed0f56f1abdf5..0098b5a59d1a5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -29,6 +29,7 @@ import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.rdd.{RDD, RDDCheckpointData}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
+import org.apache.spark.shuffle.ShuffleWriter
private[spark] object ShuffleMapTask {
@@ -37,7 +38,7 @@ private[spark] object ShuffleMapTask {
// expensive on the master node if it needs to launch thousands of tasks.
private val serializedInfoCache = new HashMap[Int, Array[Byte]]
- def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = {
+ def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
synchronized {
val old = serializedInfoCache.get(stageId).orNull
if (old != null) {
@@ -56,12 +57,12 @@ private[spark] object ShuffleMapTask {
}
}
- def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_,_]) = {
+ def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
val ser = SparkEnv.get.closureSerializer.newInstance()
val objIn = ser.deserializeStream(in)
val rdd = objIn.readObject().asInstanceOf[RDD[_]]
- val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_,_]]
+ val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
(rdd, dep)
}
@@ -99,7 +100,7 @@ private[spark] object ShuffleMapTask {
private[spark] class ShuffleMapTask(
stageId: Int,
var rdd: RDD[_],
- var dep: ShuffleDependency[_,_],
+ var dep: ShuffleDependency[_, _, _],
_partitionId: Int,
@transient private var locs: Seq[TaskLocation])
extends Task[MapStatus](stageId, _partitionId)
@@ -141,66 +142,22 @@ private[spark] class ShuffleMapTask(
}
override def runTask(context: TaskContext): MapStatus = {
- val numOutputSplits = dep.partitioner.numPartitions
metrics = Some(context.taskMetrics)
-
- val blockManager = SparkEnv.get.blockManager
- val shuffleBlockManager = blockManager.shuffleBlockManager
- var shuffle: ShuffleWriterGroup = null
- var success = false
-
+ var writer: ShuffleWriter[Any, Any] = null
try {
- // Obtain all the block writers for shuffle blocks.
- val ser = Serializer.getSerializer(dep.serializer)
- shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser)
-
- // Write the map output to its associated buckets.
+ val manager = SparkEnv.get.shuffleManager
+ writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
for (elem <- rdd.iterator(split, context)) {
- val pair = elem.asInstanceOf[Product2[Any, Any]]
- val bucketId = dep.partitioner.getPartition(pair._1)
- shuffle.writers(bucketId).write(pair)
- }
-
- // Commit the writes. Get the size of each bucket block (total block size).
- var totalBytes = 0L
- var totalTime = 0L
- val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
- writer.commit()
- writer.close()
- val size = writer.fileSegment().length
- totalBytes += size
- totalTime += writer.timeWriting()
- MapOutputTracker.compressSize(size)
+ writer.write(elem.asInstanceOf[Product2[Any, Any]])
}
-
- // Update shuffle metrics.
- val shuffleMetrics = new ShuffleWriteMetrics
- shuffleMetrics.shuffleBytesWritten = totalBytes
- shuffleMetrics.shuffleWriteTime = totalTime
- metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
-
- success = true
- new MapStatus(blockManager.blockManagerId, compressedSizes)
- } catch { case e: Exception =>
- // If there is an exception from running the task, revert the partial writes
- // and throw the exception upstream to Spark.
- if (shuffle != null && shuffle.writers != null) {
- for (writer <- shuffle.writers) {
- writer.revertPartialWrites()
- writer.close()
+ return writer.stop(success = true).get
+ } catch {
+ case e: Exception =>
+ if (writer != null) {
+ writer.stop(success = false)
}
- }
- throw e
+ throw e
} finally {
- // Release the writers back to the shuffle block manager.
- if (shuffle != null && shuffle.writers != null) {
- try {
- shuffle.releaseWriters(success)
- } catch {
- case e: Exception => logError("Failed to release shuffle writers", e)
- }
- }
- // Execute the callbacks on task completion.
context.executeOnCompleteCallbacks()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 5c1fc30e4a557..3bf9713f728c6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -40,7 +40,7 @@ private[spark] class Stage(
val id: Int,
val rdd: RDD[_],
val numTasks: Int,
- val shuffleDep: Option[ShuffleDependency[_,_]], // Output shuffle if stage is a map stage
+ val shuffleDep: Option[ShuffleDependency[_, _, _]], // Output shuffle if stage is a map stage
val parents: List[Stage],
val jobId: Int,
callSite: Option[String])
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 99d305b36a959..df59f444b7a0e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -71,7 +71,8 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
case ex: Exception =>
- taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
+ logError("Exception while getting task result", ex)
+ taskSetManager.abort("Exception while getting task result: %s".format(ex))
}
}
})
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index ee26970a3d874..f2f5cea469c61 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -52,6 +52,10 @@ object Serializer {
def getSerializer(serializer: Serializer): Serializer = {
if (serializer == null) SparkEnv.get.serializer else serializer
}
+
+ def getSerializer(serializer: Option[Serializer]): Serializer = {
+ serializer.getOrElse(SparkEnv.get.serializer)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
similarity index 66%
rename from core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
rename to core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
index a4f69b6b22b2c..b36c457d6d514 100644
--- a/core/src/main/scala/org/apache/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala
@@ -15,22 +15,16 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.shuffle
+import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner}
import org.apache.spark.serializer.Serializer
-private[spark] abstract class ShuffleFetcher {
-
- /**
- * Fetch the shuffle outputs for a given ShuffleDependency.
- * @return An iterator over the elements of the fetched shuffle outputs.
- */
- def fetch[T](
- shuffleId: Int,
- reduceId: Int,
- context: TaskContext,
- serializer: Serializer = SparkEnv.get.serializer): Iterator[T]
-
- /** Stop the fetcher */
- def stop() {}
-}
+/**
+ * A basic ShuffleHandle implementation that just captures registerShuffle's parameters.
+ */
+private[spark] class BaseShuffleHandle[K, V, C](
+ shuffleId: Int,
+ val numMaps: Int,
+ val dependency: ShuffleDependency[K, V, C])
+ extends ShuffleHandle(shuffleId)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
new file mode 100644
index 0000000000000..13c7115f88afa
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleHandle.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+/**
+ * An opaque handle to a shuffle, used by a ShuffleManager to pass information about it to tasks.
+ *
+ * @param shuffleId ID of the shuffle
+ */
+private[spark] abstract class ShuffleHandle(val shuffleId: Int) extends Serializable {}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
new file mode 100644
index 0000000000000..9c859b8b4a118
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.{TaskContext, ShuffleDependency}
+
+/**
+ * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the
+ * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles
+ * with it, and executors (or tasks running locally in the driver) can ask to read and write data.
+ *
+ * NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
+ * boolean isDriver as parameters.
+ */
+private[spark] trait ShuffleManager {
+ /**
+ * Register a shuffle with the manager and obtain a handle for it to pass to tasks.
+ */
+ def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext): ShuffleWriter[K, V]
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C]
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ def unregisterShuffle(shuffleId: Int)
+
+ /** Shut down this ShuffleManager. */
+ def stop(): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
new file mode 100644
index 0000000000000..b30e366d06006
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+/**
+ * Obtained inside a reduce task to read combined records from the mappers.
+ */
+private[spark] trait ShuffleReader[K, C] {
+ /** Read the combined key-values for this reduce task */
+ def read(): Iterator[Product2[K, C]]
+
+ /** Close this reader */
+ def stop(): Unit
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
new file mode 100644
index 0000000000000..ead3ebd652ca5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleWriter.scala
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import org.apache.spark.scheduler.MapStatus
+
+/**
+ * Obtained inside a map task to write out records to the shuffle system.
+ */
+private[spark] trait ShuffleWriter[K, V] {
+ /** Write a record to this task's output */
+ def write(record: Product2[K, V]): Unit
+
+ /** Close this writer, passing along whether the map completed */
+ def stop(success: Boolean): Option[MapStatus]
+}
diff --git a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
similarity index 96%
rename from core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
rename to core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index a67392441ed29..b05b6ea345df3 100644
--- a/core/src/main/scala/org/apache/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark
+package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
@@ -24,17 +24,16 @@ import org.apache.spark.executor.ShuffleReadMetrics
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
+import org.apache.spark._
-private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
-
- override def fetch[T](
+private[hash] object BlockStoreShuffleFetcher extends Logging {
+ def fetch[T](
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
{
-
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
new file mode 100644
index 0000000000000..5b0940ecce29d
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark._
+import org.apache.spark.shuffle._
+
+/**
+ * A ShuffleManager using hashing, that creates one output file per reduce partition on each
+ * mapper (possibly reusing these across waves of tasks).
+ */
+class HashShuffleManager(conf: SparkConf) extends ShuffleManager {
+ /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */
+ override def registerShuffle[K, V, C](
+ shuffleId: Int,
+ numMaps: Int,
+ dependency: ShuffleDependency[K, V, C]): ShuffleHandle = {
+ new BaseShuffleHandle(shuffleId, numMaps, dependency)
+ }
+
+ /**
+ * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive).
+ * Called on executors by reduce tasks.
+ */
+ override def getReader[K, C](
+ handle: ShuffleHandle,
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext): ShuffleReader[K, C] = {
+ new HashShuffleReader(
+ handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context)
+ }
+
+ /** Get a writer for a given partition. Called on executors by map tasks. */
+ override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext)
+ : ShuffleWriter[K, V] = {
+ new HashShuffleWriter(handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context)
+ }
+
+ /** Remove a shuffle's metadata from the ShuffleManager. */
+ override def unregisterShuffle(shuffleId: Int): Unit = {}
+
+ /** Shut down this ShuffleManager. */
+ override def stop(): Unit = {}
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
new file mode 100644
index 0000000000000..f6a790309a587
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
+import org.apache.spark.TaskContext
+
+class HashShuffleReader[K, C](
+ handle: BaseShuffleHandle[K, _, C],
+ startPartition: Int,
+ endPartition: Int,
+ context: TaskContext)
+ extends ShuffleReader[K, C]
+{
+ require(endPartition == startPartition + 1,
+ "Hash shuffle currently only supports fetching one partition")
+
+ /** Read the combined key-values for this reduce task */
+ override def read(): Iterator[Product2[K, C]] = {
+ BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context,
+ Serializer.getSerializer(handle.dependency.serializer))
+ }
+
+ /** Close this reader */
+ override def stop(): Unit = ???
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
new file mode 100644
index 0000000000000..4c6749098c110
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.hash
+
+import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleWriter}
+import org.apache.spark.{Logging, MapOutputTracker, SparkEnv, TaskContext}
+import org.apache.spark.storage.{BlockObjectWriter}
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.scheduler.MapStatus
+
+class HashShuffleWriter[K, V](
+ handle: BaseShuffleHandle[K, V, _],
+ mapId: Int,
+ context: TaskContext)
+ extends ShuffleWriter[K, V] with Logging {
+
+ private val dep = handle.dependency
+ private val numOutputSplits = dep.partitioner.numPartitions
+ private val metrics = context.taskMetrics
+ private var stopping = false
+
+ private val blockManager = SparkEnv.get.blockManager
+ private val shuffleBlockManager = blockManager.shuffleBlockManager
+ private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
+ private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser)
+
+ /** Write a record to this task's output */
+ override def write(record: Product2[K, V]): Unit = {
+ val pair = record.asInstanceOf[Product2[Any, Any]]
+ val bucketId = dep.partitioner.getPartition(pair._1)
+ shuffle.writers(bucketId).write(pair)
+ }
+
+ /** Close this writer, passing along whether the map completed */
+ override def stop(success: Boolean): Option[MapStatus] = {
+ try {
+ if (stopping) {
+ return None
+ }
+ stopping = true
+ if (success) {
+ try {
+ return Some(commitWritesAndBuildStatus())
+ } catch {
+ case e: Exception =>
+ revertWrites()
+ throw e
+ }
+ } else {
+ revertWrites()
+ return None
+ }
+ } finally {
+ // Release the writers back to the shuffle block manager.
+ if (shuffle != null && shuffle.writers != null) {
+ try {
+ shuffle.releaseWriters(success)
+ } catch {
+ case e: Exception => logError("Failed to release shuffle writers", e)
+ }
+ }
+ }
+ }
+
+ private def commitWritesAndBuildStatus(): MapStatus = {
+ // Commit the writes. Get the size of each bucket block (total block size).
+ var totalBytes = 0L
+ var totalTime = 0L
+ val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
+ writer.commit()
+ writer.close()
+ val size = writer.fileSegment().length
+ totalBytes += size
+ totalTime += writer.timeWriting()
+ MapOutputTracker.compressSize(size)
+ }
+
+ // Update shuffle metrics.
+ val shuffleMetrics = new ShuffleWriteMetrics
+ shuffleMetrics.shuffleBytesWritten = totalBytes
+ shuffleMetrics.shuffleWriteTime = totalTime
+ metrics.shuffleWriteMetrics = Some(shuffleMetrics)
+
+ new MapStatus(blockManager.blockManagerId, compressedSizes)
+ }
+
+ private def revertWrites(): Unit = {
+ if (shuffle != null && shuffle.writers != null) {
+ for (writer <- shuffle.writers) {
+ writer.revertPartialWrites()
+ writer.close()
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 3b1b6df089b8e..4ce28bb0cf059 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -862,6 +862,59 @@ private[spark] object Utils extends Logging {
Source.fromBytes(buff).mkString
}
+ /**
+ * Return a string containing data across a set of files. The `startIndex`
+ * and `endIndex` is based on the cumulative size of all the files take in
+ * the given order. See figure below for more details.
+ */
+ def offsetBytes(files: Seq[File], start: Long, end: Long): String = {
+ val fileLengths = files.map { _.length }
+ val startIndex = math.max(start, 0)
+ val endIndex = math.min(end, fileLengths.sum)
+ val fileToLength = files.zip(fileLengths).toMap
+ logDebug("Log files: \n" + fileToLength.mkString("\n"))
+
+ val stringBuffer = new StringBuffer((endIndex - startIndex).toInt)
+ var sum = 0L
+ for (file <- files) {
+ val startIndexOfFile = sum
+ val endIndexOfFile = sum + fileToLength(file)
+ logDebug(s"Processing file $file, " +
+ s"with start index = $startIndexOfFile, end index = $endIndex")
+
+ /*
+ ____________
+ range 1: | |
+ | case A |
+
+ files: |==== file 1 ====|====== file 2 ======|===== file 3 =====|
+
+ | case B . case C . case D |
+ range 2: |___________.____________________.______________|
+ */
+
+ if (startIndex <= startIndexOfFile && endIndex >= endIndexOfFile) {
+ // Case C: read the whole file
+ stringBuffer.append(offsetBytes(file.getAbsolutePath, 0, fileToLength(file)))
+ } else if (startIndex > startIndexOfFile && startIndex < endIndexOfFile) {
+ // Case A and B: read from [start of required range] to [end of file / end of range]
+ val effectiveStartIndex = startIndex - startIndexOfFile
+ val effectiveEndIndex = math.min(endIndex - startIndexOfFile, fileToLength(file))
+ stringBuffer.append(Utils.offsetBytes(
+ file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex))
+ } else if (endIndex > startIndexOfFile && endIndex < endIndexOfFile) {
+ // Case D: read from [start of file] to [end of require range]
+ val effectiveStartIndex = math.max(startIndex - startIndexOfFile, 0)
+ val effectiveEndIndex = endIndex - startIndexOfFile
+ stringBuffer.append(Utils.offsetBytes(
+ file.getAbsolutePath, effectiveStartIndex, effectiveEndIndex))
+ }
+ sum += fileToLength(file)
+ logDebug(s"After processing file $file, string built is ${stringBuffer.toString}}")
+ }
+ stringBuffer.toString
+ }
+
/**
* Clone an object using a Spark serializer.
*/
diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
new file mode 100644
index 0000000000000..8e9c3036d09c2
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.logging
+
+import java.io.{File, FileOutputStream, InputStream}
+
+import org.apache.spark.{Logging, SparkConf}
+import org.apache.spark.util.{IntParam, Utils}
+
+/**
+ * Continuously appends the data from an input stream into the given file.
+ */
+private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSize: Int = 8192)
+ extends Logging {
+ @volatile private var outputStream: FileOutputStream = null
+ @volatile private var markedForStop = false // has the appender been asked to stopped
+ @volatile private var stopped = false // has the appender stopped
+
+ // Thread that reads the input stream and writes to file
+ private val writingThread = new Thread("File appending thread for " + file) {
+ setDaemon(true)
+ override def run() {
+ Utils.logUncaughtExceptions {
+ appendStreamToFile()
+ }
+ }
+ }
+ writingThread.start()
+
+ /**
+ * Wait for the appender to stop appending, either because input stream is closed
+ * or because of any error in appending
+ */
+ def awaitTermination() {
+ synchronized {
+ if (!stopped) {
+ wait()
+ }
+ }
+ }
+
+ /** Stop the appender */
+ def stop() {
+ markedForStop = true
+ }
+
+ /** Continuously read chunks from the input stream and append to the file */
+ protected def appendStreamToFile() {
+ try {
+ logDebug("Started appending thread")
+ openFile()
+ val buf = new Array[Byte](bufferSize)
+ var n = 0
+ while (!markedForStop && n != -1) {
+ n = inputStream.read(buf)
+ if (n != -1) {
+ appendToFile(buf, n)
+ }
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Error writing stream to file $file", e)
+ } finally {
+ closeFile()
+ synchronized {
+ stopped = true
+ notifyAll()
+ }
+ }
+ }
+
+ /** Append bytes to the file output stream */
+ protected def appendToFile(bytes: Array[Byte], len: Int) {
+ if (outputStream == null) {
+ openFile()
+ }
+ outputStream.write(bytes, 0, len)
+ }
+
+ /** Open the file output stream */
+ protected def openFile() {
+ outputStream = new FileOutputStream(file, false)
+ logDebug(s"Opened file $file")
+ }
+
+ /** Close the file output stream */
+ protected def closeFile() {
+ outputStream.flush()
+ outputStream.close()
+ logDebug(s"Closed file $file")
+ }
+}
+
+/**
+ * Companion object to [[org.apache.spark.util.logging.FileAppender]] which has helper
+ * functions to choose the correct type of FileAppender based on SparkConf configuration.
+ */
+private[spark] object FileAppender extends Logging {
+
+ /** Create the right appender based on Spark configuration */
+ def apply(inputStream: InputStream, file: File, conf: SparkConf): FileAppender = {
+
+ import RollingFileAppender._
+
+ val rollingStrategy = conf.get(STRATEGY_PROPERTY, STRATEGY_DEFAULT)
+ val rollingSizeBytes = conf.get(SIZE_PROPERTY, STRATEGY_DEFAULT)
+ val rollingInterval = conf.get(INTERVAL_PROPERTY, INTERVAL_DEFAULT)
+
+ def createTimeBasedAppender() = {
+ val validatedParams: Option[(Long, String)] = rollingInterval match {
+ case "daily" =>
+ logInfo(s"Rolling executor logs enabled for $file with daily rolling")
+ Some(24 * 60 * 60 * 1000L, "--YYYY-MM-dd")
+ case "hourly" =>
+ logInfo(s"Rolling executor logs enabled for $file with hourly rolling")
+ Some(60 * 60 * 1000L, "--YYYY-MM-dd--HH")
+ case "minutely" =>
+ logInfo(s"Rolling executor logs enabled for $file with rolling every minute")
+ Some(60 * 1000L, "--YYYY-MM-dd--HH-mm")
+ case IntParam(seconds) =>
+ logInfo(s"Rolling executor logs enabled for $file with rolling $seconds seconds")
+ Some(seconds * 1000L, "--YYYY-MM-dd--HH-mm-ss")
+ case _ =>
+ logWarning(s"Illegal interval for rolling executor logs [$rollingInterval], " +
+ s"rolling logs not enabled")
+ None
+ }
+ validatedParams.map {
+ case (interval, pattern) =>
+ new RollingFileAppender(
+ inputStream, file, new TimeBasedRollingPolicy(interval, pattern), conf)
+ }.getOrElse {
+ new FileAppender(inputStream, file)
+ }
+ }
+
+ def createSizeBasedAppender() = {
+ rollingSizeBytes match {
+ case IntParam(bytes) =>
+ logInfo(s"Rolling executor logs enabled for $file with rolling every $bytes bytes")
+ new RollingFileAppender(inputStream, file, new SizeBasedRollingPolicy(bytes), conf)
+ case _ =>
+ logWarning(
+ s"Illegal size [$rollingSizeBytes] for rolling executor logs, rolling logs not enabled")
+ new FileAppender(inputStream, file)
+ }
+ }
+
+ rollingStrategy match {
+ case "" =>
+ new FileAppender(inputStream, file)
+ case "time" =>
+ createTimeBasedAppender()
+ case "size" =>
+ createSizeBasedAppender()
+ case _ =>
+ logWarning(
+ s"Illegal strategy [$rollingStrategy] for rolling executor logs, " +
+ s"rolling logs not enabled")
+ new FileAppender(inputStream, file)
+ }
+ }
+}
+
+
diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala
new file mode 100644
index 0000000000000..1bbbd20cf076f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.logging
+
+import java.io.{File, FileFilter, InputStream}
+
+import org.apache.commons.io.FileUtils
+import org.apache.spark.SparkConf
+import RollingFileAppender._
+
+/**
+ * Continuously appends data from input stream into the given file, and rolls
+ * over the file after the given interval. The rolled over files are named
+ * based on the given pattern.
+ *
+ * @param inputStream Input stream to read data from
+ * @param activeFile File to write data to
+ * @param rollingPolicy Policy based on which files will be rolled over.
+ * @param conf SparkConf that is used to pass on extra configurations
+ * @param bufferSize Optional buffer size. Used mainly for testing.
+ */
+private[spark] class RollingFileAppender(
+ inputStream: InputStream,
+ activeFile: File,
+ val rollingPolicy: RollingPolicy,
+ conf: SparkConf,
+ bufferSize: Int = DEFAULT_BUFFER_SIZE
+ ) extends FileAppender(inputStream, activeFile, bufferSize) {
+
+ private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1)
+
+ /** Stop the appender */
+ override def stop() {
+ super.stop()
+ }
+
+ /** Append bytes to file after rolling over is necessary */
+ override protected def appendToFile(bytes: Array[Byte], len: Int) {
+ if (rollingPolicy.shouldRollover(len)) {
+ rollover()
+ rollingPolicy.rolledOver()
+ }
+ super.appendToFile(bytes, len)
+ rollingPolicy.bytesWritten(len)
+ }
+
+ /** Rollover the file, by closing the output stream and moving it over */
+ private def rollover() {
+ try {
+ closeFile()
+ moveFile()
+ openFile()
+ if (maxRetainedFiles > 0) {
+ deleteOldFiles()
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Error rolling over $activeFile", e)
+ }
+ }
+
+ /** Move the active log file to a new rollover file */
+ private def moveFile() {
+ val rolloverSuffix = rollingPolicy.generateRolledOverFileSuffix()
+ val rolloverFile = new File(
+ activeFile.getParentFile, activeFile.getName + rolloverSuffix).getAbsoluteFile
+ try {
+ logDebug(s"Attempting to rollover file $activeFile to file $rolloverFile")
+ if (activeFile.exists) {
+ if (!rolloverFile.exists) {
+ FileUtils.moveFile(activeFile, rolloverFile)
+ logInfo(s"Rolled over $activeFile to $rolloverFile")
+ } else {
+ // In case the rollover file name clashes, make a unique file name.
+ // The resultant file names are long and ugly, so this is used only
+ // if there is a name collision. This can be avoided by the using
+ // the right pattern such that name collisions do not occur.
+ var i = 0
+ var altRolloverFile: File = null
+ do {
+ altRolloverFile = new File(activeFile.getParent,
+ s"${activeFile.getName}$rolloverSuffix--$i").getAbsoluteFile
+ i += 1
+ } while (i < 10000 && altRolloverFile.exists)
+
+ logWarning(s"Rollover file $rolloverFile already exists, " +
+ s"rolled over $activeFile to file $altRolloverFile")
+ FileUtils.moveFile(activeFile, altRolloverFile)
+ }
+ } else {
+ logWarning(s"File $activeFile does not exist")
+ }
+ }
+ }
+
+ /** Retain only last few files */
+ private[util] def deleteOldFiles() {
+ try {
+ val rolledoverFiles = activeFile.getParentFile.listFiles(new FileFilter {
+ def accept(f: File): Boolean = {
+ f.getName.startsWith(activeFile.getName) && f != activeFile
+ }
+ }).sorted
+ val filesToBeDeleted = rolledoverFiles.take(
+ math.max(0, rolledoverFiles.size - maxRetainedFiles))
+ filesToBeDeleted.foreach { file =>
+ logInfo(s"Deleting file executor log file ${file.getAbsolutePath}")
+ file.delete()
+ }
+ } catch {
+ case e: Exception =>
+ logError("Error cleaning logs in directory " + activeFile.getParentFile.getAbsolutePath, e)
+ }
+ }
+}
+
+/**
+ * Companion object to [[org.apache.spark.util.logging.RollingFileAppender]]. Defines
+ * names of configurations that configure rolling file appenders.
+ */
+private[spark] object RollingFileAppender {
+ val STRATEGY_PROPERTY = "spark.executor.logs.rolling.strategy"
+ val STRATEGY_DEFAULT = ""
+ val INTERVAL_PROPERTY = "spark.executor.logs.rolling.time.interval"
+ val INTERVAL_DEFAULT = "daily"
+ val SIZE_PROPERTY = "spark.executor.logs.rolling.size.maxBytes"
+ val SIZE_DEFAULT = (1024 * 1024).toString
+ val RETAINED_FILES_PROPERTY = "spark.executor.logs.rolling.maxRetainedFiles"
+ val DEFAULT_BUFFER_SIZE = 8192
+
+ /**
+ * Get the sorted list of rolled over files. This assumes that the all the rolled
+ * over file names are prefixed with the `activeFileName`, and the active file
+ * name has the latest logs. So it sorts all the rolled over logs (that are
+ * prefixed with `activeFileName`) and appends the active file
+ */
+ def getSortedRolledOverFiles(directory: String, activeFileName: String): Seq[File] = {
+ val rolledOverFiles = new File(directory).getAbsoluteFile.listFiles.filter { file =>
+ val fileName = file.getName
+ fileName.startsWith(activeFileName) && fileName != activeFileName
+ }.sorted
+ val activeFile = {
+ val file = new File(directory, activeFileName).getAbsoluteFile
+ if (file.exists) Some(file) else None
+ }
+ rolledOverFiles ++ activeFile
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
new file mode 100644
index 0000000000000..84e5c3c917dcb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.logging
+
+import java.text.SimpleDateFormat
+import java.util.Calendar
+
+import org.apache.spark.Logging
+
+/**
+ * Defines the policy based on which [[org.apache.spark.util.logging.RollingFileAppender]] will
+ * generate rolling files.
+ */
+private[spark] trait RollingPolicy {
+
+ /** Whether rollover should be initiated at this moment */
+ def shouldRollover(bytesToBeWritten: Long): Boolean
+
+ /** Notify that rollover has occurred */
+ def rolledOver()
+
+ /** Notify that bytes have been written */
+ def bytesWritten(bytes: Long)
+
+ /** Get the desired name of the rollover file */
+ def generateRolledOverFileSuffix(): String
+}
+
+/**
+ * Defines a [[org.apache.spark.util.logging.RollingPolicy]] by which files will be rolled
+ * over at a fixed interval.
+ */
+private[spark] class TimeBasedRollingPolicy(
+ var rolloverIntervalMillis: Long,
+ rollingFileSuffixPattern: String,
+ checkIntervalConstraint: Boolean = true // set to false while testing
+ ) extends RollingPolicy with Logging {
+
+ import TimeBasedRollingPolicy._
+ if (checkIntervalConstraint && rolloverIntervalMillis < MINIMUM_INTERVAL_SECONDS * 1000L) {
+ logWarning(s"Rolling interval [${rolloverIntervalMillis/1000L} seconds] is too small. " +
+ s"Setting the interval to the acceptable minimum of $MINIMUM_INTERVAL_SECONDS seconds.")
+ rolloverIntervalMillis = MINIMUM_INTERVAL_SECONDS * 1000L
+ }
+
+ @volatile private var nextRolloverTime = calculateNextRolloverTime()
+ private val formatter = new SimpleDateFormat(rollingFileSuffixPattern)
+
+ /** Should rollover if current time has exceeded next rollover time */
+ def shouldRollover(bytesToBeWritten: Long): Boolean = {
+ System.currentTimeMillis > nextRolloverTime
+ }
+
+ /** Rollover has occurred, so find the next time to rollover */
+ def rolledOver() {
+ nextRolloverTime = calculateNextRolloverTime()
+ logDebug(s"Current time: ${System.currentTimeMillis}, next rollover time: " + nextRolloverTime)
+ }
+
+ def bytesWritten(bytes: Long) { } // nothing to do
+
+ private def calculateNextRolloverTime(): Long = {
+ val now = System.currentTimeMillis()
+ val targetTime = (
+ math.ceil(now.toDouble / rolloverIntervalMillis) * rolloverIntervalMillis
+ ).toLong
+ logDebug(s"Next rollover time is $targetTime")
+ targetTime
+ }
+
+ def generateRolledOverFileSuffix(): String = {
+ formatter.format(Calendar.getInstance.getTime)
+ }
+}
+
+private[spark] object TimeBasedRollingPolicy {
+ val MINIMUM_INTERVAL_SECONDS = 60L // 1 minute
+}
+
+/**
+ * Defines a [[org.apache.spark.util.logging.RollingPolicy]] by which files will be rolled
+ * over after reaching a particular size.
+ */
+private[spark] class SizeBasedRollingPolicy(
+ var rolloverSizeBytes: Long,
+ checkSizeConstraint: Boolean = true // set to false while testing
+ ) extends RollingPolicy with Logging {
+
+ import SizeBasedRollingPolicy._
+ if (checkSizeConstraint && rolloverSizeBytes < MINIMUM_SIZE_BYTES) {
+ logWarning(s"Rolling size [$rolloverSizeBytes bytes] is too small. " +
+ s"Setting the size to the acceptable minimum of $MINIMUM_SIZE_BYTES bytes.")
+ rolloverSizeBytes = MINIMUM_SIZE_BYTES
+ }
+
+ @volatile private var bytesWrittenSinceRollover = 0L
+ val formatter = new SimpleDateFormat("--YYYY-MM-dd--HH-mm-ss--SSSS")
+
+ /** Should rollover if the next set of bytes is going to exceed the size limit */
+ def shouldRollover(bytesToBeWritten: Long): Boolean = {
+ logInfo(s"$bytesToBeWritten + $bytesWrittenSinceRollover > $rolloverSizeBytes")
+ bytesToBeWritten + bytesWrittenSinceRollover > rolloverSizeBytes
+ }
+
+ /** Rollover has occurred, so reset the counter */
+ def rolledOver() {
+ bytesWrittenSinceRollover = 0
+ }
+
+ /** Increment the bytes that have been written in the current file */
+ def bytesWritten(bytes: Long) {
+ bytesWrittenSinceRollover += bytes
+ }
+
+ /** Get the desired name of the rollover file */
+ def generateRolledOverFileSuffix(): String = {
+ formatter.format(Calendar.getInstance.getTime)
+ }
+}
+
+private[spark] object SizeBasedRollingPolicy {
+ val MINIMUM_SIZE_BYTES = RollingFileAppender.DEFAULT_BUFFER_SIZE * 10
+}
+
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 50a62129116f1..ef41bfb88de9d 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -317,6 +317,37 @@ public Integer call(Integer a, Integer b) {
Assert.assertEquals(33, sum);
}
+ @Test
+ public void aggregateByKey() {
+ JavaPairRDD pairs = sc.parallelizePairs(
+ Arrays.asList(
+ new Tuple2(1, 1),
+ new Tuple2(1, 1),
+ new Tuple2(3, 2),
+ new Tuple2(5, 1),
+ new Tuple2(5, 3)), 2);
+
+ Map> sets = pairs.aggregateByKey(new HashSet(),
+ new Function2, Integer, Set>() {
+ @Override
+ public Set call(Set a, Integer b) {
+ a.add(b);
+ return a;
+ }
+ },
+ new Function2, Set, Set>() {
+ @Override
+ public Set call(Set a, Set b) {
+ a.addAll(b);
+ return a;
+ }
+ }).collectAsMap();
+ Assert.assertEquals(3, sets.size());
+ Assert.assertEquals(new HashSet(Arrays.asList(1)), sets.get(1));
+ Assert.assertEquals(new HashSet(Arrays.asList(2)), sets.get(3));
+ Assert.assertEquals(new HashSet(Arrays.asList(1, 3)), sets.get(5));
+ }
+
@SuppressWarnings("unchecked")
@Test
public void foldByKey() {
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 4e7c34e6d1ada..3aab88e9e9196 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark
import scala.collection.mutable
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.SparkContext._
-class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+class AccumulatorSuite extends FunSuite with Matchers with LocalSparkContext {
implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] {
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index dc2db66df60e0..13b415cccb647 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -201,7 +201,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
def newPairRDD = newRDD.map(_ -> 1)
def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
def newBroadcast = sc.broadcast(1 to 100)
- def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _]]) = {
+ def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
getAllDependencies(dep.rdd)
@@ -211,8 +211,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
// Get all the shuffle dependencies
val shuffleDeps = getAllDependencies(rdd)
- .filter(_.isInstanceOf[ShuffleDependency[_, _]])
- .map(_.asInstanceOf[ShuffleDependency[_, _]])
+ .filter(_.isInstanceOf[ShuffleDependency[_, _, _]])
+ .map(_.asInstanceOf[ShuffleDependency[_, _, _]])
(rdd, shuffleDeps)
}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 14ddd6f1ec08f..41c294f727b3c 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark
import org.scalatest.BeforeAndAfter
import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts._
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.scalatest.time.{Millis, Span}
import org.apache.spark.SparkContext._
@@ -31,7 +31,7 @@ class NotSerializableClass
class NotSerializableExn(val notSer: NotSerializableClass) extends Throwable() {}
-class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter
with LocalSparkContext {
val clusterUrl = "local-cluster[2,1,512]"
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index 2c8ef405c944c..a57430e829ced 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -25,7 +25,7 @@ import scala.concurrent.duration._
import scala.concurrent.future
import org.scalatest.{BeforeAndAfter, FunSuite}
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.SparkContext._
import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
@@ -35,7 +35,7 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart}
* (e.g. count) as well as multi-job action (e.g. take). We test the local and cluster schedulers
* in both FIFO and fair scheduling modes.
*/
-class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
+class JobCancellationSuite extends FunSuite with Matchers with BeforeAndAfter
with LocalSparkContext {
override def afterEach() {
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index be6508a40ea61..47112ce66d695 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.SparkContext._
import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
@@ -26,7 +26,7 @@ import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD
import org.apache.spark.serializer.KryoSerializer
import org.apache.spark.util.MutablePair
-class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+class ShuffleSuite extends FunSuite with Matchers with LocalSparkContext {
val conf = new SparkConf(loadDefaults = false)
@@ -58,7 +58,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// default Java serializer cannot handle the non serializable class.
val c = new ShuffledRDD[Int, NonJavaSerializableClass, (Int, NonJavaSerializableClass)](
b, new HashPartitioner(NUM_BLOCKS)).setSerializer(new KryoSerializer(conf))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 10)
@@ -97,7 +97,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
.setSerializer(new KryoSerializer(conf))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
@@ -122,7 +122,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
// NOTE: The default Java serializer should create zero-sized blocks
val c = new ShuffledRDD[Int, Int, (Int, Int)](b, new HashPartitioner(10))
- val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId
+ val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleId
assert(c.count === 4)
val blockSizes = (0 until NUM_BLOCKS).flatMap { id =>
diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala
index d6b93f5fedd3b..4161aede1d1d0 100644
--- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala
@@ -18,9 +18,9 @@
package org.apache.spark.deploy
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
-class ClientSuite extends FunSuite with ShouldMatchers {
+class ClientSuite extends FunSuite with Matchers {
test("correctly validates driver jar URL's") {
ClientArguments.isValidJarUrl("http://someHost:8080/foo.jar") should be (true)
ClientArguments.isValidJarUrl("file://some/path/to/a/jarFile.jar") should be (true)
diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
index bfae32dae0dc5..01ab2d549325c 100644
--- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala
@@ -28,6 +28,7 @@ import org.scalatest.FunSuite
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse}
import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, RecoveryState, WorkerInfo}
import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner}
+import org.apache.spark.SparkConf
class JsonProtocolSuite extends FunSuite {
@@ -116,7 +117,8 @@ class JsonProtocolSuite extends FunSuite {
}
def createExecutorRunner(): ExecutorRunner = {
new ExecutorRunner("appId", 123, createAppDesc(), 4, 1234, null, "workerId", "host",
- new File("sparkHome"), new File("workDir"), "akka://worker", ExecutorState.RUNNING)
+ new File("sparkHome"), new File("workDir"), "akka://worker",
+ new SparkConf, ExecutorState.RUNNING)
}
def createDriverRunner(): DriverRunner = {
new DriverRunner("driverId", new File("workDir"), new File("sparkHome"), createDriverDesc(),
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 02427a4a83506..565c53e9529ff 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -25,9 +25,9 @@ import org.apache.spark.{SparkConf, SparkContext, SparkEnv, SparkException, Test
import org.apache.spark.deploy.SparkSubmit._
import org.apache.spark.util.Utils
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
-class SparkSubmitSuite extends FunSuite with ShouldMatchers {
+class SparkSubmitSuite extends FunSuite with Matchers {
def beforeAll() {
System.setProperty("spark.testing", "true")
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index 8ae387fa0be6f..e5f748d55500d 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -22,6 +22,7 @@ import java.io.File
import org.scalatest.FunSuite
import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
+import org.apache.spark.SparkConf
class ExecutorRunnerTest extends FunSuite {
test("command includes appId") {
@@ -32,7 +33,7 @@ class ExecutorRunnerTest extends FunSuite {
sparkHome, "appUiUrl")
val appId = "12345-worker321-9876"
val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321", f(sparkHome.getOrElse(".")),
- f("ooga"), "blah", ExecutorState.RUNNING)
+ f("ooga"), "blah", new SparkConf, ExecutorState.RUNNING)
assert(er.getCommandSeq.last === appId)
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
index 9ddafc451878d..0b9004448a63e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala
@@ -30,6 +30,19 @@ import org.apache.spark.SparkContext._
import org.apache.spark.{Partitioner, SharedSparkContext}
class PairRDDFunctionsSuite extends FunSuite with SharedSparkContext {
+ test("aggregateByKey") {
+ val pairs = sc.parallelize(Array((1, 1), (1, 1), (3, 2), (5, 1), (5, 3)), 2)
+
+ val sets = pairs.aggregateByKey(new HashSet[Int]())(_ += _, _ ++= _).collect()
+ assert(sets.size === 3)
+ val valuesFor1 = sets.find(_._1 == 1).get._2
+ assert(valuesFor1.toList.sorted === List(1))
+ val valuesFor3 = sets.find(_._1 == 3).get._2
+ assert(valuesFor3.toList.sorted === List(2))
+ val valuesFor5 = sets.find(_._1 == 5).get._2
+ assert(valuesFor5.toList.sorted === List(1, 3))
+ }
+
test("groupByKey") {
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)))
val groups = pairs.groupByKey().collect()
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 2e70a59c2f53e..e94a1e76d410c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -24,6 +24,7 @@ import org.scalatest.FunSuite
import org.apache.spark._
import org.apache.spark.SparkContext._
+import org.apache.spark.util.Utils
class RDDSuite extends FunSuite with SharedSparkContext {
@@ -65,6 +66,13 @@ class RDDSuite extends FunSuite with SharedSparkContext {
}
}
+ test("serialization") {
+ val empty = new EmptyRDD[Int](sc)
+ val serial = Utils.serialize(empty)
+ val deserial: EmptyRDD[Int] = Utils.deserialize(serial)
+ assert(!deserial.toString().isEmpty())
+ }
+
test("countApproxDistinct") {
def error(est: Long, size: Long) = math.abs(est - size) / size.toDouble
diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
index d0619559bb457..656917628f7a8 100644
--- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala
@@ -18,12 +18,12 @@
package org.apache.spark.rdd
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.{Logging, SharedSparkContext}
import org.apache.spark.SparkContext._
-class SortingSuite extends FunSuite with SharedSparkContext with ShouldMatchers with Logging {
+class SortingSuite extends FunSuite with SharedSparkContext with Matchers with Logging {
test("sortByKey") {
val pairs = sc.parallelize(Array((1, 0), (2, 0), (0, 0), (3, 0)), 2)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 5426e578a9ddd..be506e0287a16 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -22,13 +22,13 @@ import java.util.concurrent.Semaphore
import scala.collection.mutable
import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.{LocalSparkContext, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.executor.TaskMetrics
-class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
+class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers
with BeforeAndAfter with BeforeAndAfterAll {
/** Length of time to wait while draining listener events. */
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 81bd8257bc155..d7dbe5164b7f6 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -28,7 +28,7 @@ import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfter, FunSuite, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
-import org.scalatest.matchers.ShouldMatchers._
+import org.scalatest.Matchers
import org.scalatest.time.SpanSugar._
import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
@@ -39,7 +39,8 @@ import org.apache.spark.util.{AkkaUtils, ByteBufferInputStream, SizeEstimator, U
import scala.language.implicitConversions
import scala.language.postfixOps
-class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester {
+class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
+ with PrivateMethodTester {
private val conf = new SparkConf(false)
var store: BlockManager = null
var store2: BlockManager = null
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 8c06a2d9aa4ab..91b4c7b0dd962 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -18,14 +18,14 @@
package org.apache.spark.ui.jobs
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.{LocalSparkContext, SparkConf, Success}
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics}
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
-class JobProgressListenerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matchers {
test("test LRU eviction of stages") {
val conf = new SparkConf()
conf.set("spark.ui.retainedStages", 5.toString)
diff --git a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala
index 63642461e4465..090d48ec921a1 100644
--- a/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/DistributionSuite.scala
@@ -18,13 +18,13 @@
package org.apache.spark.util
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
/**
*
*/
-class DistributionSuite extends FunSuite with ShouldMatchers {
+class DistributionSuite extends FunSuite with Matchers {
test("summary") {
val d = new Distribution((1 to 100).toArray.map{_.toDouble})
val stats = d.statCounter
diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
new file mode 100644
index 0000000000000..53d7f5c6072e6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.io._
+
+import scala.collection.mutable.HashSet
+import scala.reflect._
+
+import org.apache.commons.io.{FileUtils, IOUtils}
+import org.apache.spark.{Logging, SparkConf}
+import org.scalatest.{BeforeAndAfter, FunSuite}
+import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender}
+
+class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging {
+
+ val testFile = new File("FileAppenderSuite-test-" + System.currentTimeMillis).getAbsoluteFile
+
+ before {
+ cleanup()
+ }
+
+ after {
+ cleanup()
+ }
+
+ test("basic file appender") {
+ val testString = (1 to 1000).mkString(", ")
+ val inputStream = IOUtils.toInputStream(testString)
+ val appender = new FileAppender(inputStream, testFile)
+ inputStream.close()
+ appender.awaitTermination()
+ assert(FileUtils.readFileToString(testFile) === testString)
+ }
+
+ test("rolling file appender - time-based rolling") {
+ // setup input stream and appender
+ val testOutputStream = new PipedOutputStream()
+ val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000)
+ val rolloverIntervalMillis = 100
+ val durationMillis = 1000
+ val numRollovers = durationMillis / rolloverIntervalMillis
+ val textToAppend = (1 to numRollovers).map( _.toString * 10 )
+
+ val appender = new RollingFileAppender(testInputStream, testFile,
+ new TimeBasedRollingPolicy(rolloverIntervalMillis, s"--HH-mm-ss-SSSS", false),
+ new SparkConf(), 10)
+
+ testRolling(appender, testOutputStream, textToAppend, rolloverIntervalMillis)
+ }
+
+ test("rolling file appender - size-based rolling") {
+ // setup input stream and appender
+ val testOutputStream = new PipedOutputStream()
+ val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000)
+ val rolloverSize = 1000
+ val textToAppend = (1 to 3).map( _.toString * 1000 )
+
+ val appender = new RollingFileAppender(testInputStream, testFile,
+ new SizeBasedRollingPolicy(rolloverSize, false), new SparkConf(), 99)
+
+ val files = testRolling(appender, testOutputStream, textToAppend, 0)
+ files.foreach { file =>
+ logInfo(file.toString + ": " + file.length + " bytes")
+ assert(file.length <= rolloverSize)
+ }
+ }
+
+ test("rolling file appender - cleaning") {
+ // setup input stream and appender
+ val testOutputStream = new PipedOutputStream()
+ val testInputStream = new PipedInputStream(testOutputStream, 100 * 1000)
+ val conf = new SparkConf().set(RollingFileAppender.RETAINED_FILES_PROPERTY, "10")
+ val appender = new RollingFileAppender(testInputStream, testFile,
+ new SizeBasedRollingPolicy(1000, false), conf, 10)
+
+ // send data to appender through the input stream, and wait for the data to be written
+ val allGeneratedFiles = new HashSet[String]()
+ val items = (1 to 10).map { _.toString * 10000 }
+ for (i <- 0 until items.size) {
+ testOutputStream.write(items(i).getBytes("UTF8"))
+ testOutputStream.flush()
+ allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles(
+ testFile.getParentFile.toString, testFile.getName).map(_.toString)
+
+ Thread.sleep(10)
+ }
+ testOutputStream.close()
+ appender.awaitTermination()
+ logInfo("Appender closed")
+
+ // verify whether the earliest file has been deleted
+ val rolledOverFiles = allGeneratedFiles.filter { _ != testFile.toString }.toArray.sorted
+ logInfo(s"All rolled over files generated:${rolledOverFiles.size}\n" + rolledOverFiles.mkString("\n"))
+ assert(rolledOverFiles.size > 2)
+ val earliestRolledOverFile = rolledOverFiles.head
+ val existingRolledOverFiles = RollingFileAppender.getSortedRolledOverFiles(
+ testFile.getParentFile.toString, testFile.getName).map(_.toString)
+ logInfo("Existing rolled over files:\n" + existingRolledOverFiles.mkString("\n"))
+ assert(!existingRolledOverFiles.toSet.contains(earliestRolledOverFile))
+ }
+
+ test("file appender selection") {
+ // Test whether FileAppender.apply() returns the right type of the FileAppender based
+ // on SparkConf settings.
+
+ def testAppenderSelection[ExpectedAppender: ClassTag, ExpectedRollingPolicy](
+ properties: Seq[(String, String)], expectedRollingPolicyParam: Long = -1): FileAppender = {
+
+ // Set spark conf properties
+ val conf = new SparkConf
+ properties.foreach { p =>
+ conf.set(p._1, p._2)
+ }
+
+ // Create and test file appender
+ val inputStream = new PipedInputStream(new PipedOutputStream())
+ val appender = FileAppender(inputStream, new File("stdout"), conf)
+ assert(appender.isInstanceOf[ExpectedAppender])
+ assert(appender.getClass.getSimpleName ===
+ classTag[ExpectedAppender].runtimeClass.getSimpleName)
+ if (appender.isInstanceOf[RollingFileAppender]) {
+ val rollingPolicy = appender.asInstanceOf[RollingFileAppender].rollingPolicy
+ rollingPolicy.isInstanceOf[ExpectedRollingPolicy]
+ val policyParam = if (rollingPolicy.isInstanceOf[TimeBasedRollingPolicy]) {
+ rollingPolicy.asInstanceOf[TimeBasedRollingPolicy].rolloverIntervalMillis
+ } else {
+ rollingPolicy.asInstanceOf[SizeBasedRollingPolicy].rolloverSizeBytes
+ }
+ assert(policyParam === expectedRollingPolicyParam)
+ }
+ appender
+ }
+
+ import RollingFileAppender._
+
+ def rollingStrategy(strategy: String) = Seq(STRATEGY_PROPERTY -> strategy)
+ def rollingSize(size: String) = Seq(SIZE_PROPERTY -> size)
+ def rollingInterval(interval: String) = Seq(INTERVAL_PROPERTY -> interval)
+
+ val msInDay = 24 * 60 * 60 * 1000L
+ val msInHour = 60 * 60 * 1000L
+ val msInMinute = 60 * 1000L
+
+ // test no strategy -> no rolling
+ testAppenderSelection[FileAppender, Any](Seq.empty)
+
+ // test time based rolling strategy
+ testAppenderSelection[RollingFileAppender, Any](rollingStrategy("time"), msInDay)
+ testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy](
+ rollingStrategy("time") ++ rollingInterval("daily"), msInDay)
+ testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy](
+ rollingStrategy("time") ++ rollingInterval("hourly"), msInHour)
+ testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy](
+ rollingStrategy("time") ++ rollingInterval("minutely"), msInMinute)
+ testAppenderSelection[RollingFileAppender, TimeBasedRollingPolicy](
+ rollingStrategy("time") ++ rollingInterval("123456789"), 123456789 * 1000L)
+ testAppenderSelection[FileAppender, Any](
+ rollingStrategy("time") ++ rollingInterval("xyz"))
+
+ // test size based rolling strategy
+ testAppenderSelection[RollingFileAppender, SizeBasedRollingPolicy](
+ rollingStrategy("size") ++ rollingSize("123456789"), 123456789)
+ testAppenderSelection[FileAppender, Any](rollingSize("xyz"))
+
+ // test illegal strategy
+ testAppenderSelection[FileAppender, Any](rollingStrategy("xyz"))
+ }
+
+ /**
+ * Run the rolling file appender with data and see whether all the data was written correctly
+ * across rolled over files.
+ */
+ def testRolling(
+ appender: FileAppender,
+ outputStream: OutputStream,
+ textToAppend: Seq[String],
+ sleepTimeBetweenTexts: Long
+ ): Seq[File] = {
+ // send data to appender through the input stream, and wait for the data to be written
+ val expectedText = textToAppend.mkString("")
+ for (i <- 0 until textToAppend.size) {
+ outputStream.write(textToAppend(i).getBytes("UTF8"))
+ outputStream.flush()
+ Thread.sleep(sleepTimeBetweenTexts)
+ }
+ logInfo("Data sent to appender")
+ outputStream.close()
+ appender.awaitTermination()
+ logInfo("Appender closed")
+
+ // verify whether all the data written to rolled over files is same as expected
+ val generatedFiles = RollingFileAppender.getSortedRolledOverFiles(
+ testFile.getParentFile.toString, testFile.getName)
+ logInfo("Filtered files: \n" + generatedFiles.mkString("\n"))
+ assert(generatedFiles.size > 1)
+ val allText = generatedFiles.map { file =>
+ FileUtils.readFileToString(file)
+ }.mkString("")
+ assert(allText === expectedText)
+ generatedFiles
+ }
+
+ /** Delete all the generated rolledover files */
+ def cleanup() {
+ testFile.getParentFile.listFiles.filter { file =>
+ file.getName.startsWith(testFile.getName)
+ }.foreach { _.delete() }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
index 32d74d0500b72..cf438a3d72a06 100644
--- a/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/NextIteratorSuite.scala
@@ -22,9 +22,9 @@ import java.util.NoSuchElementException
import scala.collection.mutable.Buffer
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
-class NextIteratorSuite extends FunSuite with ShouldMatchers {
+class NextIteratorSuite extends FunSuite with Matchers {
test("one iteration") {
val i = new StubIterator(Buffer(1))
i.hasNext should be === true
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index 0aad882ed76a8..1ee936bc78f49 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -140,6 +140,38 @@ class UtilsSuite extends FunSuite {
Utils.deleteRecursively(tmpDir2)
}
+ test("reading offset bytes across multiple files") {
+ val tmpDir = Files.createTempDir()
+ tmpDir.deleteOnExit()
+ val files = (1 to 3).map(i => new File(tmpDir, i.toString))
+ Files.write("0123456789", files(0), Charsets.UTF_8)
+ Files.write("abcdefghij", files(1), Charsets.UTF_8)
+ Files.write("ABCDEFGHIJ", files(2), Charsets.UTF_8)
+
+ // Read first few bytes in the 1st file
+ assert(Utils.offsetBytes(files, 0, 5) === "01234")
+
+ // Read bytes within the 1st file
+ assert(Utils.offsetBytes(files, 5, 8) === "567")
+
+ // Read bytes across 1st and 2nd file
+ assert(Utils.offsetBytes(files, 8, 18) === "89abcdefgh")
+
+ // Read bytes across 1st, 2nd and 3rd file
+ assert(Utils.offsetBytes(files, 5, 24) === "56789abcdefghijABCD")
+
+ // Read some nonexistent bytes in the beginning
+ assert(Utils.offsetBytes(files, -5, 18) === "0123456789abcdefgh")
+
+ // Read some nonexistent bytes at the end
+ assert(Utils.offsetBytes(files, 18, 35) === "ijABCDEFGHIJ")
+
+ // Read some nonexistent bytes on both ends
+ assert(Utils.offsetBytes(files, -5, 35) === "0123456789abcdefghijABCDEFGHIJ")
+
+ Utils.deleteRecursively(tmpDir)
+ }
+
test("deserialize long value") {
val testval : Long = 9730889947L
val bbuf = ByteBuffer.allocate(8)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index b024c89d94d33..6a70877356409 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.util.SizeEstimator
-class OpenHashMapSuite extends FunSuite with ShouldMatchers {
+class OpenHashMapSuite extends FunSuite with Matchers {
test("size for specialized, primitive value (int)") {
val capacity = 1024
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
index ff4a98f5dcd4a..68a03e3a0970f 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -18,11 +18,11 @@
package org.apache.spark.util.collection
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.util.SizeEstimator
-class OpenHashSetSuite extends FunSuite with ShouldMatchers {
+class OpenHashSetSuite extends FunSuite with Matchers {
test("size for specialized, primitive int") {
val loadFactor = 0.7
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index e3fca173908e9..8c7df7d73dcd3 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.util.SizeEstimator
-class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers {
+class PrimitiveKeyOpenHashMapSuite extends FunSuite with Matchers {
test("size for specialized, primitive key, value (int, int)") {
val capacity = 1024
diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
index 0865c6386f7cd..e15fd59a5a8bb 100644
--- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala
@@ -18,13 +18,13 @@
package org.apache.spark.util.random
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.util.Utils.times
import scala.language.reflectiveCalls
-class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
+class XORShiftRandomSuite extends FunSuite with Matchers {
def fixture = new {
val seed = 1L
diff --git a/dev/mima b/dev/mima
index ab6bd4469b0e8..b68800d6d0173 100755
--- a/dev/mima
+++ b/dev/mima
@@ -23,6 +23,9 @@ set -o pipefail
FWDIR="$(cd `dirname $0`/..; pwd)"
cd $FWDIR
+echo -e "q\n" | sbt/sbt oldDeps/update
+
+export SPARK_CLASSPATH=`find lib_managed \( -name '*spark*jar' -a -type f \) -printf "%p:" `
./bin/spark-class org.apache.spark.tools.GenerateMIMAIgnore
echo -e "q\n" | sbt/sbt mima-report-binary-issues | grep -v -e "info.*Resolving"
ret_val=$?
@@ -31,5 +34,5 @@ if [ $ret_val != 0 ]; then
echo "NOTE: Exceptions to binary compatibility can be added in project/MimaExcludes.scala"
fi
-rm -f .generated-mima-excludes
+rm -f .generated-mima*
exit $ret_val
diff --git a/dev/run-tests b/dev/run-tests
index 93d6692f83ca8..c82a47ebb618b 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -73,9 +73,6 @@ fi
echo "========================================================================="
echo "Running PySpark tests"
echo "========================================================================="
-if [ -z "$PYSPARK_PYTHON" ]; then
- export PYSPARK_PYTHON=/usr/local/bin/python2.7
-fi
./python/run-tests
echo "========================================================================="
diff --git a/docs/configuration.md b/docs/configuration.md
index 71fafa573467f..b84104cc7e653 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -784,6 +784,45 @@ Apart from these, the following properties are also available, and may be useful
higher memory usage in Spark.
+
+ spark.executor.logs.rolling.strategy |
+ (none) |
+
+ Set the strategy of rolling of executor logs. By default it is disabled. It can
+ be set to "time" (time-based rolling) or "size" (size-based rolling). For "time",
+ use spark.executor.logs.rolling.time.interval to set the rolling interval.
+ For "size", use spark.executor.logs.rolling.size.maxBytes to set
+ the maximum file size for rolling.
+ |
+
+
+ spark.executor.logs.rolling.time.interval |
+ daily |
+
+ Set the time interval by which the executor logs will be rolled over.
+ Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or
+ any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles
+ for automatic cleaning of old logs.
+ |
+
+
+ spark.executor.logs.rolling.size.maxBytes |
+ (none) |
+
+ Set the max size of the file by which the executor logs will be rolled over.
+ Rolling is disabled by default. Value is set in terms of bytes.
+ See spark.executor.logs.rolling.maxRetainedFiles
+ for automatic cleaning of old logs.
+ |
+
+
+ spark.executor.logs.rolling.maxRetainedFiles |
+ (none) |
+
+ Sets the number of latest rolling log files that are going to be retained by the system.
+ Older log files will be deleted. Disabled by default.
+ |
+
#### Cluster Managers
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 7989e02dfb732..79784682bfd1b 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -890,6 +890,10 @@ for details.
reduceByKey(func, [numTasks]) |
When called on a dataset of (K, V) pairs, returns a dataset of (K, V) pairs where the values for each key are aggregated using the given reduce function. Like in groupByKey , the number of reduce tasks is configurable through an optional second argument. |
+
+ aggregateByKey(zeroValue)(seqOp, combOp, [numTasks]) |
+ When called on a dataset of (K, V) pairs, returns a dataset of (K, U) pairs where the values for each key are aggregated using the given combine functions and a neutral "zero" value. Allows an aggregated value type that is different than the input value type, while avoiding unnecessary allocations. Like in groupByKey , the number of reduce tasks is configurable through an optional second argument. |
+
sortByKey([ascending], [numTasks]) |
When called on a dataset of (K, V) pairs where K implements Ordered, returns a dataset of (K, V) pairs sorted by keys in ascending or descending order, as specified in the boolean ascending argument. |
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 9d5748ba4bc23..52a89cb2481ca 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -200,6 +200,7 @@ def get_spark_shark_version(opts):
sys.exit(1)
return (version, spark_shark_map[version])
+
# Attempt to resolve an appropriate AMI given the architecture and
# region of the request.
def get_spark_ami(opts):
@@ -418,6 +419,16 @@ def launch_cluster(conn, opts, cluster_name):
master_nodes = master_res.instances
print "Launched master in %s, regid = %s" % (zone, master_res.id)
+ # Give the instances descriptive names
+ for master in master_nodes:
+ master.add_tag(
+ key='Name',
+ value='spark-{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id))
+ for slave in slave_nodes:
+ slave.add_tag(
+ key='Name',
+ value='spark-{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id))
+
# Return all the instances
return (master_nodes, slave_nodes)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index 6eb41e7ba36fb..28e201d279f41 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -50,6 +50,8 @@ object MovieLensALS {
numIterations: Int = 20,
lambda: Double = 1.0,
rank: Int = 10,
+ numUserBlocks: Int = -1,
+ numProductBlocks: Int = -1,
implicitPrefs: Boolean = false)
def main(args: Array[String]) {
@@ -67,8 +69,14 @@ object MovieLensALS {
.text(s"lambda (smoothing constant), default: ${defaultParams.lambda}")
.action((x, c) => c.copy(lambda = x))
opt[Unit]("kryo")
- .text(s"use Kryo serialization")
+ .text("use Kryo serialization")
.action((_, c) => c.copy(kryo = true))
+ opt[Int]("numUserBlocks")
+ .text(s"number of user blocks, default: ${defaultParams.numUserBlocks} (auto)")
+ .action((x, c) => c.copy(numUserBlocks = x))
+ opt[Int]("numProductBlocks")
+ .text(s"number of product blocks, default: ${defaultParams.numProductBlocks} (auto)")
+ .action((x, c) => c.copy(numProductBlocks = x))
opt[Unit]("implicitPrefs")
.text("use implicit preference")
.action((_, c) => c.copy(implicitPrefs = true))
@@ -160,6 +168,8 @@ object MovieLensALS {
.setIterations(params.numIterations)
.setLambda(params.lambda)
.setImplicitPrefs(params.implicitPrefs)
+ .setUserBlocks(params.numUserBlocks)
+ .setProductBlocks(params.numProductBlocks)
.run(training)
val rmse = computeRmse(model, test, params.implicitPrefs)
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index 5be33f1d5c428..ed35e34ad45ab 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -71,12 +71,12 @@ class SparkFlumeEvent() extends Externalizable {
for (i <- 0 until numHeaders) {
val keyLength = in.readInt()
val keyBuff = new Array[Byte](keyLength)
- in.read(keyBuff)
+ in.readFully(keyBuff)
val key : String = Utils.deserialize(keyBuff)
val valLength = in.readInt()
val valBuff = new Array[Byte](valLength)
- in.read(valBuff)
+ in.readFully(valBuff)
val value : String = Utils.deserialize(valBuff)
headers.put(key, value)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index d743bd7dd1825..cc56fd6ef28d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -61,7 +61,7 @@ private[recommendation] case class InLinkBlock(
* A more compact class to represent a rating than Tuple3[Int, Int, Double].
*/
@Experimental
-case class Rating(val user: Int, val product: Int, val rating: Double)
+case class Rating(user: Int, product: Int, rating: Double)
/**
* Alternating Least Squares matrix factorization.
@@ -93,7 +93,8 @@ case class Rating(val user: Int, val product: Int, val rating: Double)
* preferences rather than explicit ratings given to items.
*/
class ALS private (
- private var numBlocks: Int,
+ private var numUserBlocks: Int,
+ private var numProductBlocks: Int,
private var rank: Int,
private var iterations: Int,
private var lambda: Double,
@@ -106,14 +107,31 @@ class ALS private (
* Constructs an ALS instance with default parameters: {numBlocks: -1, rank: 10, iterations: 10,
* lambda: 0.01, implicitPrefs: false, alpha: 1.0}.
*/
- def this() = this(-1, 10, 10, 0.01, false, 1.0)
+ def this() = this(-1, -1, 10, 10, 0.01, false, 1.0)
/**
- * Set the number of blocks to parallelize the computation into; pass -1 for an auto-configured
- * number of blocks. Default: -1.
+ * Set the number of blocks for both user blocks and product blocks to parallelize the computation
+ * into; pass -1 for an auto-configured number of blocks. Default: -1.
*/
def setBlocks(numBlocks: Int): ALS = {
- this.numBlocks = numBlocks
+ this.numUserBlocks = numBlocks
+ this.numProductBlocks = numBlocks
+ this
+ }
+
+ /**
+ * Set the number of user blocks to parallelize the computation.
+ */
+ def setUserBlocks(numUserBlocks: Int): ALS = {
+ this.numUserBlocks = numUserBlocks
+ this
+ }
+
+ /**
+ * Set the number of product blocks to parallelize the computation.
+ */
+ def setProductBlocks(numProductBlocks: Int): ALS = {
+ this.numProductBlocks = numProductBlocks
this
}
@@ -176,31 +194,32 @@ class ALS private (
def run(ratings: RDD[Rating]): MatrixFactorizationModel = {
val sc = ratings.context
- val numBlocks = if (this.numBlocks == -1) {
+ val numUserBlocks = if (this.numUserBlocks == -1) {
math.max(sc.defaultParallelism, ratings.partitions.size / 2)
} else {
- this.numBlocks
+ this.numUserBlocks
}
-
- val partitioner = new Partitioner {
- val numPartitions = numBlocks
-
- def getPartition(x: Any): Int = {
- Utils.nonNegativeMod(byteswap32(x.asInstanceOf[Int]), numPartitions)
- }
+ val numProductBlocks = if (this.numProductBlocks == -1) {
+ math.max(sc.defaultParallelism, ratings.partitions.size / 2)
+ } else {
+ this.numProductBlocks
}
- val ratingsByUserBlock = ratings.map{ rating =>
- (partitioner.getPartition(rating.user), rating)
+ val userPartitioner = new ALSPartitioner(numUserBlocks)
+ val productPartitioner = new ALSPartitioner(numProductBlocks)
+
+ val ratingsByUserBlock = ratings.map { rating =>
+ (userPartitioner.getPartition(rating.user), rating)
}
- val ratingsByProductBlock = ratings.map{ rating =>
- (partitioner.getPartition(rating.product),
+ val ratingsByProductBlock = ratings.map { rating =>
+ (productPartitioner.getPartition(rating.product),
Rating(rating.product, rating.user, rating.rating))
}
- val (userInLinks, userOutLinks) = makeLinkRDDs(numBlocks, ratingsByUserBlock, partitioner)
+ val (userInLinks, userOutLinks) =
+ makeLinkRDDs(numUserBlocks, numProductBlocks, ratingsByUserBlock, productPartitioner)
val (productInLinks, productOutLinks) =
- makeLinkRDDs(numBlocks, ratingsByProductBlock, partitioner)
+ makeLinkRDDs(numProductBlocks, numUserBlocks, ratingsByProductBlock, userPartitioner)
userInLinks.setName("userInLinks")
userOutLinks.setName("userOutLinks")
productInLinks.setName("productInLinks")
@@ -232,27 +251,27 @@ class ALS private (
users.setName(s"users-$iter").persist()
val YtY = Some(sc.broadcast(computeYtY(users)))
val previousProducts = products
- products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
- alpha, YtY)
+ products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks,
+ userPartitioner, rank, lambda, alpha, YtY)
previousProducts.unpersist()
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
products.setName(s"products-$iter").persist()
val XtX = Some(sc.broadcast(computeYtY(products)))
val previousUsers = users
- users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
- alpha, XtX)
+ users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks,
+ productPartitioner, rank, lambda, alpha, XtX)
previousUsers.unpersist()
}
} else {
for (iter <- 1 to iterations) {
// perform ALS update
logInfo("Re-computing I given U (Iteration %d/%d)".format(iter, iterations))
- products = updateFeatures(users, userOutLinks, productInLinks, partitioner, rank, lambda,
- alpha, YtY = None)
+ products = updateFeatures(numProductBlocks, users, userOutLinks, productInLinks,
+ userPartitioner, rank, lambda, alpha, YtY = None)
products.setName(s"products-$iter")
logInfo("Re-computing U given I (Iteration %d/%d)".format(iter, iterations))
- users = updateFeatures(products, productOutLinks, userInLinks, partitioner, rank, lambda,
- alpha, YtY = None)
+ users = updateFeatures(numUserBlocks, products, productOutLinks, userInLinks,
+ productPartitioner, rank, lambda, alpha, YtY = None)
users.setName(s"users-$iter")
}
}
@@ -340,9 +359,10 @@ class ALS private (
/**
* Flatten out blocked user or product factors into an RDD of (id, factor vector) pairs
*/
- private def unblockFactors(blockedFactors: RDD[(Int, Array[Array[Double]])],
- outLinks: RDD[(Int, OutLinkBlock)]) = {
- blockedFactors.join(outLinks).flatMap{ case (b, (factors, outLinkBlock)) =>
+ private def unblockFactors(
+ blockedFactors: RDD[(Int, Array[Array[Double]])],
+ outLinks: RDD[(Int, OutLinkBlock)]): RDD[(Int, Array[Double])] = {
+ blockedFactors.join(outLinks).flatMap { case (b, (factors, outLinkBlock)) =>
for (i <- 0 until factors.length) yield (outLinkBlock.elementIds(i), factors(i))
}
}
@@ -351,14 +371,14 @@ class ALS private (
* Make the out-links table for a block of the users (or products) dataset given the list of
* (user, product, rating) values for the users in that block (or the opposite for products).
*/
- private def makeOutLinkBlock(numBlocks: Int, ratings: Array[Rating],
- partitioner: Partitioner): OutLinkBlock = {
+ private def makeOutLinkBlock(numProductBlocks: Int, ratings: Array[Rating],
+ productPartitioner: Partitioner): OutLinkBlock = {
val userIds = ratings.map(_.user).distinct.sorted
val numUsers = userIds.length
val userIdToPos = userIds.zipWithIndex.toMap
- val shouldSend = Array.fill(numUsers)(new BitSet(numBlocks))
+ val shouldSend = Array.fill(numUsers)(new BitSet(numProductBlocks))
for (r <- ratings) {
- shouldSend(userIdToPos(r.user))(partitioner.getPartition(r.product)) = true
+ shouldSend(userIdToPos(r.user))(productPartitioner.getPartition(r.product)) = true
}
OutLinkBlock(userIds, shouldSend)
}
@@ -367,18 +387,17 @@ class ALS private (
* Make the in-links table for a block of the users (or products) dataset given a list of
* (user, product, rating) values for the users in that block (or the opposite for products).
*/
- private def makeInLinkBlock(numBlocks: Int, ratings: Array[Rating],
- partitioner: Partitioner): InLinkBlock = {
+ private def makeInLinkBlock(numProductBlocks: Int, ratings: Array[Rating],
+ productPartitioner: Partitioner): InLinkBlock = {
val userIds = ratings.map(_.user).distinct.sorted
- val numUsers = userIds.length
val userIdToPos = userIds.zipWithIndex.toMap
// Split out our ratings by product block
- val blockRatings = Array.fill(numBlocks)(new ArrayBuffer[Rating])
+ val blockRatings = Array.fill(numProductBlocks)(new ArrayBuffer[Rating])
for (r <- ratings) {
- blockRatings(partitioner.getPartition(r.product)) += r
+ blockRatings(productPartitioner.getPartition(r.product)) += r
}
- val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numBlocks)
- for (productBlock <- 0 until numBlocks) {
+ val ratingsForBlock = new Array[Array[(Array[Int], Array[Double])]](numProductBlocks)
+ for (productBlock <- 0 until numProductBlocks) {
// Create an array of (product, Seq(Rating)) ratings
val groupedRatings = blockRatings(productBlock).groupBy(_.product).toArray
// Sort them by product ID
@@ -400,14 +419,16 @@ class ALS private (
* the users (or (blockId, (p, u, r)) for the products). We create these simultaneously to avoid
* having to shuffle the (blockId, (u, p, r)) RDD twice, or to cache it.
*/
- private def makeLinkRDDs(numBlocks: Int, ratings: RDD[(Int, Rating)], partitioner: Partitioner)
- : (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) =
- {
- val grouped = ratings.partitionBy(new HashPartitioner(numBlocks))
+ private def makeLinkRDDs(
+ numUserBlocks: Int,
+ numProductBlocks: Int,
+ ratingsByUserBlock: RDD[(Int, Rating)],
+ productPartitioner: Partitioner): (RDD[(Int, InLinkBlock)], RDD[(Int, OutLinkBlock)]) = {
+ val grouped = ratingsByUserBlock.partitionBy(new HashPartitioner(numUserBlocks))
val links = grouped.mapPartitionsWithIndex((blockId, elements) => {
- val ratings = elements.map{_._2}.toArray
- val inLinkBlock = makeInLinkBlock(numBlocks, ratings, partitioner)
- val outLinkBlock = makeOutLinkBlock(numBlocks, ratings, partitioner)
+ val ratings = elements.map(_._2).toArray
+ val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner)
+ val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner)
Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
}, true)
val inLinks = links.mapValues(_._1)
@@ -439,26 +460,24 @@ class ALS private (
* It returns an RDD of new feature vectors for each user block.
*/
private def updateFeatures(
+ numUserBlocks: Int,
products: RDD[(Int, Array[Array[Double]])],
productOutLinks: RDD[(Int, OutLinkBlock)],
userInLinks: RDD[(Int, InLinkBlock)],
- partitioner: Partitioner,
+ productPartitioner: Partitioner,
rank: Int,
lambda: Double,
alpha: Double,
- YtY: Option[Broadcast[DoubleMatrix]])
- : RDD[(Int, Array[Array[Double]])] =
- {
- val numBlocks = products.partitions.size
+ YtY: Option[Broadcast[DoubleMatrix]]): RDD[(Int, Array[Array[Double]])] = {
productOutLinks.join(products).flatMap { case (bid, (outLinkBlock, factors)) =>
- val toSend = Array.fill(numBlocks)(new ArrayBuffer[Array[Double]])
- for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numBlocks) {
+ val toSend = Array.fill(numUserBlocks)(new ArrayBuffer[Array[Double]])
+ for (p <- 0 until outLinkBlock.elementIds.length; userBlock <- 0 until numUserBlocks) {
if (outLinkBlock.shouldSend(p)(userBlock)) {
toSend(userBlock) += factors(p)
}
}
toSend.zipWithIndex.map{ case (buf, idx) => (idx, (bid, buf.toArray)) }
- }.groupByKey(partitioner)
+ }.groupByKey(productPartitioner)
.join(userInLinks)
.mapValues{ case (messages, inLinkBlock) =>
updateBlock(messages, inLinkBlock, rank, lambda, alpha, YtY)
@@ -475,7 +494,7 @@ class ALS private (
{
// Sort the incoming block factor messages by block ID and make them an array
val blockFactors = messages.toSeq.sortBy(_._1).map(_._2).toArray // Array[Array[Double]]
- val numBlocks = blockFactors.length
+ val numProductBlocks = blockFactors.length
val numUsers = inLinkBlock.elementIds.length
// We'll sum up the XtXes using vectors that represent only the lower-triangular part, since
@@ -488,9 +507,12 @@ class ALS private (
val tempXtX = DoubleMatrix.zeros(triangleSize)
val fullXtX = DoubleMatrix.zeros(rank, rank)
+ // Count the number of ratings each user gives to provide user-specific regularization
+ val numRatings = Array.fill(numUsers)(0)
+
// Compute the XtX and Xy values for each user by adding products it rated in each product
// block
- for (productBlock <- 0 until numBlocks) {
+ for (productBlock <- 0 until numProductBlocks) {
var p = 0
while (p < blockFactors(productBlock).length) {
val x = wrapDoubleArray(blockFactors(productBlock)(p))
@@ -500,6 +522,7 @@ class ALS private (
if (implicitPrefs) {
var i = 0
while (i < us.length) {
+ numRatings(us(i)) += 1
// Extension to the original paper to handle rs(i) < 0. confidence is a function
// of |rs(i)| instead so that it is never negative:
val confidence = 1 + alpha * abs(rs(i))
@@ -515,6 +538,7 @@ class ALS private (
} else {
var i = 0
while (i < us.length) {
+ numRatings(us(i)) += 1
userXtX(us(i)).addi(tempXtX)
SimpleBlas.axpy(rs(i), x, userXy(us(i)))
i += 1
@@ -531,9 +555,10 @@ class ALS private (
// Compute the full XtX matrix from the lower-triangular part we got above
fillFullMatrix(userXtX(index), fullXtX)
// Add regularization
+ val regParam = numRatings(index) * lambda
var i = 0
while (i < rank) {
- fullXtX.data(i * rank + i) += lambda
+ fullXtX.data(i * rank + i) += regParam
i += 1
}
// Solve the resulting matrix, which is symmetric and positive-definite
@@ -579,6 +604,23 @@ class ALS private (
}
}
+/**
+ * Partitioner for ALS.
+ */
+private[recommendation] class ALSPartitioner(override val numPartitions: Int) extends Partitioner {
+ override def getPartition(key: Any): Int = {
+ Utils.nonNegativeMod(byteswap32(key.asInstanceOf[Int]), numPartitions)
+ }
+
+ override def equals(obj: Any): Boolean = {
+ obj match {
+ case p: ALSPartitioner =>
+ this.numPartitions == p.numPartitions
+ case _ =>
+ false
+ }
+ }
+}
/**
* Top-level methods for calling Alternating Least Squares (ALS) matrix factorization.
@@ -606,7 +648,7 @@ object ALS {
blocks: Int,
seed: Long
): MatrixFactorizationModel = {
- new ALS(blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings)
+ new ALS(blocks, blocks, rank, iterations, lambda, false, 1.0, seed).run(ratings)
}
/**
@@ -629,7 +671,7 @@ object ALS {
lambda: Double,
blocks: Int
): MatrixFactorizationModel = {
- new ALS(blocks, rank, iterations, lambda, false, 1.0).run(ratings)
+ new ALS(blocks, blocks, rank, iterations, lambda, false, 1.0).run(ratings)
}
/**
@@ -689,7 +731,7 @@ object ALS {
alpha: Double,
seed: Long
): MatrixFactorizationModel = {
- new ALS(blocks, rank, iterations, lambda, true, alpha, seed).run(ratings)
+ new ALS(blocks, blocks, rank, iterations, lambda, true, alpha, seed).run(ratings)
}
/**
@@ -714,7 +756,7 @@ object ALS {
blocks: Int,
alpha: Double
): MatrixFactorizationModel = {
- new ALS(blocks, rank, iterations, lambda, true, alpha).run(ratings)
+ new ALS(blocks, blocks, rank, iterations, lambda, true, alpha).run(ratings)
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 4d7b984e3ec29..44b757b6a1fb7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -21,7 +21,7 @@ import scala.util.Random
import scala.collection.JavaConversions._
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
@@ -56,7 +56,7 @@ object LogisticRegressionSuite {
}
}
-class LogisticRegressionSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Matchers {
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
prediction != expected.label
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 8a16284118cf7..951b4f7c6e6f4 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -21,7 +21,7 @@ import scala.util.Random
import scala.collection.JavaConversions._
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
@@ -61,7 +61,7 @@ object GradientDescentSuite {
}
}
-class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers {
test("Assert the loss is decreasing.") {
val nPoints = 10000
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
index 820eca9b1bf65..4b1850659a18e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala
@@ -18,13 +18,13 @@
package org.apache.spark.mllib.optimization
import org.scalatest.FunSuite
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
-class LBFGSSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
val nPoints = 10000
val A = 2.0
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
index 37c9b9d085841..81bebec8c7a39 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/ALSSuite.scala
@@ -121,6 +121,10 @@ class ALSSuite extends FunSuite with LocalSparkContext {
testALS(100, 200, 2, 15, 0.7, 0.4, true, false, true)
}
+ test("rank-2 matrices with different user and product blocks") {
+ testALS(100, 200, 2, 15, 0.7, 0.4, numUserBlocks = 4, numProductBlocks = 2)
+ }
+
test("pseudorandomness") {
val ratings = sc.parallelize(ALSSuite.generateRatings(10, 20, 5, 0.5, false, false)._1, 2)
val model11 = ALS.train(ratings, 5, 1, 1.0, 2, 1)
@@ -153,35 +157,52 @@ class ALSSuite extends FunSuite with LocalSparkContext {
}
test("NNALS, rank 2") {
- testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, false)
+ testALS(100, 200, 2, 15, 0.7, 0.4, false, false, false, -1, -1, false)
}
/**
* Test if we can correctly factorize R = U * P where U and P are of known rank.
*
- * @param users number of users
- * @param products number of products
- * @param features number of features (rank of problem)
- * @param iterations number of iterations to run
- * @param samplingRate what fraction of the user-product pairs are known
+ * @param users number of users
+ * @param products number of products
+ * @param features number of features (rank of problem)
+ * @param iterations number of iterations to run
+ * @param samplingRate what fraction of the user-product pairs are known
* @param matchThreshold max difference allowed to consider a predicted rating correct
- * @param implicitPrefs flag to test implicit feedback
- * @param bulkPredict flag to test bulk prediciton
+ * @param implicitPrefs flag to test implicit feedback
+ * @param bulkPredict flag to test bulk prediciton
* @param negativeWeights whether the generated data can contain negative values
- * @param numBlocks number of blocks to partition users and products into
+ * @param numUserBlocks number of user blocks to partition users into
+ * @param numProductBlocks number of product blocks to partition products into
* @param negativeFactors whether the generated user/product factors can have negative entries
*/
- def testALS(users: Int, products: Int, features: Int, iterations: Int,
- samplingRate: Double, matchThreshold: Double, implicitPrefs: Boolean = false,
- bulkPredict: Boolean = false, negativeWeights: Boolean = false, numBlocks: Int = -1,
- negativeFactors: Boolean = true)
- {
+ def testALS(
+ users: Int,
+ products: Int,
+ features: Int,
+ iterations: Int,
+ samplingRate: Double,
+ matchThreshold: Double,
+ implicitPrefs: Boolean = false,
+ bulkPredict: Boolean = false,
+ negativeWeights: Boolean = false,
+ numUserBlocks: Int = -1,
+ numProductBlocks: Int = -1,
+ negativeFactors: Boolean = true) {
val (sampledRatings, trueRatings, truePrefs) = ALSSuite.generateRatings(users, products,
features, samplingRate, implicitPrefs, negativeWeights, negativeFactors)
- val model = (new ALS().setBlocks(numBlocks).setRank(features).setIterations(iterations)
- .setAlpha(1.0).setImplicitPrefs(implicitPrefs).setLambda(0.01).setSeed(0L)
- .setNonnegative(!negativeFactors).run(sc.parallelize(sampledRatings)))
+ val model = new ALS()
+ .setUserBlocks(numUserBlocks)
+ .setProductBlocks(numProductBlocks)
+ .setRank(features)
+ .setIterations(iterations)
+ .setAlpha(1.0)
+ .setImplicitPrefs(implicitPrefs)
+ .setLambda(0.01)
+ .setSeed(0L)
+ .setNonnegative(!negativeFactors)
+ .run(sc.parallelize(sampledRatings))
val predictedU = new DoubleMatrix(users, features)
for ((u, vec) <- model.userFeatures.collect(); i <- 0 until features) {
@@ -208,8 +229,9 @@ class ALSSuite extends FunSuite with LocalSparkContext {
val prediction = predictedRatings.get(u, p)
val correct = trueRatings.get(u, p)
if (math.abs(prediction - correct) > matchThreshold) {
- fail("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s".format(
- u, p, correct, prediction, trueRatings, predictedRatings, predictedU, predictedP))
+ fail(("Model failed to predict (%d, %d): %f vs %f\ncorr: %s\npred: %s\nU: %s\n P: %s")
+ .format(u, p, correct, prediction, trueRatings, predictedRatings, predictedU,
+ predictedP))
}
}
} else {
diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala
index 1477809943573..bb2d73741c3bf 100644
--- a/project/MimaBuild.scala
+++ b/project/MimaBuild.scala
@@ -15,16 +15,26 @@
* limitations under the License.
*/
-import com.typesafe.tools.mima.core.{MissingTypesProblem, MissingClassProblem, ProblemFilters}
+import com.typesafe.tools.mima.core._
+import com.typesafe.tools.mima.core.MissingClassProblem
+import com.typesafe.tools.mima.core.MissingTypesProblem
import com.typesafe.tools.mima.core.ProblemFilters._
import com.typesafe.tools.mima.plugin.MimaKeys.{binaryIssueFilters, previousArtifact}
import com.typesafe.tools.mima.plugin.MimaPlugin.mimaDefaultSettings
import sbt._
object MimaBuild {
+
+ def excludeMember(fullName: String) = Seq(
+ ProblemFilters.exclude[MissingMethodProblem](fullName),
+ ProblemFilters.exclude[MissingFieldProblem](fullName),
+ ProblemFilters.exclude[IncompatibleResultTypeProblem](fullName),
+ ProblemFilters.exclude[IncompatibleMethTypeProblem](fullName),
+ ProblemFilters.exclude[IncompatibleFieldTypeProblem](fullName)
+ )
+
// Exclude a single class and its corresponding object
- def excludeClass(className: String) = {
- Seq(
+ def excludeClass(className: String) = Seq(
excludePackage(className),
ProblemFilters.exclude[MissingClassProblem](className),
ProblemFilters.exclude[MissingTypesProblem](className),
@@ -32,7 +42,7 @@ object MimaBuild {
ProblemFilters.exclude[MissingClassProblem](className + "$"),
ProblemFilters.exclude[MissingTypesProblem](className + "$")
)
- }
+
// Exclude a Spark class, that is in the package org.apache.spark
def excludeSparkClass(className: String) = {
excludeClass("org.apache.spark." + className)
@@ -49,20 +59,25 @@ object MimaBuild {
val defaultExcludes = Seq()
// Read package-private excludes from file
- val excludeFilePath = (base.getAbsolutePath + "/.generated-mima-excludes")
- val excludeFile = file(excludeFilePath)
+ val classExcludeFilePath = file(base.getAbsolutePath + "/.generated-mima-class-excludes")
+ val memberExcludeFilePath = file(base.getAbsolutePath + "/.generated-mima-member-excludes")
+
val ignoredClasses: Seq[String] =
- if (!excludeFile.exists()) {
+ if (!classExcludeFilePath.exists()) {
Seq()
} else {
- IO.read(excludeFile).split("\n")
+ IO.read(classExcludeFilePath).split("\n")
}
+ val ignoredMembers: Seq[String] =
+ if (!memberExcludeFilePath.exists()) {
+ Seq()
+ } else {
+ IO.read(memberExcludeFilePath).split("\n")
+ }
-
- val externalExcludeFileClasses = ignoredClasses.flatMap(excludeClass)
-
- defaultExcludes ++ externalExcludeFileClasses ++ MimaExcludes.excludes
+ defaultExcludes ++ ignoredClasses.flatMap(excludeClass) ++
+ ignoredMembers.flatMap(excludeMember) ++ MimaExcludes.excludes
}
def mimaSettings(sparkHome: File) = mimaDefaultSettings ++ Seq(
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index dd7efceb23c96..ee629794f60ad 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -52,7 +52,18 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.java.JavaRDDLike.countApproxDistinct$default$1"),
ProblemFilters.exclude[MissingMethodProblem](
- "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1")
+ "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.rdd.PairRDDFunctions.org$apache$spark$rdd$PairRDDFunctions$$"
+ + "createZero$1")
+ ) ++
+ Seq( // Ignore some private methods in ALS.
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$^dateFeatures"),
+ ProblemFilters.exclude[MissingMethodProblem]( // The only public constructor is the one without arguments.
+ "org.apache.spark.mllib.recommendation.ALS.this"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.recommendation.ALS.org$apache$spark$mllib$recommendation$ALS$$$default$7")
) ++
MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index ad2897d3fb46d..2d60a44f04f6f 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -59,6 +59,10 @@ object SparkBuild extends Build {
lazy val core = Project("core", file("core"), settings = coreSettings)
+ /** Following project only exists to pull previous artifacts of Spark for generating
+ Mima ignores. For more information see: SPARK 2071 */
+ lazy val oldDeps = Project("oldDeps", file("dev"), settings = oldDepsSettings)
+
def replDependencies = Seq[ProjectReference](core, graphx, bagel, mllib, sql) ++ maybeHiveRef
lazy val repl = Project("repl", file("repl"), settings = replSettings)
@@ -86,7 +90,16 @@ object SparkBuild extends Build {
lazy val assemblyProj = Project("assembly", file("assembly"), settings = assemblyProjSettings)
.dependsOn(core, graphx, bagel, mllib, streaming, repl, sql) dependsOn(maybeYarn: _*) dependsOn(maybeHive: _*) dependsOn(maybeGanglia: _*)
- lazy val assembleDeps = TaskKey[Unit]("assemble-deps", "Build assembly of dependencies and packages Spark projects")
+ lazy val assembleDepsTask = TaskKey[Unit]("assemble-deps")
+ lazy val assembleDeps = assembleDepsTask := {
+ println()
+ println("**** NOTE ****")
+ println("'sbt/sbt assemble-deps' is no longer supported.")
+ println("Instead create a normal assembly and:")
+ println(" export SPARK_PREPEND_CLASSES=1 (toggle on)")
+ println(" unset SPARK_PREPEND_CLASSES (toggle off)")
+ println()
+ }
// A configuration to set an alternative publishLocalConfiguration
lazy val MavenCompile = config("m2r") extend(Compile)
@@ -370,6 +383,7 @@ object SparkBuild extends Build {
"net.sf.py4j" % "py4j" % "0.8.1"
),
libraryDependencies ++= maybeAvro,
+ assembleDeps,
previousArtifact := sparkPreviousArtifact("spark-core")
)
@@ -581,9 +595,7 @@ object SparkBuild extends Build {
def assemblyProjSettings = sharedSettings ++ Seq(
name := "spark-assembly",
- assembleDeps in Compile <<= (packageProjects.map(packageBin in Compile in _) ++ Seq(packageDependency in Compile)).dependOn,
- jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" },
- jarName in packageDependency <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + "-deps.jar" }
+ jarName in assembly <<= version map { v => "spark-assembly-" + v + "-hadoop" + hadoopVersion + ".jar" }
) ++ assemblySettings ++ extraAssemblySettings
def extraAssemblySettings() = Seq(
@@ -599,6 +611,17 @@ object SparkBuild extends Build {
}
)
+ def oldDepsSettings() = Defaults.defaultSettings ++ Seq(
+ name := "old-deps",
+ scalaVersion := "2.10.4",
+ retrieveManaged := true,
+ retrievePattern := "[type]s/[artifact](-[revision])(-[classifier]).[ext]",
+ libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq",
+ "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter",
+ "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx",
+ "spark-core").map(sparkPreviousArtifact(_).get intransitive())
+ )
+
def twitterSettings() = sharedSettings ++ Seq(
name := "spark-streaming-twitter",
previousArtifact := sparkPreviousArtifact("spark-streaming-twitter"),
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index a411a5d5914e0..e609b60a0f968 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -454,7 +454,7 @@ def _squared_distance(v1, v2):
v2 = _convert_vector(v2)
if type(v1) == ndarray and type(v2) == ndarray:
diff = v1 - v2
- return diff.dot(diff)
+ return numpy.dot(diff, diff)
elif type(v1) == ndarray:
return v2.squared_distance(v1)
else:
@@ -469,10 +469,12 @@ def _dot(vec, target):
calling numpy.dot of the two vectors, but for SciPy ones, we
have to transpose them because they're column vectors.
"""
- if type(vec) == ndarray or type(vec) == SparseVector:
+ if type(vec) == ndarray:
+ return numpy.dot(vec, target)
+ elif type(vec) == SparseVector:
return vec.dot(target)
elif type(vec) == list:
- return _convert_vector(vec).dot(target)
+ return numpy.dot(_convert_vector(vec), target)
else:
return vec.transpose().dot(target)[0]
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index da364e19874b1..ddd22850a819c 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -736,7 +736,7 @@ def max(self):
def min(self):
"""
- Find the maximum item in this RDD.
+ Find the minimum item in this RDD.
>>> sc.parallelize([1.0, 5.0, 43.0, 10.0]).min()
1.0
@@ -1221,6 +1221,20 @@ def _mergeCombiners(iterator):
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)
+
+ def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
+ """
+ Aggregate the values of each key, using given combine functions and a neutral "zero value".
+ This function can return a different result type, U, than the type of the values in this RDD,
+ V. Thus, we need one operation for merging a V into a U and one operation for merging two U's,
+ The former operation is used for merging values within a partition, and the latter is used
+ for merging values between partitions. To avoid memory allocation, both of these functions are
+ allowed to modify and return their first argument instead of creating a new U.
+ """
+ def createZero():
+ return copy.deepcopy(zeroValue)
+
+ return self.combineByKey(lambda v: seqFunc(createZero(), v), seqFunc, combFunc, numPartitions)
def foldByKey(self, zeroValue, func, numPartitions=None):
"""
@@ -1234,7 +1248,10 @@ def foldByKey(self, zeroValue, func, numPartitions=None):
>>> rdd.foldByKey(0, add).collect()
[('a', 2), ('b', 1)]
"""
- return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions)
+ def createZero():
+ return copy.deepcopy(zeroValue)
+
+ return self.combineByKey(lambda v: func(createZero(), v), func, func, numPartitions)
# TODO: support variant with custom partitioner
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index b4e9618cc25b5..960d0a82448aa 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -117,7 +117,7 @@ def parquetFile(self, path):
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.saveAsParquetFile(parquetFile)
>>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> srdd.collect() == srdd2.collect()
+ >>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
jschema_rdd = self._ssql_ctx.parquetFile(path)
@@ -141,7 +141,7 @@ def table(self, tableName):
>>> srdd = sqlCtx.inferSchema(rdd)
>>> sqlCtx.registerRDDAsTable(srdd, "table1")
>>> srdd2 = sqlCtx.table("table1")
- >>> srdd.collect() == srdd2.collect()
+ >>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
return SchemaRDD(self._ssql_ctx.table(tableName), self)
@@ -293,7 +293,7 @@ def saveAsParquetFile(self, path):
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.saveAsParquetFile(parquetFile)
>>> srdd2 = sqlCtx.parquetFile(parquetFile)
- >>> srdd2.collect() == srdd.collect()
+ >>> sorted(srdd2.collect()) == sorted(srdd.collect())
True
"""
self._jschema_rdd.saveAsParquetFile(path)
@@ -307,7 +307,7 @@ def registerAsTable(self, name):
>>> srdd = sqlCtx.inferSchema(rdd)
>>> srdd.registerAsTable("test")
>>> srdd2 = sqlCtx.sql("select * from test")
- >>> srdd.collect() == srdd2.collect()
+ >>> sorted(srdd.collect()) == sorted(srdd2.collect())
True
"""
self._jschema_rdd.registerAsTable(name)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 184ee810b861b..c15bb457759ed 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -188,6 +188,21 @@ def test_deleting_input_files(self):
os.unlink(tempFile.name)
self.assertRaises(Exception, lambda: filtered_data.count())
+ def testAggregateByKey(self):
+ data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
+ def seqOp(x, y):
+ x.add(y)
+ return x
+
+ def combOp(x, y):
+ x |= y
+ return x
+
+ sets = dict(data.aggregateByKey(set(), seqOp, combOp).collect())
+ self.assertEqual(3, len(sets))
+ self.assertEqual(set([1]), sets[1])
+ self.assertEqual(set([2]), sets[3])
+ self.assertEqual(set([1, 3]), sets[5])
class TestIO(PySparkTestCase):
diff --git a/python/run-tests b/python/run-tests
index 36a96121cbc0d..9282aa47e8375 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -32,7 +32,8 @@ rm -f unit-tests.log
rm -rf metastore warehouse
function run_test() {
- SPARK_TESTING=0 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log
+ echo "Running test: $1"
+ SPARK_TESTING=1 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log
FAILED=$((PIPESTATUS[0]||$FAILED))
# Fail and exit on the first test failure.
@@ -43,18 +44,23 @@ function run_test() {
echo -en "\033[0m" # No color
exit -1
fi
-
}
+echo "Running PySpark tests. Output is in python/unit-tests.log."
+
run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "pyspark/conf.py"
if [ -n "$_RUN_SQL_TESTS" ]; then
run_test "pyspark/sql.py"
fi
-run_test "-m doctest pyspark/broadcast.py"
-run_test "-m doctest pyspark/accumulators.py"
-run_test "-m doctest pyspark/serializers.py"
+# These tests are included in the module-level docs, and so must
+# be handled on a higher level rather than within the python file.
+export PYSPARK_DOC_TEST=1
+run_test "pyspark/broadcast.py"
+run_test "pyspark/accumulators.py"
+run_test "pyspark/serializers.py"
+unset PYSPARK_DOC_TEST
run_test "pyspark/tests.py"
run_test "pyspark/mllib/_common.py"
run_test "pyspark/mllib/classification.py"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 36758f3114e59..46fcfbb9e26ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -111,6 +111,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val APPROXIMATE = Keyword("APPROXIMATE")
protected val AVG = Keyword("AVG")
protected val BY = Keyword("BY")
+ protected val CACHE = Keyword("CACHE")
protected val CAST = Keyword("CAST")
protected val COUNT = Keyword("COUNT")
protected val DESC = Keyword("DESC")
@@ -149,7 +150,9 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val SEMI = Keyword("SEMI")
protected val STRING = Keyword("STRING")
protected val SUM = Keyword("SUM")
+ protected val TABLE = Keyword("TABLE")
protected val TRUE = Keyword("TRUE")
+ protected val UNCACHE = Keyword("UNCACHE")
protected val UNION = Keyword("UNION")
protected val WHERE = Keyword("WHERE")
@@ -189,7 +192,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } |
UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
)
- | insert
+ | insert | cache
)
protected lazy val select: Parser[LogicalPlan] =
@@ -220,6 +223,11 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
InsertIntoTable(r, Map[String, Option[String]](), s, overwrite)
}
+ protected lazy val cache: Parser[LogicalPlan] =
+ (CACHE ^^^ true | UNCACHE ^^^ false) ~ TABLE ~ ident ^^ {
+ case doCache ~ _ ~ tableName => CacheCommand(tableName, doCache)
+ }
+
protected lazy val projections: Parser[Seq[Expression]] = repsep(projection, ",")
protected lazy val projection: Parser[Expression] =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 3cf163f9a9a75..d177339d40ae5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -175,6 +175,8 @@ package object dsl {
def where(condition: Expression) = Filter(condition, logicalPlan)
+ def limit(limitExpr: Expression) = Limit(limitExpr, logicalPlan)
+
def join(
otherPlan: LogicalPlan,
joinType: JoinType = Inner,
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 420303408451f..c074b7bb01e57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -76,7 +76,8 @@ trait CaseConversionExpression {
type EvaluatedType = Any
def convert(v: String): String
-
+
+ override def foldable: Boolean = child.foldable
def nullable: Boolean = child.nullable
def dataType: DataType = StringType
@@ -142,6 +143,8 @@ case class RLike(left: Expression, right: Expression)
case class Upper(child: Expression) extends UnaryExpression with CaseConversionExpression {
override def convert(v: String): String = v.toUpperCase()
+
+ override def toString() = s"Upper($child)"
}
/**
@@ -150,4 +153,6 @@ case class Upper(child: Expression) extends UnaryExpression with CaseConversionE
case class Lower(child: Expression) extends UnaryExpression with CaseConversionExpression {
override def convert(v: String): String = v.toLowerCase()
+
+ override def toString() = s"Lower($child)"
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 406ffd6801e98..25a347bec0e4c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,22 +19,29 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.FullOuter
+import org.apache.spark.sql.catalyst.plans.LeftOuter
+import org.apache.spark.sql.catalyst.plans.RightOuter
+import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.types._
object Optimizer extends RuleExecutor[LogicalPlan] {
val batches =
+ Batch("Combine Limits", FixedPoint(100),
+ CombineLimits) ::
Batch("ConstantFolding", FixedPoint(100),
NullPropagation,
ConstantFolding,
BooleanSimplification,
SimplifyFilters,
- SimplifyCasts) ::
+ SimplifyCasts,
+ SimplifyCaseConversionExpressions) ::
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
PushPredicateThroughProject,
- PushPredicateThroughInnerJoin,
+ PushPredicateThroughJoin,
ColumnPruning) :: Nil
}
@@ -100,8 +107,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
object NullPropagation extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsUp {
- case e @ Count(Literal(null, _)) => Literal(0, e.dataType)
- case e @ Sum(Literal(c, _)) if c == 0 => Literal(0, e.dataType)
+ case e @ Count(Literal(null, _)) => Cast(Literal(0L), e.dataType)
+ case e @ Sum(Literal(c, _)) if c == 0 => Cast(Literal(0L), e.dataType)
case e @ Average(Literal(c, _)) if c == 0 => Literal(0.0, e.dataType)
case e @ IsNull(c) if c.nullable == false => Literal(false, BooleanType)
case e @ IsNotNull(c) if c.nullable == false => Literal(true, BooleanType)
@@ -126,18 +133,6 @@ object NullPropagation extends Rule[LogicalPlan] {
case Literal(candidate, _) if candidate == v => true
case _ => false
})) => Literal(true, BooleanType)
- case e: UnaryMinus => e.child match {
- case Literal(null, _) => Literal(null, e.dataType)
- case _ => e
- }
- case e: Cast => e.child match {
- case Literal(null, _) => Literal(null, e.dataType)
- case _ => e
- }
- case e: Not => e.child match {
- case Literal(null, _) => Literal(null, e.dataType)
- case _ => e
- }
// Put exceptional cases above if any
case e: BinaryArithmetic => e.children match {
case Literal(null, _) :: right :: Nil => Literal(null, e.dataType)
@@ -254,28 +249,98 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
/**
* Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be
- * evaluated using only the attributes of the left or right side of an inner join. Other
+ * evaluated using only the attributes of the left or right side of a join. Other
* [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the
* [[catalyst.plans.logical.Join Join]].
+ * And also Pushes down the join filter, where the `condition` can be evaluated using only the
+ * attributes of the left or right side of sub query when applicable.
+ *
+ * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
*/
-object PushPredicateThroughInnerJoin extends Rule[LogicalPlan] with PredicateHelper {
+object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
+ // split the condition expression into 3 parts,
+ // (canEvaluateInLeftSide, canEvaluateInRightSide, haveToEvaluateWithBothSide)
+ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
+ val (leftEvaluateCondition, rest) =
+ condition.partition(_.references subsetOf left.outputSet)
+ val (rightEvaluateCondition, commonCondition) =
+ rest.partition(_.references subsetOf right.outputSet)
+
+ (leftEvaluateCondition, rightEvaluateCondition, commonCondition)
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case f @ Filter(filterCondition, Join(left, right, Inner, joinCondition)) =>
- val allConditions =
- splitConjunctivePredicates(filterCondition) ++
- joinCondition.map(splitConjunctivePredicates).getOrElse(Nil)
-
- // Split the predicates into those that can be evaluated on the left, right, and those that
- // must be evaluated after the join.
- val (rightConditions, leftOrJoinConditions) =
- allConditions.partition(_.references subsetOf right.outputSet)
- val (leftConditions, joinConditions) =
- leftOrJoinConditions.partition(_.references subsetOf left.outputSet)
-
- // Build the new left and right side, optionally with the pushed down filters.
- val newLeft = leftConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
- val newRight = rightConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
- Join(newLeft, newRight, Inner, joinConditions.reduceLeftOption(And))
+ // push the where condition down into join filter
+ case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
+ val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
+ split(splitConjunctivePredicates(filterCondition), left, right)
+
+ joinType match {
+ case Inner =>
+ // push down the single side `where` condition into respective sides
+ val newLeft = leftFilterConditions.
+ reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+ val newRight = rightFilterConditions.
+ reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)
+
+ Join(newLeft, newRight, Inner, newJoinCond)
+ case RightOuter =>
+ // push down the right side only `where` condition
+ val newLeft = left
+ val newRight = rightFilterConditions.
+ reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ val newJoinCond = joinCondition
+ val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond)
+
+ (leftFilterConditions ++ commonFilterCondition).
+ reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
+ case _ @ (LeftOuter | LeftSemi) =>
+ // push down the left side only `where` condition
+ val newLeft = leftFilterConditions.
+ reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+ val newRight = right
+ val newJoinCond = joinCondition
+ val newJoin = Join(newLeft, newRight, joinType, newJoinCond)
+
+ (rightFilterConditions ++ commonFilterCondition).
+ reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
+ case FullOuter => f // DO Nothing for Full Outer Join
+ }
+
+ // push down the join filter into sub query scanning if applicable
+ case f @ Join(left, right, joinType, joinCondition) =>
+ val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
+ split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
+
+ joinType match {
+ case Inner =>
+ // push down the single side only join filter for both sides sub queries
+ val newLeft = leftJoinConditions.
+ reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+ val newRight = rightJoinConditions.
+ reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ val newJoinCond = commonJoinCondition.reduceLeftOption(And)
+
+ Join(newLeft, newRight, Inner, newJoinCond)
+ case RightOuter =>
+ // push down the left side only join filter for left side sub query
+ val newLeft = leftJoinConditions.
+ reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+ val newRight = right
+ val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
+
+ Join(newLeft, newRight, RightOuter, newJoinCond)
+ case _ @ (LeftOuter | LeftSemi) =>
+ // push down the right side only join filter for right sub query
+ val newLeft = left
+ val newRight = rightJoinConditions.
+ reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
+
+ Join(newLeft, newRight, joinType, newJoinCond)
+ case FullOuter => f
+ }
}
}
@@ -288,3 +353,29 @@ object SimplifyCasts extends Rule[LogicalPlan] {
case Cast(e, dataType) if e.dataType == dataType => e
}
}
+
+/**
+ * Combines two adjacent [[catalyst.plans.logical.Limit Limit]] operators into one, merging the
+ * expressions into one single expression.
+ */
+object CombineLimits extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case ll @ Limit(le, nl @ Limit(ne, grandChild)) =>
+ Limit(If(LessThan(ne, le), ne, le), grandChild)
+ }
+}
+
+/**
+ * Removes the inner [[catalyst.expressions.CaseConversionExpression]] that are unnecessary because
+ * the inner conversion is overwritten by the outer one.
+ */
+object SimplifyCaseConversionExpressions extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsUp {
+ case Upper(Upper(child)) => Upper(child)
+ case Upper(Lower(child)) => Upper(child)
+ case Lower(Upper(child)) => Lower(child)
+ case Lower(Lower(child)) => Lower(child)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 7eeb98aea6368..0933a31c362d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.types.{StringType, StructType}
+import org.apache.spark.sql.catalyst.types.StructType
import org.apache.spark.sql.catalyst.trees
abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
@@ -96,39 +96,6 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
def references = Set.empty
}
-/**
- * A logical node that represents a non-query command to be executed by the system. For example,
- * commands can be used by parsers to represent DDL operations.
- */
-abstract class Command extends LeafNode {
- self: Product =>
- def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this
-}
-
-/**
- * Returned for commands supported by a given parser, but not catalyst. In general these are DDL
- * commands that are passed directly to another system.
- */
-case class NativeCommand(cmd: String) extends Command
-
-/**
- * Commands of the form "SET (key) (= value)".
- */
-case class SetCommand(key: Option[String], value: Option[String]) extends Command {
- override def output = Seq(
- AttributeReference("key", StringType, nullable = false)(),
- AttributeReference("value", StringType, nullable = false)()
- )
-}
-
-/**
- * Returned by a parser when the users only wants to see what query plan would be executed, without
- * actually performing the execution.
- */
-case class ExplainCommand(plan: LogicalPlan) extends Command {
- override def output = Seq(AttributeReference("plan", StringType, nullable = false)())
-}
-
/**
* A logical plan node with single child.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index d3347b622f3d8..b777cf4249196 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -135,9 +135,9 @@ case class Aggregate(
def references = (groupingExpressions ++ aggregateExpressions).flatMap(_.references).toSet
}
-case class Limit(limit: Expression, child: LogicalPlan) extends UnaryNode {
+case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
def output = child.output
- def references = limit.references
+ def references = limitExpr.references
}
case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
new file mode 100644
index 0000000000000..d05c9652753e0
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute}
+import org.apache.spark.sql.catalyst.types.StringType
+
+/**
+ * A logical node that represents a non-query command to be executed by the system. For example,
+ * commands can be used by parsers to represent DDL operations.
+ */
+abstract class Command extends LeafNode {
+ self: Product =>
+ def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this
+}
+
+/**
+ * Returned for commands supported by a given parser, but not catalyst. In general these are DDL
+ * commands that are passed directly to another system.
+ */
+case class NativeCommand(cmd: String) extends Command
+
+/**
+ * Commands of the form "SET (key) (= value)".
+ */
+case class SetCommand(key: Option[String], value: Option[String]) extends Command {
+ override def output = Seq(
+ AttributeReference("key", StringType, nullable = false)(),
+ AttributeReference("value", StringType, nullable = false)()
+ )
+}
+
+/**
+ * Returned by a parser when the users only wants to see what query plan would be executed, without
+ * actually performing the execution.
+ */
+case class ExplainCommand(plan: LogicalPlan) extends Command {
+ override def output = Seq(AttributeReference("plan", StringType, nullable = false)())
+}
+
+/**
+ * Returned for the "CACHE TABLE tableName" and "UNCACHE TABLE tableName" command.
+ */
+case class CacheCommand(tableName: String, doCache: Boolean) extends Command
+
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
new file mode 100644
index 0000000000000..714f01843c0f5
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+class CombiningLimitsSuite extends OptimizerTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Combine Limit", FixedPoint(2),
+ CombineLimits) ::
+ Batch("Constant Folding", FixedPoint(3),
+ NullPropagation,
+ ConstantFolding,
+ BooleanSimplification) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ test("limits: combines two limits") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .limit(10)
+ .limit(5)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(5).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("limits: combines three limits") {
+ val originalQuery =
+ testRelation
+ .select('a)
+ .limit(2)
+ .limit(7)
+ .limit(5)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index 20dfba847790c..6efc0e211eb21 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.catalyst.types.{DoubleType, IntegerType}
+import org.apache.spark.sql.catalyst.types._
// For implicit conversions
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -173,4 +173,63 @@ class ConstantFoldingSuite extends OptimizerTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("Constant folding test: expressions have null literals") {
+ val originalQuery =
+ testRelation
+ .select(
+ IsNull(Literal(null)) as 'c1,
+ IsNotNull(Literal(null)) as 'c2,
+
+ GetItem(Literal(null, ArrayType(IntegerType)), 1) as 'c3,
+ GetItem(Literal(Seq(1), ArrayType(IntegerType)), Literal(null, IntegerType)) as 'c4,
+ GetField(
+ Literal(null, StructType(Seq(StructField("a", IntegerType, true)))),
+ "a") as 'c5,
+
+ UnaryMinus(Literal(null, IntegerType)) as 'c6,
+ Cast(Literal(null), IntegerType) as 'c7,
+ Not(Literal(null, BooleanType)) as 'c8,
+
+ Add(Literal(null, IntegerType), 1) as 'c9,
+ Add(1, Literal(null, IntegerType)) as 'c10,
+
+ Equals(Literal(null, IntegerType), 1) as 'c11,
+ Equals(1, Literal(null, IntegerType)) as 'c12,
+
+ Like(Literal(null, StringType), "abc") as 'c13,
+ Like("abc", Literal(null, StringType)) as 'c14,
+
+ Upper(Literal(null, StringType)) as 'c15)
+
+ val optimized = Optimize(originalQuery.analyze)
+
+ val correctAnswer =
+ testRelation
+ .select(
+ Literal(true) as 'c1,
+ Literal(false) as 'c2,
+
+ Literal(null, IntegerType) as 'c3,
+ Literal(null, IntegerType) as 'c4,
+ Literal(null, IntegerType) as 'c5,
+
+ Literal(null, IntegerType) as 'c6,
+ Literal(null, IntegerType) as 'c7,
+ Literal(null, BooleanType) as 'c8,
+
+ Literal(null, IntegerType) as 'c9,
+ Literal(null, IntegerType) as 'c10,
+
+ Literal(null, BooleanType) as 'c11,
+ Literal(null, BooleanType) as 'c12,
+
+ Literal(null, BooleanType) as 'c13,
+ Literal(null, BooleanType) as 'c14,
+
+ Literal(null, StringType) as 'c15)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ef47850455a37..0cada785b6630 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -20,9 +20,9 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.LeftOuter
+import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.rules._
-
-/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
@@ -35,7 +35,7 @@ class FilterPushdownSuite extends OptimizerTest {
Batch("Filter Pushdown", Once,
CombineFilters,
PushPredicateThroughProject,
- PushPredicateThroughInnerJoin) :: Nil
+ PushPredicateThroughJoin) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -161,6 +161,184 @@ class FilterPushdownSuite extends OptimizerTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("joins: push down left outer join #1") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter)
+ .where("x.b".attr === 1 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 1)
+ val correctAnswer =
+ left.join(y, LeftOuter).where("y.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #1") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter)
+ .where("x.b".attr === 1 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val right = testRelation.where('b === 2).subquery('d)
+ val correctAnswer =
+ x.join(right, RightOuter).where("x.b".attr === 1).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down left outer join #2") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter, Some("x.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 2).subquery('d)
+ val correctAnswer =
+ left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #2") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val right = testRelation.where('b === 2).subquery('d)
+ val correctAnswer =
+ x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down left outer join #3") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 2).subquery('l)
+ val right = testRelation.where('b === 1).subquery('r)
+ val correctAnswer =
+ left.join(right, LeftOuter).where("r.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #3") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val right = testRelation.where('b === 2).subquery('r)
+ val correctAnswer =
+ x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down left outer join #4") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 2).subquery('l)
+ val right = testRelation.where('b === 1).subquery('r)
+ val correctAnswer =
+ left.join(right, LeftOuter).where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #4") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.subquery('l)
+ val right = testRelation.where('b === 2).subquery('r)
+ val correctAnswer =
+ left.join(right, RightOuter, Some("r.b".attr === 1)).
+ where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down left outer join #5") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter, Some("y.b".attr === 1 && "x.a".attr === 3))
+ .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 2).subquery('l)
+ val right = testRelation.where('b === 1).subquery('r)
+ val correctAnswer =
+ left.join(right, LeftOuter, Some("l.a".attr===3)).
+ where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #5") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter, Some("y.b".attr === 1 && "x.a".attr === 3))
+ .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('a === 3).subquery('l)
+ val right = testRelation.where('b === 2).subquery('r)
+ val correctAnswer =
+ left.join(right, RightOuter, Some("r.b".attr === 1)).
+ where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
test("joins: can't push down") {
val x = testRelation.subquery('x)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
new file mode 100644
index 0000000000000..df1409fe7baee
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+/* Implicit conversions */
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+
+class SimplifyCaseConversionExpressionsSuite extends OptimizerTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("Simplify CaseConversionExpressions", Once,
+ SimplifyCaseConversionExpressions) :: Nil
+ }
+
+ val testRelation = LocalRelation('a.string)
+
+ test("simplify UPPER(UPPER(str))") {
+ val originalQuery =
+ testRelation
+ .select(Upper(Upper('a)) as 'u)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select(Upper('a) as 'u)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("simplify UPPER(LOWER(str))") {
+ val originalQuery =
+ testRelation
+ .select(Upper(Lower('a)) as 'u)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select(Upper('a) as 'u)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("simplify LOWER(UPPER(str))") {
+ val originalQuery =
+ testRelation
+ .select(Lower(Upper('a)) as 'l)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer = testRelation
+ .select(Lower('a) as 'l)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("simplify LOWER(LOWER(str))") {
+ val originalQuery =
+ testRelation
+ .select(Lower(Lower('a)) as 'l)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer = testRelation
+ .select(Lower('a) as 'l)
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 021e0e8245a0d..264192ed1aa26 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -188,6 +188,15 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
}
+ /** Returns true if the table is currently cached in-memory. */
+ def isCached(tableName: String): Boolean = {
+ val relation = catalog.lookupRelation(None, tableName)
+ EliminateAnalysisOperators(relation) match {
+ case SparkLogicalPlan(_: InMemoryColumnarTableScan) => true
+ case _ => false
+ }
+ }
+
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext = self.sparkContext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 8855c4e876917..7ad8edf5a5a6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -178,14 +178,18 @@ class SchemaRDD(
def orderBy(sortExprs: SortOrder*): SchemaRDD =
new SchemaRDD(sqlContext, Sort(sortExprs, logicalPlan))
+ @deprecated("use limit with integer argument", "1.1.0")
+ def limit(limitExpr: Expression): SchemaRDD =
+ new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
+
/**
- * Limits the results by the given expressions.
+ * Limits the results by the given integer.
* {{{
* schemaRDD.limit(10)
* }}}
*/
- def limit(limitExpr: Expression): SchemaRDD =
- new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan))
+ def limit(limitNum: Int): SchemaRDD =
+ new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan))
/**
* Performs a grouping followed by an aggregation.
@@ -374,6 +378,8 @@ class SchemaRDD(
override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
+ override def take(num: Int): Array[Row] = limit(num).collect()
+
// =======================================================================
// Base RDD functions that do NOT change schema
// =======================================================================
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 0455748d40eec..f2f95dfe27e69 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -239,10 +239,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.SetCommand(key, value) =>
Seq(execution.SetCommandPhysical(key, value, plan.output)(context))
case logical.ExplainCommand(child) =>
- val qe = context.executePlan(child)
- Seq(execution.ExplainCommandPhysical(qe.executedPlan, plan.output)(context))
+ val executedPlan = context.executePlan(child).executedPlan
+ Seq(execution.ExplainCommandPhysical(executedPlan, plan.output)(context))
+ case logical.CacheCommand(tableName, cache) =>
+ Seq(execution.CacheCommandPhysical(tableName, cache)(context))
case _ => Nil
}
}
-
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 9364506691f38..be26d19e66862 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -65,3 +65,26 @@ case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute])
override def otherCopyArgs = context :: Nil
}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class CacheCommandPhysical(tableName: String, doCache: Boolean)(@transient context: SQLContext)
+ extends LeafNode {
+
+ lazy val commandSideEffect = {
+ if (doCache) {
+ context.cacheTable(tableName)
+ } else {
+ context.uncacheTable(tableName)
+ }
+ }
+
+ override def execute(): RDD[Row] = {
+ commandSideEffect
+ context.emptyResult
+ }
+
+ override def output: Seq[Attribute] = Seq.empty
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 88ff3d49a79b3..8d7a5ba59f96a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -169,7 +169,7 @@ case class LeftSemiJoinHash(
def execute() = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- val hashTable = new java.util.HashSet[Row]()
+ val hashSet = new java.util.HashSet[Row]()
var currentRow: Row = null
// Create a Hash set of buildKeys
@@ -177,43 +177,17 @@ case class LeftSemiJoinHash(
currentRow = buildIter.next()
val rowKey = buildSideKeyGenerator(currentRow)
if(!rowKey.anyNull) {
- val keyExists = hashTable.contains(rowKey)
+ val keyExists = hashSet.contains(rowKey)
if (!keyExists) {
- hashTable.add(rowKey)
+ hashSet.add(rowKey)
}
}
}
- new Iterator[Row] {
- private[this] var currentStreamedRow: Row = _
- private[this] var currentHashMatched: Boolean = false
-
- private[this] val joinKeys = streamSideKeyGenerator()
-
- override final def hasNext: Boolean =
- streamIter.hasNext && fetchNext()
-
- override final def next() = {
- currentStreamedRow
- }
-
- /**
- * Searches the streamed iterator for the next row that has at least one match in hashtable.
- *
- * @return true if the search is successful, and false the streamed iterator runs out of
- * tuples.
- */
- private final def fetchNext(): Boolean = {
- currentHashMatched = false
- while (!currentHashMatched && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- if (!joinKeys(currentStreamedRow).anyNull) {
- currentHashMatched = hashTable.contains(joinKeys.currentValue)
- }
- }
- currentHashMatched
- }
- }
+ val joinKeys = streamSideKeyGenerator()
+ streamIter.filter(current => {
+ !joinKeys(current).anyNull && hashSet.contains(joinKeys.currentValue)
+ })
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 0331f90272a99..ebca3adc2ff01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -70,4 +70,20 @@ class CachedTableSuite extends QueryTest {
TestSQLContext.sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key")
TestSQLContext.uncacheTable("testData")
}
+
+ test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") {
+ TestSQLContext.sql("CACHE TABLE testData")
+ TestSQLContext.table("testData").queryExecution.executedPlan match {
+ case _: InMemoryColumnarTableScan => // Found evidence of caching
+ case _ => fail(s"Table 'testData' should be cached")
+ }
+ assert(TestSQLContext.isCached("testData"), "Table 'testData' should be cached")
+
+ TestSQLContext.sql("UNCACHE TABLE testData")
+ TestSQLContext.table("testData").queryExecution.executedPlan match {
+ case _: InMemoryColumnarTableScan => fail(s"Table 'testData' should not be cached")
+ case _ => // Found evidence of uncaching
+ }
+ assert(!TestSQLContext.isCached("testData"), "Table 'testData' should not be cached")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 5eb73a4eff980..08293f7f0ca30 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -28,6 +28,7 @@ class SQLConfSuite extends QueryTest {
val testVal = "test.val.0"
test("programmatic ways of basic setting and getting") {
+ clear()
assert(getOption(testKey).isEmpty)
assert(getAll.toSet === Set())
@@ -48,6 +49,7 @@ class SQLConfSuite extends QueryTest {
}
test("parse SQL set commands") {
+ clear()
sql(s"set $testKey=$testVal")
assert(get(testKey, testVal + "_") == testVal)
assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index f2d850ad6aa56..c1fc99f077431 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -136,6 +136,12 @@ class SQLQuerySuite extends QueryTest {
2.0)
}
+ test("average overflow") {
+ checkAnswer(
+ sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"),
+ Seq((2147483645.0,1),(2.0,2)))
+ }
+
test("count") {
checkAnswer(
sql("SELECT COUNT(*) FROM testData2"),
@@ -396,6 +402,7 @@ class SQLQuerySuite extends QueryTest {
sql(s"SET $nonexistentKey"),
Seq(Seq(s"$nonexistentKey is undefined"))
)
+ clear()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 05de736bbce1b..330b20b315d63 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -30,6 +30,17 @@ object TestData {
(1 to 100).map(i => TestData(i, i.toString)))
testData.registerAsTable("testData")
+ case class LargeAndSmallInts(a: Int, b: Int)
+ val largeAndSmallInts: SchemaRDD =
+ TestSQLContext.sparkContext.parallelize(
+ LargeAndSmallInts(2147483644, 1) ::
+ LargeAndSmallInts(1, 2) ::
+ LargeAndSmallInts(2147483645, 1) ::
+ LargeAndSmallInts(2, 2) ::
+ LargeAndSmallInts(2147483646, 1) ::
+ LargeAndSmallInts(3, 2) :: Nil)
+ largeAndSmallInts.registerAsTable("largeAndSmallInts")
+
case class TestData2(a: Int, b: Int)
val testData2: SchemaRDD =
TestSQLContext.sparkContext.parallelize(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 4e74d9bc909fa..b745d8ffd8f17 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -218,15 +218,19 @@ private[hive] object HiveQl {
case Array(key, value) => // "set key=value"
SetCommand(Some(key), Some(value))
}
- } else if (sql.toLowerCase.startsWith("add jar")) {
+ } else if (sql.trim.toLowerCase.startsWith("cache table")) {
+ CacheCommand(sql.drop(12).trim, true)
+ } else if (sql.trim.toLowerCase.startsWith("uncache table")) {
+ CacheCommand(sql.drop(14).trim, false)
+ } else if (sql.trim.toLowerCase.startsWith("add jar")) {
AddJar(sql.drop(8))
- } else if (sql.toLowerCase.startsWith("add file")) {
+ } else if (sql.trim.toLowerCase.startsWith("add file")) {
AddFile(sql.drop(9))
- } else if (sql.startsWith("dfs")) {
+ } else if (sql.trim.startsWith("dfs")) {
DfsCommand(sql)
- } else if (sql.startsWith("source")) {
+ } else if (sql.trim.startsWith("source")) {
SourceCommand(sql.split(" ").toSeq match { case Seq("source", filePath) => filePath })
- } else if (sql.startsWith("!")) {
+ } else if (sql.trim.startsWith("!")) {
ShellCommand(sql.drop(1))
} else {
val tree = getAst(sql)
@@ -839,11 +843,11 @@ private[hive] object HiveQl {
case Token("TOK_FUNCTIONDI", Token(SUM(), Nil) :: arg :: Nil) => SumDistinct(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(MAX(), Nil) :: arg :: Nil) => Max(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(MIN(), Nil) :: arg :: Nil) => Min(nodeToExpr(arg))
-
+
/* System functions about string operations */
case Token("TOK_FUNCTION", Token(UPPER(), Nil) :: arg :: Nil) => Upper(nodeToExpr(arg))
case Token("TOK_FUNCTION", Token(LOWER(), Nil) :: arg :: Nil) => Lower(nodeToExpr(arg))
-
+
/* Casts */
case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) =>
Cast(nodeToExpr(arg), StringType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
index 041e813598d1b..9386008d02d51 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
@@ -32,7 +32,7 @@ import org.apache.hadoop.hive.serde2.avro.AvroSerDe
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.analysis._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, NativeCommand}
+import org.apache.spark.sql.catalyst.plans.logical.{CacheCommand, LogicalPlan, NativeCommand}
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.hive._
@@ -103,7 +103,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
val inRepoTests = if (System.getProperty("user.dir").endsWith("sql" + File.separator + "hive")) {
new File("src" + File.separator + "test" + File.separator + "resources" + File.separator)
} else {
- new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" +
+ new File("sql" + File.separator + "hive" + File.separator + "src" + File.separator + "test" +
File.separator + "resources")
}
@@ -130,6 +130,7 @@ class TestHiveContext(sc: SparkContext) extends LocalHiveContext(sc) {
override lazy val analyzed = {
val describedTables = logical match {
case NativeCommand(describedTable(tbl)) => tbl :: Nil
+ case CacheCommand(tbl, _) => tbl :: Nil
case _ => Nil
}
diff --git a/sql/hive/src/test/resources/golden/semijoin-0-1631b71327abf75b96116036b977b26c b/sql/hive/src/test/resources/golden/semijoin-0-1631b71327abf75b96116036b977b26c
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c b/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c
new file mode 100644
index 0000000000000..2ed47ab83dd02
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-1-e7c99e1f46d502edbb0925d75aab5f0c
@@ -0,0 +1,11 @@
+0 val_0
+0 val_0
+0 val_0
+2 val_2
+4 val_4
+5 val_5
+5 val_5
+5 val_5
+8 val_8
+9 val_9
+10 val_10
diff --git a/sql/hive/src/test/resources/golden/semijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6 b/sql/hive/src/test/resources/golden/semijoin-10-ffd4fb3a903a6725ccb97d5451a3fec6
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330 b/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330
new file mode 100644
index 0000000000000..a24bd8c6379e3
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-11-246a40dcafe077f02397e30d227c330
@@ -0,0 +1,8 @@
+0 val_0
+0 val_0
+0 val_0
+4 val_2
+8 val_4
+10 val_5
+10 val_5
+10 val_5
diff --git a/sql/hive/src/test/resources/golden/semijoin-12-6d93a9d332ba490835b17f261a5467df b/sql/hive/src/test/resources/golden/semijoin-12-6d93a9d332ba490835b17f261a5467df
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-13-18282d38b6efc0017089ab89b661764f b/sql/hive/src/test/resources/golden/semijoin-13-18282d38b6efc0017089ab89b661764f
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-14-19cfcefb10e1972bec0ffd421cd79de7 b/sql/hive/src/test/resources/golden/semijoin-14-19cfcefb10e1972bec0ffd421cd79de7
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589 b/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589
new file mode 100644
index 0000000000000..03c61a908b071
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-15-1de6eb3f357bd1c4d02ab4d19d43589
@@ -0,0 +1,11 @@
+val_0
+val_0
+val_0
+val_10
+val_2
+val_4
+val_5
+val_5
+val_5
+val_8
+val_9
diff --git a/sql/hive/src/test/resources/golden/semijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d b/sql/hive/src/test/resources/golden/semijoin-16-d3a72a90515ac4a8d8e9ac923bcda3d
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019 b/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019
new file mode 100644
index 0000000000000..2dcdfd1217ced
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-17-f0f8720cfd11fd71af17b310dc2d1019
@@ -0,0 +1,3 @@
+0 val_0
+0 val_0
+0 val_0
diff --git a/sql/hive/src/test/resources/golden/semijoin-18-f7b2ce472443982e32d954cbb5c96765 b/sql/hive/src/test/resources/golden/semijoin-18-f7b2ce472443982e32d954cbb5c96765
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f b/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f
new file mode 100644
index 0000000000000..a3670515e8cc2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-19-b1f1c7f701abe81c01e72fb98f0bd13f
@@ -0,0 +1,3 @@
+val_10
+val_8
+val_9
diff --git a/sql/hive/src/test/resources/golden/semijoin-2-deb9c3286ae8e851b1fdb270085b16bc b/sql/hive/src/test/resources/golden/semijoin-2-deb9c3286ae8e851b1fdb270085b16bc
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1 b/sql/hive/src/test/resources/golden/semijoin-20-b7a8ebaeb42b2eaba7d97cadc3fd96c1
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-21-480418a0646cf7260b494b9eb4821bb6 b/sql/hive/src/test/resources/golden/semijoin-21-480418a0646cf7260b494b9eb4821bb6
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41 b/sql/hive/src/test/resources/golden/semijoin-22-b6aebd98f7636cda7b24e0bf84d7ba41
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda b/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda
new file mode 100644
index 0000000000000..72bc6a6a88f6e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-23-ed730ccdf552c07e7a82ba6b7fd3fbda
@@ -0,0 +1,5 @@
+4 val_2
+8 val_4
+10 val_5
+10 val_5
+10 val_5
diff --git a/sql/hive/src/test/resources/golden/semijoin-24-d16b37134de78980b2bf96029e8265c3 b/sql/hive/src/test/resources/golden/semijoin-24-d16b37134de78980b2bf96029e8265c3
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b b/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b
new file mode 100644
index 0000000000000..d89ea1757c712
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-25-be2bd011cc80b480b271a08dbf381e9b
@@ -0,0 +1,19 @@
+0
+0
+0
+0
+0
+0
+2
+4
+4
+5
+5
+5
+8
+8
+9
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-26-f1d3bab29f1ebafa148dbe3816e1da25 b/sql/hive/src/test/resources/golden/semijoin-26-f1d3bab29f1ebafa148dbe3816e1da25
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d b/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d
new file mode 100644
index 0000000000000..dbbdae75a52a4
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-27-391c256254d171973f02e7f33672ce1d
@@ -0,0 +1,4 @@
+0 val_0
+0 val_0
+0 val_0
+8 val_8
diff --git a/sql/hive/src/test/resources/golden/semijoin-28-b56400f6d9372f353cf7292a2182e963 b/sql/hive/src/test/resources/golden/semijoin-28-b56400f6d9372f353cf7292a2182e963
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5 b/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5
new file mode 100644
index 0000000000000..07c61afb5124b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-29-9efeef3d3c38e22a74d379978178c4f5
@@ -0,0 +1,14 @@
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+0 val_0 0 val_0
+4 val_4 4 val_2
+8 val_8 8 val_4
+10 val_10 10 val_5
+10 val_10 10 val_5
+10 val_10 10 val_5
diff --git a/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389 b/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389
new file mode 100644
index 0000000000000..bf51e8f5d9eb5
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-3-b4d4317dd3a10e18502f20f5c5250389
@@ -0,0 +1,11 @@
+0 val_0
+0 val_0
+0 val_0
+4 val_2
+8 val_4
+10 val_5
+10 val_5
+10 val_5
+16 val_8
+18 val_9
+20 val_10
diff --git a/sql/hive/src/test/resources/golden/semijoin-30-dd901d00fce5898b03a57cbc3028a70a b/sql/hive/src/test/resources/golden/semijoin-30-dd901d00fce5898b03a57cbc3028a70a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67 b/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67
new file mode 100644
index 0000000000000..d6283e34d8ffc
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-31-e5dc4d8185e63e984aa4e3a2e08bd67
@@ -0,0 +1,14 @@
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+2 val_2
+4 val_4
+5 val_5
+5 val_5
+5 val_5
+8 val_8
+9 val_9
+10 val_10
diff --git a/sql/hive/src/test/resources/golden/semijoin-32-23017c7663f2710265a7e2a4a1606d39 b/sql/hive/src/test/resources/golden/semijoin-32-23017c7663f2710265a7e2a4a1606d39
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5 b/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5
new file mode 100644
index 0000000000000..080180f9d0f0e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-33-ed499f94c6e6ac847ef5187b3b43bbc5
@@ -0,0 +1,14 @@
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-34-5e1b832090ab73c141c1167d5b25a490 b/sql/hive/src/test/resources/golden/semijoin-34-5e1b832090ab73c141c1167d5b25a490
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99 b/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99
new file mode 100644
index 0000000000000..4a64d5c625790
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-35-8d5731f26232f6e26dd8012461b08d99
@@ -0,0 +1,26 @@
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-36-b1159823dca8025926407f8aa921238d b/sql/hive/src/test/resources/golden/semijoin-36-b1159823dca8025926407f8aa921238d
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0 b/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0
new file mode 100644
index 0000000000000..1420c786fb228
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-37-a15b9074f999ce836be5329591b968d0
@@ -0,0 +1,29 @@
+NULL
+NULL
+NULL
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-38-f37547c73a48ce3ba089531b176e6ba b/sql/hive/src/test/resources/golden/semijoin-38-f37547c73a48ce3ba089531b176e6ba
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8 b/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8
new file mode 100644
index 0000000000000..1420c786fb228
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-39-c22a6f11368affcb80a9c80e653a47a8
@@ -0,0 +1,29 @@
+NULL
+NULL
+NULL
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-4-dfdad5a2742f93e8ea888191460809c0 b/sql/hive/src/test/resources/golden/semijoin-4-dfdad5a2742f93e8ea888191460809c0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-40-32071a51e2ba6e86b1c5e40de55aae63 b/sql/hive/src/test/resources/golden/semijoin-40-32071a51e2ba6e86b1c5e40de55aae63
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092 b/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092
new file mode 100644
index 0000000000000..aef9483bb0bc9
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-41-cf74f73a33b1af8902b7970a9350b092
@@ -0,0 +1,29 @@
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
+16
+18
+20
diff --git a/sql/hive/src/test/resources/golden/semijoin-42-6b4257a74fca627785c967c99547f4c0 b/sql/hive/src/test/resources/golden/semijoin-42-6b4257a74fca627785c967c99547f4c0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81 b/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81
new file mode 100644
index 0000000000000..0bc413ef2e09e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-43-e8a166ac2e94bf8d1da0fe91b0db2c81
@@ -0,0 +1,31 @@
+NULL
+NULL
+NULL
+NULL
+NULL
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-44-945aaa3a24359ef73acab1e99500d5ea b/sql/hive/src/test/resources/golden/semijoin-44-945aaa3a24359ef73acab1e99500d5ea
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b b/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b
new file mode 100644
index 0000000000000..3131e64446f66
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-45-3fd94ffd4f1eb6cf83dcc064599bf12b
@@ -0,0 +1,42 @@
+NULL
+NULL
+NULL
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+2
+4
+4
+5
+5
+5
+8
+8
+9
+10
+10
+10
+10
+10
+10
+10
+10
+10
+10
+10
+10
diff --git a/sql/hive/src/test/resources/golden/semijoin-46-620e01f81f6e5254b4bbe8fab4043ec0 b/sql/hive/src/test/resources/golden/semijoin-46-620e01f81f6e5254b4bbe8fab4043ec0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c b/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c
new file mode 100644
index 0000000000000..ff30bedb81861
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-47-f0140e4ee92508ba241f91c157b7af9c
@@ -0,0 +1,35 @@
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+0
+4
+4
+8
+8
+10
+10
+10
+10
+10
+10
+10
+10
+10
+10
+16
+18
+20
diff --git a/sql/hive/src/test/resources/golden/semijoin-48-8a04442e84f99a584c2882d0af8c25d8 b/sql/hive/src/test/resources/golden/semijoin-48-8a04442e84f99a584c2882d0af8c25d8
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-49-df1d6705d3624be72036318a6b42f04c b/sql/hive/src/test/resources/golden/semijoin-49-df1d6705d3624be72036318a6b42f04c
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d b/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d
new file mode 100644
index 0000000000000..60f6eacee9b14
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-5-d3c2f84a12374b307c58a69aba4ec70d
@@ -0,0 +1,22 @@
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+0 val_0
+2 val_2
+4 val_2
+4 val_4
+5 val_5
+5 val_5
+5 val_5
+8 val_4
+8 val_8
+9 val_9
+10 val_10
+10 val_5
+10 val_5
+10 val_5
+16 val_8
+18 val_9
+20 val_10
diff --git a/sql/hive/src/test/resources/golden/semijoin-6-90bb51b1330230d10a14fb7517457aa0 b/sql/hive/src/test/resources/golden/semijoin-6-90bb51b1330230d10a14fb7517457aa0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-7-333d72e8bce6d11a35fc7a30418f225b b/sql/hive/src/test/resources/golden/semijoin-7-333d72e8bce6d11a35fc7a30418f225b
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-8-d46607be851a6f4e27e98cbbefdee994 b/sql/hive/src/test/resources/golden/semijoin-8-d46607be851a6f4e27e98cbbefdee994
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2 b/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2
new file mode 100644
index 0000000000000..5baaac9bebf6d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/semijoin-9-f7adaf0f77ce6ff8c3a4807f428d8de2
@@ -0,0 +1,6 @@
+0 val_0
+0 val_0
+0 val_0
+4 val_4
+8 val_8
+10 val_10
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index f9a162ef4e3c0..91ac03ca30cd7 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -56,4 +56,20 @@ class CachedTableSuite extends HiveComparisonTest {
TestHive.uncacheTable("src")
}
}
+
+ test("'CACHE TABLE' and 'UNCACHE TABLE' HiveQL statement") {
+ TestHive.hql("CACHE TABLE src")
+ TestHive.table("src").queryExecution.executedPlan match {
+ case _: InMemoryColumnarTableScan => // Found evidence of caching
+ case _ => fail(s"Table 'src' should be cached")
+ }
+ assert(TestHive.isCached("src"), "Table 'src' should be cached")
+
+ TestHive.hql("UNCACHE TABLE src")
+ TestHive.table("src").queryExecution.executedPlan match {
+ case _: InMemoryColumnarTableScan => fail(s"Table 'src' should not be cached")
+ case _ => // Found evidence of uncaching
+ }
+ assert(!TestHive.isCached("src"), "Table 'src' should not be cached")
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index fb8f272d5abfe..3581617c269a6 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -597,6 +597,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"select_unquote_and",
"select_unquote_not",
"select_unquote_or",
+ "semijoin",
"serde_regex",
"serde_reported_schema",
"set_variable_sub",
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
index 86753360a07e4..a0aeacbc733bd 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/ContextWaiter.scala
@@ -27,6 +27,7 @@ private[streaming] class ContextWaiter {
}
def notifyStop() = synchronized {
+ stopped = true
notifyAll()
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
index 303d149d285e1..d9ac3c91f6e36 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/NetworkReceiverSuite.scala
@@ -29,7 +29,6 @@ import org.scalatest.FunSuite
import org.scalatest.concurrent.Timeouts
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
-import scala.language.postfixOps
/** Testsuite for testing the network receiver behavior */
class NetworkReceiverSuite extends FunSuite with Timeouts {
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index cd86019f63e7e..7b33d3b235466 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -223,6 +223,18 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
}
}
+ test("awaitTermination after stop") {
+ ssc = new StreamingContext(master, appName, batchDuration)
+ val inputStream = addInputStream(ssc)
+ inputStream.map(x => x).register()
+
+ failAfter(10000 millis) {
+ ssc.start()
+ ssc.stop()
+ ssc.awaitTermination()
+ }
+ }
+
test("awaitTermination with error in task") {
ssc = new StreamingContext(master, appName, batchDuration)
val inputStream = addInputStream(ssc)
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
index ef0efa552ceaf..2861f5335ae36 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala
@@ -27,12 +27,12 @@ import org.apache.spark.streaming.dstream.DStream
import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.streaming.scheduler._
-import org.scalatest.matchers.ShouldMatchers
+import org.scalatest.Matchers
import org.scalatest.concurrent.Eventually._
import org.scalatest.time.SpanSugar._
import org.apache.spark.Logging
-class StreamingListenerSuite extends TestSuiteBase with ShouldMatchers {
+class StreamingListenerSuite extends TestSuiteBase with Matchers {
val input = (1 to 4).map(Seq(_)).toSeq
val operation = (d: DStream[Int]) => d.map(x => x)
diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
index 6a261e19a35cd..03a73f92b275e 100644
--- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala
@@ -40,74 +40,78 @@ object GenerateMIMAIgnore {
private val classLoader = Thread.currentThread().getContextClassLoader
private val mirror = runtimeMirror(classLoader)
- private def classesPrivateWithin(packageName: String): Set[String] = {
+
+ private def isDeveloperApi(sym: unv.Symbol) =
+ sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.DeveloperApi])
+
+ private def isExperimental(sym: unv.Symbol) =
+ sym.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.Experimental])
+
+
+ private def isPackagePrivate(sym: unv.Symbol) =
+ !sym.privateWithin.fullName.startsWith("")
+
+ private def isPackagePrivateModule(moduleSymbol: unv.ModuleSymbol) =
+ !moduleSymbol.privateWithin.fullName.startsWith("")
+
+ /**
+ * For every class checks via scala reflection if the class itself or contained members
+ * have DeveloperApi or Experimental annotations or they are package private.
+ * Returns the tuple of such classes and members.
+ */
+ private def privateWithin(packageName: String): (Set[String], Set[String]) = {
val classes = getClasses(packageName)
val ignoredClasses = mutable.HashSet[String]()
+ val ignoredMembers = mutable.HashSet[String]()
- def isPackagePrivate(className: String) = {
+ for (className <- classes) {
try {
- /* Couldn't figure out if it's possible to determine a-priori whether a given symbol
- is a module or class. */
-
- val privateAsClass = mirror
- .classSymbol(Class.forName(className, false, classLoader))
- .privateWithin
- .fullName
- .startsWith(packageName)
-
- val privateAsModule = mirror
- .staticModule(className)
- .privateWithin
- .fullName
- .startsWith(packageName)
-
- privateAsClass || privateAsModule
- } catch {
- case _: Throwable => {
- println("Error determining visibility: " + className)
- false
+ val classSymbol = mirror.classSymbol(Class.forName(className, false, classLoader))
+ val moduleSymbol = mirror.staticModule(className) // TODO: see if it is necessary.
+ val directlyPrivateSpark =
+ isPackagePrivate(classSymbol) || isPackagePrivateModule(moduleSymbol)
+ val developerApi = isDeveloperApi(classSymbol)
+ val experimental = isExperimental(classSymbol)
+
+ /* Inner classes defined within a private[spark] class or object are effectively
+ invisible, so we account for them as package private. */
+ lazy val indirectlyPrivateSpark = {
+ val maybeOuter = className.toString.takeWhile(_ != '$')
+ if (maybeOuter != className) {
+ isPackagePrivate(mirror.classSymbol(Class.forName(maybeOuter, false, classLoader))) ||
+ isPackagePrivateModule(mirror.staticModule(maybeOuter))
+ } else {
+ false
+ }
+ }
+ if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi || experimental) {
+ ignoredClasses += className
+ } else {
+ // check if this class has package-private/annotated members.
+ ignoredMembers ++= getAnnotatedOrPackagePrivateMembers(classSymbol)
}
- }
- }
- def isDeveloperApi(className: String) = {
- try {
- val clazz = mirror.classSymbol(Class.forName(className, false, classLoader))
- clazz.annotations.exists(_.tpe =:= unv.typeOf[org.apache.spark.annotation.DeveloperApi])
} catch {
- case _: Throwable => {
- println("Error determining Annotations: " + className)
- false
- }
+ case _: Throwable => println("Error instrumenting class:" + className)
}
}
+ (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet)
+ }
- for (className <- classes) {
- val directlyPrivateSpark = isPackagePrivate(className)
- val developerApi = isDeveloperApi(className)
-
- /* Inner classes defined within a private[spark] class or object are effectively
- invisible, so we account for them as package private. */
- val indirectlyPrivateSpark = {
- val maybeOuter = className.toString.takeWhile(_ != '$')
- if (maybeOuter != className) {
- isPackagePrivate(maybeOuter)
- } else {
- false
- }
- }
- if (directlyPrivateSpark || indirectlyPrivateSpark || developerApi) {
- ignoredClasses += className
- }
- }
- ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet
+ private def getAnnotatedOrPackagePrivateMembers(classSymbol: unv.ClassSymbol) = {
+ classSymbol.typeSignature.members
+ .filter(x => isPackagePrivate(x) || isDeveloperApi(x) || isExperimental(x)).map(_.fullName)
}
def main(args: Array[String]) {
- scala.tools.nsc.io.File(".generated-mima-excludes").
- writeAll(classesPrivateWithin("org.apache.spark").mkString("\n"))
- println("Created : .generated-mima-excludes in current directory.")
+ val (privateClasses, privateMembers) = privateWithin("org.apache.spark")
+ scala.tools.nsc.io.File(".generated-mima-class-excludes").
+ writeAll(privateClasses.mkString("\n"))
+ println("Created : .generated-mima-class-excludes in current directory.")
+ scala.tools.nsc.io.File(".generated-mima-member-excludes").
+ writeAll(privateMembers.mkString("\n"))
+ println("Created : .generated-mima-member-excludes in current directory.")
}
@@ -140,10 +144,17 @@ object GenerateMIMAIgnore {
* Get all classes in a package from a jar file.
*/
private def getClassesFromJar(jarPath: String, packageName: String) = {
+ import scala.collection.mutable
val jar = new JarFile(new File(jarPath))
val enums = jar.entries().map(_.getName).filter(_.startsWith(packageName))
- val classes = for (entry <- enums if entry.endsWith(".class"))
- yield Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader)
+ val classes = mutable.HashSet[Class[_]]()
+ for (entry <- enums if entry.endsWith(".class")) {
+ try {
+ classes += Class.forName(entry.replace('/', '.').stripSuffix(".class"), false, classLoader)
+ } catch {
+ case _: Throwable => println("Unable to load:" + entry)
+ }
+ }
classes
}
}
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 8f0ecb855718e..1cc9c33cd2d02 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -277,7 +277,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
yarnAllocator.allocateContainers(
math.max(args.numExecutors - yarnAllocator.getNumExecutorsRunning, 0))
ApplicationMaster.incrementAllocatorLoop(1)
- Thread.sleep(100)
+ Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL)
}
} finally {
// In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT,
@@ -416,6 +416,7 @@ object ApplicationMaster {
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
// optimal as more containers are available. Might need to handle this better.
private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+ private val ALLOCATE_HEARTBEAT_INTERVAL = 100
def incrementAllocatorLoop(by: Int) {
val count = yarnAllocatorLoop.getAndAdd(by)
@@ -467,13 +468,22 @@ object ApplicationMaster {
})
}
- // Wait for initialization to complete and atleast 'some' nodes can get allocated.
+ modified
+ }
+
+
+ /**
+ * Returns when we've either
+ * 1) received all the requested executors,
+ * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms,
+ * 3) hit an error that causes us to terminate trying to get containers.
+ */
+ def waitForInitialAllocations() {
yarnAllocatorLoop.synchronized {
while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) {
yarnAllocatorLoop.wait(1000L)
}
}
- modified
}
def main(argStrings: Array[String]) {
diff --git a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
index a3bd91590fc25..b6ecae1e652fe 100644
--- a/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
+++ b/yarn/alpha/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
@@ -271,6 +271,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
.asInstanceOf[FinishApplicationMasterRequest]
finishReq.setAppAttemptId(appAttemptId)
finishReq.setFinishApplicationStatus(status)
+ finishReq.setTrackingUrl(sparkConf.get("spark.yarn.historyServer.address", ""))
resourceManager.finishApplicationMaster(finishReq)
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
index 801e8b381588f..29a35680c0e72 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
@@ -19,7 +19,6 @@ package org.apache.spark.deploy.yarn
import java.io.File
import java.net.{InetAddress, UnknownHostException, URI, URISyntaxException}
-import java.nio.ByteBuffer
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, ListBuffer, Map}
@@ -37,7 +36,7 @@ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.util.{Apps, Records}
+import org.apache.hadoop.yarn.util.Records
import org.apache.spark.{Logging, SparkConf, SparkContext}
/**
@@ -169,14 +168,13 @@ trait ClientBase extends Logging {
destPath
}
- def qualifyForLocal(localURI: URI): Path = {
+ private def qualifyForLocal(localURI: URI): Path = {
var qualifiedURI = localURI
- // If not specified assume these are in the local filesystem to keep behavior like Hadoop
+ // If not specified, assume these are in the local filesystem to keep behavior like Hadoop
if (qualifiedURI.getScheme() == null) {
qualifiedURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(qualifiedURI)).toString)
}
- val qualPath = new Path(qualifiedURI)
- qualPath
+ new Path(qualifiedURI)
}
def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = {
@@ -305,13 +303,13 @@ trait ClientBase extends Logging {
val amMemory = calculateAMMemory(newApp)
- val JAVA_OPTS = ListBuffer[String]()
+ val javaOpts = ListBuffer[String]()
// Add Xmx for AM memory
- JAVA_OPTS += "-Xmx" + amMemory + "m"
+ javaOpts += "-Xmx" + amMemory + "m"
val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
- JAVA_OPTS += "-Djava.io.tmpdir=" + tmpDir
+ javaOpts += "-Djava.io.tmpdir=" + tmpDir
// TODO: Remove once cpuset version is pushed out.
// The context is, default gc for server class machines ends up using all cores to do gc -
@@ -325,11 +323,11 @@ trait ClientBase extends Logging {
if (useConcurrentAndIncrementalGC) {
// In our expts, using (default) throughput collector has severe perf ramifications in
// multi-tenant machines
- JAVA_OPTS += "-XX:+UseConcMarkSweepGC"
- JAVA_OPTS += "-XX:+CMSIncrementalMode"
- JAVA_OPTS += "-XX:+CMSIncrementalPacing"
- JAVA_OPTS += "-XX:CMSIncrementalDutyCycleMin=0"
- JAVA_OPTS += "-XX:CMSIncrementalDutyCycle=10"
+ javaOpts += "-XX:+UseConcMarkSweepGC"
+ javaOpts += "-XX:+CMSIncrementalMode"
+ javaOpts += "-XX:+CMSIncrementalPacing"
+ javaOpts += "-XX:CMSIncrementalDutyCycleMin=0"
+ javaOpts += "-XX:CMSIncrementalDutyCycle=10"
}
// SPARK_JAVA_OPTS is deprecated, but for backwards compatibility:
@@ -344,22 +342,22 @@ trait ClientBase extends Logging {
// If we are being launched in client mode, forward the spark-conf options
// onto the executor launcher
for ((k, v) <- sparkConf.getAll) {
- JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\""
+ javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\""
}
} else {
// If we are being launched in standalone mode, capture and forward any spark
// system properties (e.g. set by spark-class).
for ((k, v) <- sys.props.filterKeys(_.startsWith("spark"))) {
- JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\""
+ javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\""
}
- sys.props.get("spark.driver.extraJavaOptions").foreach(opts => JAVA_OPTS += opts)
- sys.props.get("spark.driver.libraryPath").foreach(p => JAVA_OPTS += s"-Djava.library.path=$p")
+ sys.props.get("spark.driver.extraJavaOptions").foreach(opts => javaOpts += opts)
+ sys.props.get("spark.driver.libraryPath").foreach(p => javaOpts += s"-Djava.library.path=$p")
}
- JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources)
+ javaOpts += ClientBase.getLog4jConfiguration(localResources)
// Command for the ApplicationMaster
val commands = Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++
- JAVA_OPTS ++
+ javaOpts ++
Seq(args.amClass, "--class", args.userClass, "--jar ", args.userJar,
userArgsToString(args),
"--executor-memory", args.executorMemory.toString,
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
index 32f8861dc9503..43dbb2464f929 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnableUtil.scala
@@ -28,7 +28,7 @@ import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records}
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
import org.apache.spark.{Logging, SparkConf}
@@ -46,19 +46,19 @@ trait ExecutorRunnableUtil extends Logging {
executorCores: Int,
localResources: HashMap[String, LocalResource]): List[String] = {
// Extra options for the JVM
- val JAVA_OPTS = ListBuffer[String]()
+ val javaOpts = ListBuffer[String]()
// Set the JVM memory
val executorMemoryString = executorMemory + "m"
- JAVA_OPTS += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " "
+ javaOpts += "-Xms" + executorMemoryString + " -Xmx" + executorMemoryString + " "
// Set extra Java options for the executor, if defined
sys.props.get("spark.executor.extraJavaOptions").foreach { opts =>
- JAVA_OPTS += opts
+ javaOpts += opts
}
- JAVA_OPTS += "-Djava.io.tmpdir=" +
+ javaOpts += "-Djava.io.tmpdir=" +
new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR)
- JAVA_OPTS += ClientBase.getLog4jConfiguration(localResources)
+ javaOpts += ClientBase.getLog4jConfiguration(localResources)
// Certain configs need to be passed here because they are needed before the Executor
// registers with the Scheduler and transfers the spark configs. Since the Executor backend
@@ -66,10 +66,10 @@ trait ExecutorRunnableUtil extends Logging {
// authentication settings.
sparkConf.getAll.
filter { case (k, v) => k.startsWith("spark.auth") || k.startsWith("spark.akka") }.
- foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" }
+ foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" }
sparkConf.getAkkaConf.
- foreach { case (k, v) => JAVA_OPTS += "-D" + k + "=" + "\\\"" + v + "\\\"" }
+ foreach { case (k, v) => javaOpts += "-D" + k + "=" + "\\\"" + v + "\\\"" }
// Commenting it out for now - so that people can refer to the properties if required. Remove
// it once cpuset version is pushed out.
@@ -88,11 +88,11 @@ trait ExecutorRunnableUtil extends Logging {
// multi-tennent machines
// The options are based on
// http://www.oracle.com/technetwork/java/gc-tuning-5-138395.html#0.0.0.%20When%20to%20Use%20the%20Concurrent%20Low%20Pause%20Collector|outline
- JAVA_OPTS += " -XX:+UseConcMarkSweepGC "
- JAVA_OPTS += " -XX:+CMSIncrementalMode "
- JAVA_OPTS += " -XX:+CMSIncrementalPacing "
- JAVA_OPTS += " -XX:CMSIncrementalDutyCycleMin=0 "
- JAVA_OPTS += " -XX:CMSIncrementalDutyCycle=10 "
+ javaOpts += " -XX:+UseConcMarkSweepGC "
+ javaOpts += " -XX:+CMSIncrementalMode "
+ javaOpts += " -XX:+CMSIncrementalPacing "
+ javaOpts += " -XX:CMSIncrementalDutyCycleMin=0 "
+ javaOpts += " -XX:CMSIncrementalDutyCycle=10 "
}
*/
@@ -104,7 +104,7 @@ trait ExecutorRunnableUtil extends Logging {
// TODO: If the OOM is not recoverable by rescheduling it on different node, then do
// 'something' to fail job ... akin to blacklisting trackers in mapred ?
"-XX:OnOutOfMemoryError='kill %p'") ++
- JAVA_OPTS ++
+ javaOpts ++
Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend",
masterAddress.toString,
slaveId.toString,
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index a4638cc863611..39cdd2e8a522b 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -33,10 +33,11 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
def this(sc: SparkContext) = this(sc, new Configuration())
- // Nothing else for now ... initialize application master : which needs sparkContext to determine how to allocate
- // Note that only the first creation of SparkContext influences (and ideally, there must be only one SparkContext, right ?)
- // Subsequent creations are ignored - since nodes are already allocated by then.
-
+ // Nothing else for now ... initialize application master : which needs a SparkContext to
+ // determine how to allocate.
+ // Note that only the first creation of a SparkContext influences (and ideally, there must be
+ // only one SparkContext, right ?). Subsequent creations are ignored since executors are already
+ // allocated by then.
// By default, rack is unknown
override def getRackForHost(hostPort: String): Option[String] = {
@@ -48,6 +49,7 @@ private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
override def postStartHook() {
val sparkContextInitialized = ApplicationMaster.sparkContextInitialized(sc)
if (sparkContextInitialized){
+ ApplicationMaster.waitForInitialAllocations()
// Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
Thread.sleep(3000L)
}
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 33a60d978c586..6244332f23737 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -19,13 +19,12 @@ package org.apache.spark.deploy.yarn
import java.io.IOException
import java.util.concurrent.CopyOnWriteArrayList
-import java.util.concurrent.atomic.{AtomicInteger, AtomicReference}
+import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConversions._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.hadoop.net.NetUtils
import org.apache.hadoop.util.ShutdownHookManager
import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.protocolrecords._
@@ -33,8 +32,7 @@ import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.AMRMClient
import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest
import org.apache.hadoop.yarn.conf.YarnConfiguration
-import org.apache.hadoop.yarn.ipc.YarnRPC
-import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import org.apache.hadoop.yarn.util.ConverterUtils
import org.apache.hadoop.yarn.webapp.util.WebAppUtils
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
@@ -77,17 +75,18 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
// than user specified and /tmp.
System.setProperty("spark.local.dir", getLocalDirs())
- // set the web ui port to be ephemeral for yarn so we don't conflict with
+ // Set the web ui port to be ephemeral for yarn so we don't conflict with
// other spark processes running on the same box
System.setProperty("spark.ui.port", "0")
- // when running the AM, the Spark master is always "yarn-cluster"
+ // When running the AM, the Spark master is always "yarn-cluster"
System.setProperty("spark.master", "yarn-cluster")
- // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using.
+ // Use priority 30 as it's higher than HDFS. It's the same priority MapReduce is using.
ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30)
- appAttemptId = getApplicationAttemptId()
+ appAttemptId = ApplicationMaster.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts
amClient = AMRMClient.createAMRMClient()
amClient.init(yarnConf)
@@ -99,7 +98,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
ApplicationMaster.register(this)
// Call this to force generation of secret so it gets populated into the
- // hadoop UGI. This has to happen before the startUserClass which does a
+ // Hadoop UGI. This has to happen before the startUserClass which does a
// doAs in order for the credentials to be passed on to the executor containers.
val securityMgr = new SecurityManager(sparkConf)
@@ -121,7 +120,10 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
// Allocate all containers
allocateExecutors()
- // Wait for the user class to Finish
+ // Launch thread that will heartbeat to the RM so it won't think the app has died.
+ launchReporterThread()
+
+ // Wait for the user class to finish
userThread.join()
System.exit(0)
@@ -141,7 +143,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
"spark.org.apache.hadoop.yarn.server.webproxy.amfilter.AmIpFilter.params", params)
}
- /** Get the Yarn approved local directories. */
+ // Get the Yarn approved local directories.
private def getLocalDirs(): String = {
// Hadoop 0.23 and 2.x have different Environment variable names for the
// local dirs, so lets check both. We assume one of the 2 is set.
@@ -150,18 +152,9 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
.orElse(Option(System.getenv("LOCAL_DIRS")))
localDirs match {
- case None => throw new Exception("Yarn Local dirs can't be empty")
+ case None => throw new Exception("Yarn local dirs can't be empty")
case Some(l) => l
}
- }
-
- private def getApplicationAttemptId(): ApplicationAttemptId = {
- val envs = System.getenv()
- val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())
- val containerId = ConverterUtils.toContainerId(containerIdString)
- val appAttemptId = containerId.getApplicationAttemptId()
- logInfo("ApplicationAttemptId: " + appAttemptId)
- appAttemptId
}
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
@@ -173,25 +166,23 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
logInfo("Starting the user JAR in a separate Thread")
val mainMethod = Class.forName(
args.userClass,
- false /* initialize */ ,
+ false,
Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]])
val t = new Thread {
override def run() {
-
- var successed = false
+ var succeeded = false
try {
// Copy
- var mainArgs: Array[String] = new Array[String](args.userArgs.size)
+ val mainArgs = new Array[String](args.userArgs.size)
args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size)
mainMethod.invoke(null, mainArgs)
- // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR
- // userThread will stop here unless it has uncaught exception thrown out
- // It need shutdown hook to set SUCCEEDED
- successed = true
+ // Some apps have "System.exit(0)" at the end. The user thread will stop here unless
+ // it has an uncaught exception thrown out. It needs a shutdown hook to set SUCCEEDED.
+ succeeded = true
} finally {
- logDebug("finishing main")
+ logDebug("Finishing main")
isLastAMRetry = true
- if (successed) {
+ if (succeeded) {
ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
} else {
ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.FAILED)
@@ -199,11 +190,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
}
}
}
+ t.setName("Driver")
t.start()
t
}
- // This need to happen before allocateExecutors()
+ // This needs to happen before allocateExecutors()
private def waitForSparkContextInitialized() {
logInfo("Waiting for Spark context initialization")
try {
@@ -231,7 +223,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
sparkContext.preferredNodeLocationData,
sparkContext.getConf)
} else {
- logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d".
+ logWarning("Unable to retrieve SparkContext in spite of waiting for %d, maxNumTries = %d".
format(numTries * waitTime, maxNumTries))
this.yarnAllocator = YarnAllocationHandler.newAllocator(
yarnConf,
@@ -242,48 +234,37 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
}
}
} finally {
- // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT :
- // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks.
- ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ // In case of exceptions, etc - ensure that the loop in
+ // ApplicationMaster#sparkContextInitialized() breaks.
+ ApplicationMaster.doneWithInitialAllocations()
}
}
private def allocateExecutors() {
try {
- logInfo("Allocating " + args.numExecutors + " executors.")
- // Wait until all containers have finished
+ logInfo("Requesting" + args.numExecutors + " executors.")
+ // Wait until all containers have launched
yarnAllocator.addResourceRequests(args.numExecutors)
yarnAllocator.allocateResources()
// Exits the loop if the user thread exits.
+
+ var iters = 0
while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) {
checkNumExecutorsFailed()
allocateMissingExecutor()
yarnAllocator.allocateResources()
- ApplicationMaster.incrementAllocatorLoop(1)
- Thread.sleep(100)
+ if (iters == ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) {
+ ApplicationMaster.doneWithInitialAllocations()
+ }
+ Thread.sleep(ApplicationMaster.ALLOCATE_HEARTBEAT_INTERVAL)
+ iters += 1
}
} finally {
- // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT,
- // so that the loop in ApplicationMaster#sparkContextInitialized() breaks.
- ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT)
+ // In case of exceptions, etc - ensure that the loop in
+ // ApplicationMaster#sparkContextInitialized() breaks.
+ ApplicationMaster.doneWithInitialAllocations()
}
logInfo("All executors have launched.")
-
- // Launch a progress reporter thread, else the app will get killed after expiration
- // (def: 10mins) timeout.
- if (userThread.isAlive) {
- // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
- val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
-
- // we want to be reasonably responsive without causing too many requests to RM.
- val schedulerInterval =
- sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000)
-
- // must be <= timeoutInterval / 2.
- val interval = math.min(timeoutInterval / 2, schedulerInterval)
-
- launchReporterThread(interval)
- }
}
private def allocateMissingExecutor() {
@@ -303,47 +284,35 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
}
}
- private def launchReporterThread(_sleepTime: Long): Thread = {
- val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime
+ private def launchReporterThread(): Thread = {
+ // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses.
+ val expiryInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+
+ // we want to be reasonably responsive without causing too many requests to RM.
+ val schedulerInterval =
+ sparkConf.getLong("spark.yarn.scheduler.heartbeat.interval-ms", 5000)
+
+ // must be <= timeoutInterval / 2.
+ val interval = math.max(0, math.min(expiryInterval / 2, schedulerInterval))
val t = new Thread {
override def run() {
while (userThread.isAlive) {
checkNumExecutorsFailed()
allocateMissingExecutor()
- sendProgress()
- Thread.sleep(sleepTime)
+ logDebug("Sending progress")
+ yarnAllocator.allocateResources()
+ Thread.sleep(interval)
}
}
}
// Setting to daemon status, though this is usually not a good idea.
t.setDaemon(true)
t.start()
- logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ logInfo("Started progress reporter thread - heartbeat interval : " + interval)
t
}
- private def sendProgress() {
- logDebug("Sending progress")
- // Simulated with an allocate request with no nodes requested.
- yarnAllocator.allocateResources()
- }
-
- /*
- def printContainers(containers: List[Container]) = {
- for (container <- containers) {
- logInfo("Launching shell command on a new container."
- + ", containerId=" + container.getId()
- + ", containerNode=" + container.getNodeId().getHost()
- + ":" + container.getNodeId().getPort()
- + ", containerNodeURI=" + container.getNodeHttpAddress()
- + ", containerState" + container.getState()
- + ", containerResourceMemory"
- + container.getResource().getMemory())
- }
- }
- */
-
def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") {
synchronized {
if (isFinished) {
@@ -351,7 +320,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
}
isFinished = true
- logInfo("finishApplicationMaster with " + status)
+ logInfo("Unregistering ApplicationMaster with " + status)
if (registered) {
val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "")
amClient.unregisterApplicationMaster(status, diagnostics, trackingUrl)
@@ -386,7 +355,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
def run() {
logInfo("AppMaster received a signal.")
- // we need to clean up staging dir before HDFS is shut down
+ // We need to clean up staging dir before HDFS is shut down
// make sure we don't delete it until this is the last AM
if (appMaster.isLastAMRetry) appMaster.cleanupStagingDir()
}
@@ -401,21 +370,24 @@ object ApplicationMaster {
// TODO: Currently, task to container is computed once (TaskSetManager) - which need not be
// optimal as more containers are available. Might need to handle this better.
private val ALLOCATOR_LOOP_WAIT_COUNT = 30
+ private val ALLOCATE_HEARTBEAT_INTERVAL = 100
private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]()
val sparkContextRef: AtomicReference[SparkContext] =
- new AtomicReference[SparkContext](null /* initialValue */)
+ new AtomicReference[SparkContext](null)
- val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0)
+ // Variable used to notify the YarnClusterScheduler that it should stop waiting
+ // for the initial set of executors to be started and get on with its business.
+ val doneWithInitialAllocationsMonitor = new Object()
- def incrementAllocatorLoop(by: Int) {
- val count = yarnAllocatorLoop.getAndAdd(by)
- if (count >= ALLOCATOR_LOOP_WAIT_COUNT) {
- yarnAllocatorLoop.synchronized {
- // to wake threads off wait ...
- yarnAllocatorLoop.notifyAll()
- }
+ @volatile var isDoneWithInitialAllocations = false
+
+ def doneWithInitialAllocations() {
+ isDoneWithInitialAllocations = true
+ doneWithInitialAllocationsMonitor.synchronized {
+ // to wake threads off wait ...
+ doneWithInitialAllocationsMonitor.notifyAll()
}
}
@@ -423,7 +395,10 @@ object ApplicationMaster {
applicationMasters.add(master)
}
- // TODO(harvey): See whether this should be discarded - it isn't used anywhere atm...
+ /**
+ * Called from YarnClusterScheduler to notify the AM code that a SparkContext has been
+ * initialized in the user code.
+ */
def sparkContextInitialized(sc: SparkContext): Boolean = {
var modified = false
sparkContextRef.synchronized {
@@ -431,7 +406,7 @@ object ApplicationMaster {
sparkContextRef.notifyAll()
}
- // Add a shutdown hook - as a best case effort in case users do not call sc.stop or do
+ // Add a shutdown hook - as a best effort in case users do not call sc.stop or do
// System.exit.
// Should not really have to do this, but it helps YARN to evict resources earlier.
// Not to mention, prevent the Client from declaring failure even though we exited properly.
@@ -454,13 +429,29 @@ object ApplicationMaster {
})
}
- // Wait for initialization to complete and atleast 'some' nodes can get allocated.
- yarnAllocatorLoop.synchronized {
- while (yarnAllocatorLoop.get() <= ALLOCATOR_LOOP_WAIT_COUNT) {
- yarnAllocatorLoop.wait(1000L)
+ // Wait for initialization to complete and at least 'some' nodes to get allocated.
+ modified
+ }
+
+ /**
+ * Returns when we've either
+ * 1) received all the requested executors,
+ * 2) waited ALLOCATOR_LOOP_WAIT_COUNT * ALLOCATE_HEARTBEAT_INTERVAL ms,
+ * 3) hit an error that causes us to terminate trying to get containers.
+ */
+ def waitForInitialAllocations() {
+ doneWithInitialAllocationsMonitor.synchronized {
+ while (!isDoneWithInitialAllocations) {
+ doneWithInitialAllocationsMonitor.wait(1000L)
}
}
- modified
+ }
+
+ def getApplicationAttemptId(): ApplicationAttemptId = {
+ val containerIdString = System.getenv(ApplicationConstants.Environment.CONTAINER_ID.name())
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ appAttemptId
}
def main(argStrings: Array[String]) {
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 393edd1f2d670..24027618c1f35 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -21,14 +21,12 @@ import java.nio.ByteBuffer
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.DataOutputBuffer
-import org.apache.hadoop.yarn.api._
-import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.client.api.YarnClient
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
-import org.apache.hadoop.yarn.util.{Apps, Records}
+import org.apache.hadoop.yarn.util.Records
import org.apache.spark.{Logging, SparkConf}
@@ -102,7 +100,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
def logClusterResourceDetails() {
val clusterMetrics: YarnClusterMetrics = yarnClient.getYarnClusterMetrics
- logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " +
+ logInfo("Got Cluster metric info from ResourceManager, number of NodeManagers: " +
clusterMetrics.getNumNodeManagers)
val queueInfo: QueueInfo = yarnClient.getQueueInfo(args.amQueue)
@@ -133,7 +131,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
def submitApp(appContext: ApplicationSubmissionContext) = {
// Submit the application to the applications manager.
- logInfo("Submitting application to ASM")
+ logInfo("Submitting application to ResourceManager")
yarnClient.submitApplication(appContext)
}
@@ -149,7 +147,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
Thread.sleep(interval)
val report = yarnClient.getApplicationReport(appId)
- logInfo("Application report from ASM: \n" +
+ logInfo("Application report from ResourceManager: \n" +
"\t application identifier: " + appId.toString() + "\n" +
"\t appId: " + appId.getId() + "\n" +
"\t clientToAMToken: " + report.getClientToAMToken() + "\n" +
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
index d93e5bb0225d5..f71ad036ce0f2 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
@@ -72,8 +72,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
override def preStart() {
logInfo("Listen to driver: " + driverUrl)
driver = context.actorSelection(driverUrl)
- // Send a hello message thus the connection is actually established,
- // thus we can monitor Lifecycle Events.
+ // Send a hello message to establish the connection, after which
+ // we can monitor Lifecycle Events.
driver ! "Hello"
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
}
@@ -95,7 +95,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
amClient.init(yarnConf)
amClient.start()
- appAttemptId = getApplicationAttemptId()
+ appAttemptId = ApplicationMaster.getApplicationAttemptId()
registerApplicationMaster()
waitForSparkMaster()
@@ -115,7 +115,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
val interval = math.min(timeoutInterval / 2, schedulerInterval)
reporterThread = launchReporterThread(interval)
-
+
// Wait for the reporter thread to Finish.
reporterThread.join()
@@ -134,25 +134,16 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
// LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X
val localDirs = Option(System.getenv("YARN_LOCAL_DIRS"))
.orElse(Option(System.getenv("LOCAL_DIRS")))
-
+
localDirs match {
case None => throw new Exception("Yarn Local dirs can't be empty")
case Some(l) => l
}
- }
-
- private def getApplicationAttemptId(): ApplicationAttemptId = {
- val envs = System.getenv()
- val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name())
- val containerId = ConverterUtils.toContainerId(containerIdString)
- val appAttemptId = containerId.getApplicationAttemptId()
- logInfo("ApplicationAttemptId: " + appAttemptId)
- appAttemptId
}
private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
logInfo("Registering the ApplicationMaster")
- // TODO:(Raymond) Find out Spark UI address and fill in here?
+ // TODO: Find out client's Spark UI address and fill in here?
amClient.registerApplicationMaster(Utils.localHostName(), 0, "")
}
@@ -185,8 +176,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
private def allocateExecutors() {
-
- // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
+ // TODO: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] =
scala.collection.immutable.Map()
@@ -198,8 +188,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
preferredNodeLocationData,
sparkConf)
- logInfo("Allocating " + args.numExecutors + " executors.")
- // Wait until all containers have finished
+ logInfo("Requesting " + args.numExecutors + " executors.")
+ // Wait until all containers have launched
yarnAllocator.addResourceRequests(args.numExecutors)
yarnAllocator.allocateResources()
while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) {
@@ -221,7 +211,6 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
}
}
- // TODO: We might want to extend this to allocate more containers in case they die !
private def launchReporterThread(_sleepTime: Long): Thread = {
val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime
@@ -229,7 +218,8 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
override def run() {
while (!driverClosed) {
allocateMissingExecutor()
- sendProgress()
+ logDebug("Sending progress")
+ yarnAllocator.allocateResources()
Thread.sleep(sleepTime)
}
}
@@ -241,20 +231,14 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
t
}
- private def sendProgress() {
- logDebug("Sending progress")
- // simulated with an allocate request with no nodes requested ...
- yarnAllocator.allocateResources()
- }
-
def finishApplicationMaster(status: FinalApplicationStatus) {
- logInfo("finish ApplicationMaster with " + status)
- amClient.unregisterApplicationMaster(status, "" /* appMessage */ , "" /* appTrackingUrl */)
+ logInfo("Unregistering ApplicationMaster with " + status)
+ val trackingUrl = sparkConf.get("spark.yarn.historyServer.address", "")
+ amClient.unregisterApplicationMaster(status, "" /* appMessage */ , trackingUrl)
}
}
-
object ExecutorLauncher {
def main(argStrings: Array[String]) {
val args = new ApplicationMasterArguments(argStrings)