diff --git a/.gitignore b/.gitignore
index 3b9086c7187dc..9757054a50f9e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,16 +8,19 @@
*.pyc
.idea/
.idea_modules/
-sbt/*.jar
+build/*.jar
.settings
.cache
+cache
.generated-mima*
-/build/
work/
out/
.DS_Store
third_party/libmesos.so
third_party/libmesos.dylib
+build/apache-maven*
+build/zinc*
+build/scala*
conf/java-opts
conf/*.sh
conf/*.cmd
@@ -51,10 +54,11 @@ checkpoint
derby.log
dist/
dev/create-release/*txt
-dev/create-release/*new
+dev/create-release/*final
spark-*-bin-*.tgz
unit-tests.log
/lib/
+ec2/lib/
rat-results.txt
scalastyle.txt
scalastyle-output.xml
diff --git a/.rat-excludes b/.rat-excludes
index d8bee1f8e49c9..769defbac11b7 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -64,3 +64,4 @@ dist/*
logs
.*scalastyle-output.xml
.*dependency-reduced-pom.xml
+known_translations
diff --git a/LICENSE b/LICENSE
index 4f2f0e7a7006a..d0bb2ae15bc1a 100644
--- a/LICENSE
+++ b/LICENSE
@@ -646,7 +646,8 @@ THE SOFTWARE.
========================================================================
For Scala Interpreter classes (all .scala files in repl/src/main/scala
-except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala):
+except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala),
+and for SerializableMapWrapper in JavaUtils.scala:
========================================================================
Copyright (c) 2002-2013 EPFL
diff --git a/README.md b/README.md
index 8d57d50da96c9..af02339578195 100644
--- a/README.md
+++ b/README.md
@@ -26,7 +26,7 @@ To build Spark and its example programs, run:
(You do not need to do this if you downloaded a pre-built package.)
More detailed documentation is available from the project site, at
-["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-with-maven.html).
+["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html).
## Interactive Scala Shell
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 4e2b773e7d2f3..301ff69c2ae3b 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -39,16 +39,10 @@
spark
/usr/share/spark
root
- 744
+ 755
-
-
- com.google.guava
- guava
- compile
-
org.apache.spark
spark-core_${scala.binary.version}
@@ -133,20 +127,6 @@
shade
-
-
- com.google
- org.spark-project.guava
-
- com.google.common.**
-
-
- com/google/common/base/Absent*
- com/google/common/base/Optional*
- com/google/common/base/Present*
-
-
-
@@ -169,16 +149,6 @@
-
- yarn-alpha
-
-
- org.apache.spark
- spark-yarn-alpha_${scala.binary.version}
- ${project.version}
-
-
-
yarn
@@ -310,7 +280,7 @@
${deb.user}
${deb.user}
${deb.install.path}/conf
- 744
+ ${deb.bin.filemode}
@@ -332,7 +302,7 @@
${deb.user}
${deb.user}
${deb.install.path}/sbin
- 744
+ ${deb.bin.filemode}
@@ -343,7 +313,7 @@
${deb.user}
${deb.user}
${deb.install.path}/python
- 744
+ ${deb.bin.filemode}
@@ -364,5 +334,25 @@
+
+
+
+ hadoop-provided
+
+ provided
+
+
+
+ hive-provided
+
+ provided
+
+
+
+ parquet-provided
+
+ provided
+
+
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 0327ffa402671..510e92640eff8 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -40,15 +40,6 @@
spark-core_${scala.binary.version}
${project.version}
-
- org.eclipse.jetty
- jetty-server
-
-
- org.scalatest
- scalatest_${scala.binary.version}
- test
-
org.scalacheck
scalacheck_${scala.binary.version}
@@ -58,11 +49,5 @@
target/scala-${scala.binary.version}/classes
target/scala-${scala.binary.version}/test-classes
-
-
- org.scalatest
- scalatest-maven-plugin
-
-
diff --git a/bagel/src/test/resources/log4j.properties b/bagel/src/test/resources/log4j.properties
index 789869f72e3b0..853ef0ed2986f 100644
--- a/bagel/src/test/resources/log4j.properties
+++ b/bagel/src/test/resources/log4j.properties
@@ -15,10 +15,10 @@
# limitations under the License.
#
-# Set everything to be logged to the file bagel/target/unit-tests.log
+# Set everything to be logged to the file target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
-log4j.appender.file.append=false
+log4j.appender.file.append=true
log4j.appender.file.file=target/unit-tests.log
log4j.appender.file.layout=org.apache.log4j.PatternLayout
log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index a4c099fb45b14..088f993954d9e 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -109,6 +109,13 @@ if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
:no_yarn_conf_dir
+rem To allow for distributions to append needed libraries to the classpath (e.g. when
+rem using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and
+rem append it to tbe final classpath.
+if not "x%$SPARK_DIST_CLASSPATH%"=="x" (
+ set CLASSPATH=%CLASSPATH%;%SPARK_DIST_CLASSPATH%
+)
+
rem A bit of a hack to allow calling this script within run2.cmd without seeing output
if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 298641f2684de..a8c344b1ca594 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -25,7 +25,11 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
. "$FWDIR"/bin/load-spark-env.sh
-CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH"
+if [ -n "$SPARK_CLASSPATH" ]; then
+ CLASSPATH="$SPARK_CLASSPATH:$SPARK_SUBMIT_CLASSPATH"
+else
+ CLASSPATH="$SPARK_SUBMIT_CLASSPATH"
+fi
# Build up classpath
if [ -n "$SPARK_CONF_DIR" ]; then
@@ -46,8 +50,8 @@ fi
if [ -n "$SPARK_PREPEND_CLASSES" ]; then
echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\
"classes ahead of assembly." >&2
+ # Spark classes
CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*"
CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes"
@@ -59,6 +63,8 @@ if [ -n "$SPARK_PREPEND_CLASSES" ]; then
CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes"
+ # Jars for shaded deps in their original form (copied here during build)
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*"
fi
# Use spark-assembly jar from either RELEASE or assembly directory
@@ -68,22 +74,25 @@ else
assembly_folder="$ASSEMBLY_DIR"
fi
-num_jars="$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l)"
-if [ "$num_jars" -eq "0" ]; then
- echo "Failed to find Spark assembly in $assembly_folder"
- echo "You need to build Spark before running this program."
- exit 1
-fi
+num_jars=0
+
+for f in ${assembly_folder}/spark-assembly*hadoop*.jar; do
+ if [[ ! -e "$f" ]]; then
+ echo "Failed to find Spark assembly in $assembly_folder" 1>&2
+ echo "You need to build Spark before running this program." 1>&2
+ exit 1
+ fi
+ ASSEMBLY_JAR="$f"
+ num_jars=$((num_jars+1))
+done
+
if [ "$num_jars" -gt "1" ]; then
- jars_list=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*.jar")
- echo "Found multiple Spark assembly jars in $assembly_folder:"
- echo "$jars_list"
- echo "Please remove all but one jar."
+ echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2
+ ls ${assembly_folder}/spark-assembly*hadoop*.jar 1>&2
+ echo "Please remove all but one jar." 1>&2
exit 1
fi
-ASSEMBLY_JAR="$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null)"
-
# Verify that versions of java used to build the jars and run Spark are compatible
jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1)
if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then
@@ -108,7 +117,7 @@ else
datanucleus_dir="$FWDIR"/lib_managed/jars
fi
-datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar")"
+datanucleus_jars="$(find "$datanucleus_dir" 2>/dev/null | grep "datanucleus-.*\\.jar$")"
datanucleus_jars="$(echo "$datanucleus_jars" | tr "\n" : | sed s/:$//g)"
if [ -n "$datanucleus_jars" ]; then
@@ -142,4 +151,11 @@ if [ -n "$YARN_CONF_DIR" ]; then
CLASSPATH="$CLASSPATH:$YARN_CONF_DIR"
fi
+# To allow for distributions to append needed libraries to the classpath (e.g. when
+# using the "hadoop-provided" profile to build Spark), check SPARK_DIST_CLASSPATH and
+# append it to tbe final classpath.
+if [ -n "$SPARK_DIST_CLASSPATH" ]; then
+ CLASSPATH="$CLASSPATH:$SPARK_DIST_CLASSPATH"
+fi
+
echo "$CLASSPATH"
diff --git a/bin/run-example b/bin/run-example
index 3d932509426fc..c567acf9a6b5c 100755
--- a/bin/run-example
+++ b/bin/run-example
@@ -35,17 +35,32 @@ else
fi
if [ -f "$FWDIR/RELEASE" ]; then
- export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`"
-elif [ -e "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar ]; then
- export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-examples-*hadoop*.jar`"
+ JAR_PATH="${FWDIR}/lib"
+else
+ JAR_PATH="${EXAMPLES_DIR}/target/scala-${SPARK_SCALA_VERSION}"
fi
-if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then
- echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2
- echo "You need to build Spark before running this program" 1>&2
+JAR_COUNT=0
+
+for f in ${JAR_PATH}/spark-examples-*hadoop*.jar; do
+ if [[ ! -e "$f" ]]; then
+ echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2
+ echo "You need to build Spark before running this program" 1>&2
+ exit 1
+ fi
+ SPARK_EXAMPLES_JAR="$f"
+ JAR_COUNT=$((JAR_COUNT+1))
+done
+
+if [ "$JAR_COUNT" -gt "1" ]; then
+ echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2
+ ls ${JAR_PATH}/spark-examples-*hadoop*.jar 1>&2
+ echo "Please remove all but one jar." 1>&2
exit 1
fi
+export SPARK_EXAMPLES_JAR
+
EXAMPLE_MASTER=${MASTER:-"local[*]"}
if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then
diff --git a/bin/spark-class b/bin/spark-class
index 0d58d95c1aee3..2f0441bb3c1c2 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -29,6 +29,7 @@ FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
+export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}"
. "$FWDIR"/bin/load-spark-env.sh
@@ -71,6 +72,8 @@ case "$1" in
'org.apache.spark.executor.MesosExecutorBackend')
OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_EXECUTOR_OPTS"
OUR_JAVA_MEM=${SPARK_EXECUTOR_MEMORY:-$DEFAULT_MEM}
+ export PYTHONPATH="$FWDIR/python:$PYTHONPATH"
+ export PYTHONPATH="$FWDIR/python/lib/py4j-0.8.2.1-src.zip:$PYTHONPATH"
;;
# Spark submit uses SPARK_JAVA_OPTS + SPARK_SUBMIT_OPTS +
@@ -118,8 +121,8 @@ fi
JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
-if [ -e "$FWDIR/conf/java-opts" ] ; then
- JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`"
+if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then
+ JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`"
fi
# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala!
@@ -148,7 +151,7 @@ fi
if [[ "$1" =~ org.apache.spark.tools.* ]]; then
if test -z "$SPARK_TOOLS_JAR"; then
echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2
- echo "You need to build Spark before running $1." 1>&2
+ echo "You need to run \"build/sbt tools/package\" before running $1." 1>&2
exit 1
fi
CLASSPATH="$CLASSPATH:$SPARK_TOOLS_JAR"
diff --git a/bin/spark-shell b/bin/spark-shell
index 4a0670fc6c8aa..cca5aa0676123 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -45,6 +45,13 @@ source "$FWDIR"/bin/utils.sh
SUBMIT_USAGE_FUNCTION=usage
gatherSparkSubmitOpts "$@"
+# SPARK-4161: scala does not assume use of the java classpath,
+# so we need to add the "-Dscala.usejavacp=true" flag mnually. We
+# do this specifically for the Spark shell because the scala REPL
+# has its own class loader, and any additional classpath specified
+# through spark.driver.extraClassPath is not automatically propagated.
+SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Dscala.usejavacp=true"
+
function main() {
if $cygwin; then
# Workaround for issue involving JLine and Cygwin
diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd
old mode 100755
new mode 100644
diff --git a/bin/spark-shell2.cmd b/bin/spark-shell2.cmd
index 2ee60b4e2a2b3..1d1a40da315eb 100644
--- a/bin/spark-shell2.cmd
+++ b/bin/spark-shell2.cmd
@@ -19,4 +19,23 @@ rem
set SPARK_HOME=%~dp0..
-cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %* spark-shell
+echo "%*" | findstr " --help -h" >nul
+if %ERRORLEVEL% equ 0 (
+ call :usage
+ exit /b 0
+)
+
+call %SPARK_HOME%\bin\windows-utils.cmd %*
+if %ERRORLEVEL% equ 1 (
+ call :usage
+ exit /b 1
+)
+
+cmd /V /E /C %SPARK_HOME%\bin\spark-submit.cmd --class org.apache.spark.repl.Main %SUBMISSION_OPTS% spark-shell %APPLICATION_OPTS%
+
+exit /b 0
+
+:usage
+echo "Usage: .\bin\spark-shell.cmd [options]" >&2
+%SPARK_HOME%\bin\spark-submit --help 2>&1 | findstr /V "Usage" 1>&2
+exit /b 0
diff --git a/bin/spark-submit b/bin/spark-submit
index f92d90c3a66b0..3e5cbdbb24394 100755
--- a/bin/spark-submit
+++ b/bin/spark-submit
@@ -38,11 +38,19 @@ while (($#)); do
export SPARK_SUBMIT_CLASSPATH=$2
elif [ "$1" = "--driver-java-options" ]; then
export SPARK_SUBMIT_OPTS=$2
+ elif [ "$1" = "--master" ]; then
+ export MASTER=$2
fi
shift
done
-DEFAULT_PROPERTIES_FILE="$SPARK_HOME/conf/spark-defaults.conf"
+if [ -z "$SPARK_CONF_DIR" ]; then
+ export SPARK_CONF_DIR="$SPARK_HOME/conf"
+fi
+DEFAULT_PROPERTIES_FILE="$SPARK_CONF_DIR/spark-defaults.conf"
+if [ "$MASTER" == "yarn-cluster" ]; then
+ SPARK_SUBMIT_DEPLOY_MODE=cluster
+fi
export SPARK_SUBMIT_DEPLOY_MODE=${SPARK_SUBMIT_DEPLOY_MODE:-"client"}
export SPARK_SUBMIT_PROPERTIES_FILE=${SPARK_SUBMIT_PROPERTIES_FILE:-"$DEFAULT_PROPERTIES_FILE"}
diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd
index cf6046d1547ad..446cbc74b74f9 100644
--- a/bin/spark-submit2.cmd
+++ b/bin/spark-submit2.cmd
@@ -24,7 +24,11 @@ set ORIG_ARGS=%*
rem Reset the values of all variables used
set SPARK_SUBMIT_DEPLOY_MODE=client
-set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
+
+if [%SPARK_CONF_DIR%] == [] (
+ set SPARK_CONF_DIR=%SPARK_HOME%\conf
+)
+set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf
set SPARK_SUBMIT_DRIVER_MEMORY=
set SPARK_SUBMIT_LIBRARY_PATH=
set SPARK_SUBMIT_CLASSPATH=
@@ -45,11 +49,17 @@ if [%1] == [] goto continue
set SPARK_SUBMIT_CLASSPATH=%2
) else if [%1] == [--driver-java-options] (
set SPARK_SUBMIT_OPTS=%2
+ ) else if [%1] == [--master] (
+ set MASTER=%2
)
shift
goto loop
:continue
+if [%MASTER%] == [yarn-cluster] (
+ set SPARK_SUBMIT_DEPLOY_MODE=cluster
+)
+
rem For client mode, the driver will be launched in the same JVM that launches
rem SparkSubmit, so we may need to read the properties file for any extra class
rem paths, library paths, java options and memory early on. Otherwise, it will
diff --git a/bin/utils.sh b/bin/utils.sh
index 22ea2b9a6d586..2241200082018 100755
--- a/bin/utils.sh
+++ b/bin/utils.sh
@@ -26,14 +26,14 @@ function gatherSparkSubmitOpts() {
exit 1
fi
- # NOTE: If you add or remove spark-sumbmit options,
+ # NOTE: If you add or remove spark-submit options,
# modify NOT ONLY this script but also SparkSubmitArgument.scala
SUBMISSION_OPTS=()
APPLICATION_OPTS=()
while (($#)); do
case "$1" in
- --master | --deploy-mode | --class | --name | --jars | --py-files | --files | \
- --conf | --properties-file | --driver-memory | --driver-java-options | \
+ --master | --deploy-mode | --class | --name | --jars | --packages | --py-files | --files | \
+ --conf | --repositories | --properties-file | --driver-memory | --driver-java-options | \
--driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \
--total-executor-cores | --executor-cores | --queue | --num-executors | --archives)
if [[ $# -lt 2 ]]; then
diff --git a/bin/windows-utils.cmd b/bin/windows-utils.cmd
new file mode 100644
index 0000000000000..567b8733f7f77
--- /dev/null
+++ b/bin/windows-utils.cmd
@@ -0,0 +1,59 @@
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem Gather all spark-submit options into SUBMISSION_OPTS
+
+set SUBMISSION_OPTS=
+set APPLICATION_OPTS=
+
+rem NOTE: If you add or remove spark-sumbmit options,
+rem modify NOT ONLY this script but also SparkSubmitArgument.scala
+
+:OptsLoop
+if "x%1"=="x" (
+ goto :OptsLoopEnd
+)
+
+SET opts="\<--master\> \<--deploy-mode\> \<--class\> \<--name\> \<--jars\> \<--py-files\> \<--files\>"
+SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--driver-java-options\>"
+SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>"
+SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>"
+SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\> \<--packages\> \<--repositories\>"
+
+echo %1 | findstr %opts% >nul
+if %ERRORLEVEL% equ 0 (
+ if "x%2"=="x" (
+ echo "%1" requires an argument. >&2
+ exit /b 1
+ )
+ set SUBMISSION_OPTS=%SUBMISSION_OPTS% %1 %2
+ shift
+ shift
+ goto :OptsLoop
+)
+echo %1 | findstr "\<--verbose\> \<-v\> \<--supervise\>" >nul
+if %ERRORLEVEL% equ 0 (
+ set SUBMISSION_OPTS=%SUBMISSION_OPTS% %1
+ shift
+ goto :OptsLoop
+)
+set APPLICATION_OPTS=%APPLICATION_OPTS% %1
+shift
+goto :OptsLoop
+
+:OptsLoopEnd
+exit /b 0
diff --git a/build/mvn b/build/mvn
new file mode 100755
index 0000000000000..53babf54debb6
--- /dev/null
+++ b/build/mvn
@@ -0,0 +1,149 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Determine the current working directory
+_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
+# Preserve the calling directory
+_CALLING_DIR="$(pwd)"
+
+# Installs any application tarball given a URL, the expected tarball name,
+# and, optionally, a checkable binary path to determine if the binary has
+# already been installed
+## Arg1 - URL
+## Arg2 - Tarball Name
+## Arg3 - Checkable Binary
+install_app() {
+ local remote_tarball="$1/$2"
+ local local_tarball="${_DIR}/$2"
+ local binary="${_DIR}/$3"
+
+ # setup `curl` and `wget` silent options if we're running on Jenkins
+ local curl_opts="-L"
+ local wget_opts=""
+ if [ -n "$AMPLAB_JENKINS" ]; then
+ curl_opts="-s ${curl_opts}"
+ wget_opts="--quiet ${wget_opts}"
+ else
+ curl_opts="--progress-bar ${curl_opts}"
+ wget_opts="--progress=bar:force ${wget_opts}"
+ fi
+
+ if [ -z "$3" -o ! -f "$binary" ]; then
+ # check if we already have the tarball
+ # check if we have curl installed
+ # download application
+ [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \
+ echo "exec: curl ${curl_opts} ${remote_tarball}" && \
+ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}"
+ # if the file still doesn't exist, lets try `wget` and cross our fingers
+ [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \
+ echo "exec: wget ${wget_opts} ${remote_tarball}" && \
+ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}"
+ # if both were unsuccessful, exit
+ [ ! -f "${local_tarball}" ] && \
+ echo -n "ERROR: Cannot download $2 with cURL or wget; " && \
+ echo "please install manually and try again." && \
+ exit 2
+ cd "${_DIR}" && tar -xzf "$2"
+ rm -rf "$local_tarball"
+ fi
+}
+
+# Install maven under the build/ folder
+install_mvn() {
+ install_app \
+ "http://archive.apache.org/dist/maven/maven-3/3.2.5/binaries" \
+ "apache-maven-3.2.5-bin.tar.gz" \
+ "apache-maven-3.2.5/bin/mvn"
+ MVN_BIN="${_DIR}/apache-maven-3.2.5/bin/mvn"
+}
+
+# Install zinc under the build/ folder
+install_zinc() {
+ local zinc_path="zinc-0.3.5.3/bin/zinc"
+ [ ! -f "${zinc_path}" ] && ZINC_INSTALL_FLAG=1
+ install_app \
+ "http://downloads.typesafe.com/zinc/0.3.5.3" \
+ "zinc-0.3.5.3.tgz" \
+ "${zinc_path}"
+ ZINC_BIN="${_DIR}/${zinc_path}"
+}
+
+# Determine the Scala version from the root pom.xml file, set the Scala URL,
+# and, with that, download the specific version of Scala necessary under
+# the build/ folder
+install_scala() {
+ # determine the Scala version used in Spark
+ local scala_version=`grep "scala.version" "${_DIR}/../pom.xml" | \
+ head -1 | cut -f2 -d'>' | cut -f1 -d'<'`
+ local scala_bin="${_DIR}/scala-${scala_version}/bin/scala"
+
+ install_app \
+ "http://downloads.typesafe.com/scala/${scala_version}" \
+ "scala-${scala_version}.tgz" \
+ "scala-${scala_version}/bin/scala"
+
+ SCALA_COMPILER="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-compiler.jar"
+ SCALA_LIBRARY="$(cd "$(dirname ${scala_bin})/../lib" && pwd)/scala-library.jar"
+}
+
+# Determines if a given application is already installed. If not, will attempt
+# to install
+## Arg1 - application name
+## Arg2 - Alternate path to local install under build/ dir
+check_and_install_app() {
+ # create the local environment variable in uppercase
+ local app_bin="`echo $1 | awk '{print toupper(\$0)}'`_BIN"
+ # some black magic to set the generated app variable (i.e. MVN_BIN) into the
+ # environment
+ eval "${app_bin}=`which $1 2>/dev/null`"
+
+ if [ -z "`which $1 2>/dev/null`" ]; then
+ install_$1
+ fi
+}
+
+# Setup healthy defaults for the Zinc port if none were provided from
+# the environment
+ZINC_PORT=${ZINC_PORT:-"3030"}
+
+# Check and install all applications necessary to build Spark
+check_and_install_app "mvn"
+
+# Install the proper version of Scala and Zinc for the build
+install_zinc
+install_scala
+
+# Reset the current working directory
+cd "${_CALLING_DIR}"
+
+# Now that zinc is ensured to be installed, check its status and, if its
+# not running or just installed, start it
+if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then
+ ${ZINC_BIN} -shutdown
+ ${ZINC_BIN} -start -port ${ZINC_PORT} \
+ -scala-compiler "${SCALA_COMPILER}" \
+ -scala-library "${SCALA_LIBRARY}" &>/dev/null
+fi
+
+# Set any `mvn` options if not already present
+export MAVEN_OPTS=${MAVEN_OPTS:-"-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"}
+
+# Last, call the `mvn` command as usual
+${MVN_BIN} "$@"
diff --git a/build/sbt b/build/sbt
new file mode 100755
index 0000000000000..28ebb64f7197c
--- /dev/null
+++ b/build/sbt
@@ -0,0 +1,128 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# When creating new tests for Spark SQL Hive, the HADOOP_CLASSPATH must contain the hive jars so
+# that we can run Hive to generate the golden answer. This is not required for normal development
+# or testing.
+for i in "$HIVE_HOME"/lib/*
+do HADOOP_CLASSPATH="$HADOOP_CLASSPATH:$i"
+done
+export HADOOP_CLASSPATH
+
+realpath () {
+(
+ TARGET_FILE="$1"
+
+ cd "$(dirname "$TARGET_FILE")"
+ TARGET_FILE="$(basename "$TARGET_FILE")"
+
+ COUNT=0
+ while [ -L "$TARGET_FILE" -a $COUNT -lt 100 ]
+ do
+ TARGET_FILE="$(readlink "$TARGET_FILE")"
+ cd $(dirname "$TARGET_FILE")
+ TARGET_FILE="$(basename $TARGET_FILE)"
+ COUNT=$(($COUNT + 1))
+ done
+
+ echo "$(pwd -P)/"$TARGET_FILE""
+)
+}
+
+. "$(dirname "$(realpath "$0")")"/sbt-launch-lib.bash
+
+
+declare -r noshare_opts="-Dsbt.global.base=project/.sbtboot -Dsbt.boot.directory=project/.boot -Dsbt.ivy.home=project/.ivy"
+declare -r sbt_opts_file=".sbtopts"
+declare -r etc_sbt_opts_file="/etc/sbt/sbtopts"
+
+usage() {
+ cat < path to global settings/plugins directory (default: ~/.sbt)
+ -sbt-boot path to shared boot directory (default: ~/.sbt/boot in 0.11 series)
+ -ivy path to local Ivy repository (default: ~/.ivy2)
+ -mem set memory options (default: $sbt_mem, which is $(get_mem_opts $sbt_mem))
+ -no-share use all local caches; no sharing
+ -no-global uses global caches, but does not use global ~/.sbt directory.
+ -jvm-debug Turn on JVM debugging, open at the given port.
+ -batch Disable interactive mode
+
+ # sbt version (default: from project/build.properties if present, else latest release)
+ -sbt-version use the specified version of sbt
+ -sbt-jar use the specified jar as the sbt launcher
+ -sbt-rc use an RC version of sbt
+ -sbt-snapshot use a snapshot version of sbt
+
+ # java version (default: java from PATH, currently $(java -version 2>&1 | grep version))
+ -java-home alternate JAVA_HOME
+
+ # jvm options and output control
+ JAVA_OPTS environment variable, if unset uses "$java_opts"
+ SBT_OPTS environment variable, if unset uses "$default_sbt_opts"
+ .sbtopts if this file exists in the current directory, it is
+ prepended to the runner args
+ /etc/sbt/sbtopts if this file exists, it is prepended to the runner args
+ -Dkey=val pass -Dkey=val directly to the java runtime
+ -J-X pass option -X directly to the java runtime
+ (-J is stripped)
+ -S-X add -X to sbt's scalacOptions (-S is stripped)
+ -PmavenProfiles Enable a maven profile for the build.
+
+In the case of duplicated or conflicting options, the order above
+shows precedence: JAVA_OPTS lowest, command line options highest.
+EOM
+}
+
+process_my_args () {
+ while [[ $# -gt 0 ]]; do
+ case "$1" in
+ -no-colors) addJava "-Dsbt.log.noformat=true" && shift ;;
+ -no-share) addJava "$noshare_opts" && shift ;;
+ -no-global) addJava "-Dsbt.global.base=$(pwd)/project/.sbtboot" && shift ;;
+ -sbt-boot) require_arg path "$1" "$2" && addJava "-Dsbt.boot.directory=$2" && shift 2 ;;
+ -sbt-dir) require_arg path "$1" "$2" && addJava "-Dsbt.global.base=$2" && shift 2 ;;
+ -debug-inc) addJava "-Dxsbt.inc.debug=true" && shift ;;
+ -batch) exec /dev/null; then
+ if [ $(command -v curl) ]; then
(curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
- elif hash wget 2>/dev/null; then
+ elif [ $(command -v wget) ]; then
(wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}"
else
printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n"
@@ -150,7 +150,7 @@ process_args () {
-java-home) require_arg path "$1" "$2" && java_cmd="$2/bin/java" && export JAVA_HOME=$2 && shift 2 ;;
-D*) addJava "$1" && shift ;;
- -J*) addJava "${1:2}" && shift ;;
+ -J*) addJava "${1:2}" && shift ;;
-P*) enableProfile "$1" && shift ;;
*) addResidual "$1" && shift ;;
esac
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index 30bcab0c93302..464c14457e53f 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -77,8 +77,8 @@
# sample false Whether to show entire set of samples for histograms ('false' or 'true')
#
# * Default path is /metrics/json for all instances except the master. The master has two paths:
-# /metrics/aplications/json # App information
-# /metrics/master/json # Master information
+# /metrics/applications/json # App information
+# /metrics/master/json # Master information
# org.apache.spark.metrics.sink.GraphiteSink
# Name: Default: Description:
@@ -87,6 +87,7 @@
# period 10 Poll period
# unit seconds Units of poll period
# prefix EMPTY STRING Prefix to prepend to metric name
+# protocol tcp Protocol ("tcp" or "udp") to use
## Examples
# Enable JmxSink for all instances by class name
diff --git a/core/pom.xml b/core/pom.xml
index 1feb00b3a7fb8..66180035e61f1 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -34,6 +34,10 @@
Spark Project Core
http://spark.apache.org/
+
+ com.google.guava
+ guava
+
com.twitter
chill_${scala.binary.version}
@@ -90,32 +94,52 @@
org.apache.curator
curator-recipes
+
+
org.eclipse.jetty
jetty-plus
+ compile
org.eclipse.jetty
jetty-security
+ compile
org.eclipse.jetty
jetty-util
+ compile
org.eclipse.jetty
jetty-server
+ compile
-
- com.google.guava
- guava
+ org.eclipse.jetty
+ jetty-http
compile
+
+ org.eclipse.jetty
+ jetty-continuation
+ compile
+
+
+ org.eclipse.jetty
+ jetty-servlet
+ compile
+
+
+
+ org.eclipse.jetty.orbit
+ javax.servlet
+ ${orbit.version}
+
+
org.apache.commons
commons-lang3
@@ -204,26 +228,45 @@
stream
- com.codahale.metrics
+ io.dropwizard.metrics
metrics-core
- com.codahale.metrics
+ io.dropwizard.metrics
metrics-jvm
- com.codahale.metrics
+ io.dropwizard.metrics
metrics-json
- com.codahale.metrics
+ io.dropwizard.metrics
metrics-graphite
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+ com.fasterxml.jackson.module
+ jackson-module-scala_2.10
+
org.apache.derby
derby
test
+
+ org.apache.ivy
+ ivy
+ ${ivy.version}
+
+
+ oro
+
+ oro
+ ${oro.version}
+
org.tachyonproject
tachyon-client
@@ -276,11 +319,6 @@
selenium-java
test
-
- org.scalatest
- scalatest_${scala.binary.version}
- test
-
org.mockito
mockito-all
@@ -326,19 +364,6 @@
target/scala-${scala.binary.version}/classes
target/scala-${scala.binary.version}/test-classes
-
- org.scalatest
- scalatest-maven-plugin
-
-
- test
-
- test
-
-
-
-
-
org.apache.maven.plugins
@@ -352,9 +377,9 @@
-
+
-
+
@@ -368,59 +393,28 @@
true
-
- org.apache.maven.plugins
- maven-shade-plugin
-
-
- package
-
- shade
-
-
- false
-
-
- com.google.guava:guava
-
-
-
-
-
- com.google.guava:guava
-
- com/google/common/base/Absent*
- com/google/common/base/Optional*
- com/google/common/base/Present*
-
-
-
-
-
-
-
-
org.apache.maven.plugins
maven-dependency-plugin
+
copy-dependencies
package
copy-dependencies
-
+
${project.build.directory}
false
false
true
true
- guava
+
+ guava,jetty-io,jetty-servlet,jetty-continuation,jetty-http,jetty-plus,jetty-util,jetty-server
+
true
diff --git a/core/src/main/java/org/apache/spark/JavaSparkListener.java b/core/src/main/java/org/apache/spark/JavaSparkListener.java
new file mode 100644
index 0000000000000..646496f313507
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/JavaSparkListener.java
@@ -0,0 +1,97 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import org.apache.spark.scheduler.SparkListener;
+import org.apache.spark.scheduler.SparkListenerApplicationEnd;
+import org.apache.spark.scheduler.SparkListenerApplicationStart;
+import org.apache.spark.scheduler.SparkListenerBlockManagerAdded;
+import org.apache.spark.scheduler.SparkListenerBlockManagerRemoved;
+import org.apache.spark.scheduler.SparkListenerEnvironmentUpdate;
+import org.apache.spark.scheduler.SparkListenerExecutorAdded;
+import org.apache.spark.scheduler.SparkListenerExecutorMetricsUpdate;
+import org.apache.spark.scheduler.SparkListenerExecutorRemoved;
+import org.apache.spark.scheduler.SparkListenerJobEnd;
+import org.apache.spark.scheduler.SparkListenerJobStart;
+import org.apache.spark.scheduler.SparkListenerStageCompleted;
+import org.apache.spark.scheduler.SparkListenerStageSubmitted;
+import org.apache.spark.scheduler.SparkListenerTaskEnd;
+import org.apache.spark.scheduler.SparkListenerTaskGettingResult;
+import org.apache.spark.scheduler.SparkListenerTaskStart;
+import org.apache.spark.scheduler.SparkListenerUnpersistRDD;
+
+/**
+ * Java clients should extend this class instead of implementing
+ * SparkListener directly. This is to prevent java clients
+ * from breaking when new events are added to the SparkListener
+ * trait.
+ *
+ * This is a concrete class instead of abstract to enforce
+ * new events get added to both the SparkListener and this adapter
+ * in lockstep.
+ */
+public class JavaSparkListener implements SparkListener {
+
+ @Override
+ public void onStageCompleted(SparkListenerStageCompleted stageCompleted) { }
+
+ @Override
+ public void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) { }
+
+ @Override
+ public void onTaskStart(SparkListenerTaskStart taskStart) { }
+
+ @Override
+ public void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) { }
+
+ @Override
+ public void onTaskEnd(SparkListenerTaskEnd taskEnd) { }
+
+ @Override
+ public void onJobStart(SparkListenerJobStart jobStart) { }
+
+ @Override
+ public void onJobEnd(SparkListenerJobEnd jobEnd) { }
+
+ @Override
+ public void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) { }
+
+ @Override
+ public void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) { }
+
+ @Override
+ public void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) { }
+
+ @Override
+ public void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) { }
+
+ @Override
+ public void onApplicationStart(SparkListenerApplicationStart applicationStart) { }
+
+ @Override
+ public void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) { }
+
+ @Override
+ public void onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate executorMetricsUpdate) { }
+
+ @Override
+ public void onExecutorAdded(SparkListenerExecutorAdded executorAdded) { }
+
+ @Override
+ public void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) { }
+}
diff --git a/core/src/main/java/org/apache/spark/SparkFirehoseListener.java b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
new file mode 100644
index 0000000000000..fbc5666959055
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/SparkFirehoseListener.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+import org.apache.spark.scheduler.*;
+
+/**
+ * Class that allows users to receive all SparkListener events.
+ * Users should override the onEvent method.
+ *
+ * This is a concrete Java class in order to ensure that we don't forget to update it when adding
+ * new methods to SparkListener: forgetting to add a method will result in a compilation error (if
+ * this was a concrete Scala class, default implementations of new event handlers would be inherited
+ * from the SparkListener trait).
+ */
+public class SparkFirehoseListener implements SparkListener {
+
+ public void onEvent(SparkListenerEvent event) { }
+
+ @Override
+ public final void onStageCompleted(SparkListenerStageCompleted stageCompleted) {
+ onEvent(stageCompleted);
+ }
+
+ @Override
+ public final void onStageSubmitted(SparkListenerStageSubmitted stageSubmitted) {
+ onEvent(stageSubmitted);
+ }
+
+ @Override
+ public final void onTaskStart(SparkListenerTaskStart taskStart) {
+ onEvent(taskStart);
+ }
+
+ @Override
+ public final void onTaskGettingResult(SparkListenerTaskGettingResult taskGettingResult) {
+ onEvent(taskGettingResult);
+ }
+
+ @Override
+ public final void onTaskEnd(SparkListenerTaskEnd taskEnd) {
+ onEvent(taskEnd);
+ }
+
+ @Override
+ public final void onJobStart(SparkListenerJobStart jobStart) {
+ onEvent(jobStart);
+ }
+
+ @Override
+ public final void onJobEnd(SparkListenerJobEnd jobEnd) {
+ onEvent(jobEnd);
+ }
+
+ @Override
+ public final void onEnvironmentUpdate(SparkListenerEnvironmentUpdate environmentUpdate) {
+ onEvent(environmentUpdate);
+ }
+
+ @Override
+ public final void onBlockManagerAdded(SparkListenerBlockManagerAdded blockManagerAdded) {
+ onEvent(blockManagerAdded);
+ }
+
+ @Override
+ public final void onBlockManagerRemoved(SparkListenerBlockManagerRemoved blockManagerRemoved) {
+ onEvent(blockManagerRemoved);
+ }
+
+ @Override
+ public final void onUnpersistRDD(SparkListenerUnpersistRDD unpersistRDD) {
+ onEvent(unpersistRDD);
+ }
+
+ @Override
+ public final void onApplicationStart(SparkListenerApplicationStart applicationStart) {
+ onEvent(applicationStart);
+ }
+
+ @Override
+ public final void onApplicationEnd(SparkListenerApplicationEnd applicationEnd) {
+ onEvent(applicationEnd);
+ }
+
+ @Override
+ public final void onExecutorMetricsUpdate(
+ SparkListenerExecutorMetricsUpdate executorMetricsUpdate) {
+ onEvent(executorMetricsUpdate);
+ }
+
+ @Override
+ public final void onExecutorAdded(SparkListenerExecutorAdded executorAdded) {
+ onEvent(executorAdded);
+ }
+
+ @Override
+ public final void onExecutorRemoved(SparkListenerExecutorRemoved executorRemoved) {
+ onEvent(executorRemoved);
+ }
+}
diff --git a/core/src/main/java/org/apache/spark/SparkJobInfo.java b/core/src/main/java/org/apache/spark/SparkJobInfo.java
index 4e3c983b1170a..e31c4401632a6 100644
--- a/core/src/main/java/org/apache/spark/SparkJobInfo.java
+++ b/core/src/main/java/org/apache/spark/SparkJobInfo.java
@@ -17,13 +17,15 @@
package org.apache.spark;
+import java.io.Serializable;
+
/**
* Exposes information about Spark Jobs.
*
* This interface is not designed to be implemented outside of Spark. We may add additional methods
* which may break binary compatibility with outside implementations.
*/
-public interface SparkJobInfo {
+public interface SparkJobInfo extends Serializable {
int jobId();
int[] stageIds();
JobExecutionStatus status();
diff --git a/core/src/main/java/org/apache/spark/SparkStageInfo.java b/core/src/main/java/org/apache/spark/SparkStageInfo.java
index fd74321093658..b7d462abd72d6 100644
--- a/core/src/main/java/org/apache/spark/SparkStageInfo.java
+++ b/core/src/main/java/org/apache/spark/SparkStageInfo.java
@@ -17,13 +17,15 @@
package org.apache.spark;
+import java.io.Serializable;
+
/**
* Exposes information about Spark Stages.
*
* This interface is not designed to be implemented outside of Spark. We may add additional methods
* which may break binary compatibility with outside implementations.
*/
-public interface SparkStageInfo {
+public interface SparkStageInfo extends Serializable {
int stageId();
int currentAttemptId();
long submissionTime();
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
deleted file mode 100644
index 0d6973203eba1..0000000000000
--- a/core/src/main/java/org/apache/spark/TaskContext.java
+++ /dev/null
@@ -1,106 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark;
-
-import java.io.Serializable;
-
-import scala.Function0;
-import scala.Function1;
-import scala.Unit;
-
-import org.apache.spark.annotation.DeveloperApi;
-import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.util.TaskCompletionListener;
-
-/**
- * Contextual information about a task which can be read or mutated during
- * execution. To access the TaskContext for a running task use
- * TaskContext.get().
- */
-public abstract class TaskContext implements Serializable {
- /**
- * Return the currently active TaskContext. This can be called inside of
- * user functions to access contextual information about running tasks.
- */
- public static TaskContext get() {
- return taskContext.get();
- }
-
- private static ThreadLocal taskContext =
- new ThreadLocal();
-
- static void setTaskContext(TaskContext tc) {
- taskContext.set(tc);
- }
-
- static void unset() {
- taskContext.remove();
- }
-
- /**
- * Whether the task has completed.
- */
- public abstract boolean isCompleted();
-
- /**
- * Whether the task has been killed.
- */
- public abstract boolean isInterrupted();
-
- /** @deprecated: use isRunningLocally() */
- @Deprecated
- public abstract boolean runningLocally();
-
- public abstract boolean isRunningLocally();
-
- /**
- * Add a (Java friendly) listener to be executed on task completion.
- * This will be called in all situation - success, failure, or cancellation.
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener);
-
- /**
- * Add a listener in the form of a Scala closure to be executed on task completion.
- * This will be called in all situations - success, failure, or cancellation.
- * An example use is for HadoopRDD to register a callback to close the input stream.
- */
- public abstract TaskContext addTaskCompletionListener(final Function1 f);
-
- /**
- * Add a callback function to be executed on task completion. An example use
- * is for HadoopRDD to register a callback to close the input stream.
- * Will be called in any situation - success, failure, or cancellation.
- *
- * @deprecated: use addTaskCompletionListener
- *
- * @param f Callback function.
- */
- @Deprecated
- public abstract void addOnCompleteCallback(final Function0 f);
-
- public abstract int stageId();
-
- public abstract int partitionId();
-
- public abstract long attemptId();
-
- /** ::DeveloperApi:: */
- @DeveloperApi
- public abstract TaskMetrics taskMetrics();
-}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index cdf85bfbf326f..68b33b5f0d7c7 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -19,6 +19,7 @@
height: 50px;
font-size: 15px;
margin-bottom: 15px;
+ min-width: 1200px
}
.navbar .navbar-inner {
@@ -39,12 +40,12 @@
.navbar .nav > li a {
height: 30px;
- line-height: 30px;
+ line-height: 2;
}
.navbar-text {
height: 50px;
- line-height: 50px;
+ line-height: 3.3;
}
table.sortable thead {
@@ -102,6 +103,12 @@ span.expand-details {
float: right;
}
+span.rest-uri {
+ font-size: 10pt;
+ font-style: italic;
+ color: gray;
+}
+
pre {
font-size: 0.8em;
}
@@ -120,6 +127,14 @@ pre {
border: none;
}
+.description-input {
+ overflow: hidden;
+ text-overflow: ellipsis;
+ width: 100%;
+ white-space: nowrap;
+ display: block;
+}
+
.stacktrace-details {
max-height: 300px;
overflow-y: auto;
@@ -169,8 +184,19 @@ span.additional-metric-title {
display: inline-block;
}
+.version {
+ line-height: 2.5;
+ vertical-align: bottom;
+ font-size: 12px;
+ padding: 0;
+ margin: 0;
+ font-weight: bold;
+ color: #777;
+}
+
/* Hide all additional metrics by default. This is done here rather than using JavaScript to
* avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
-.scheduler_delay, .gc_time, .deserialization_time, .serialization_time, .getting_result_time {
+.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time,
+.getting_result_time {
display: none;
}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 000bbd6b532ad..5f31bfba3f8d6 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -19,6 +19,7 @@ package org.apache.spark
import java.io.{ObjectInputStream, Serializable}
import java.util.concurrent.atomic.AtomicLong
+import java.lang.ThreadLocal
import scala.collection.generic.Growable
import scala.collection.mutable.Map
@@ -278,10 +279,12 @@ object AccumulatorParam {
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
-private object Accumulators {
+private[spark] object Accumulators {
// TODO: Use soft references? => need to make readObject work properly then
val originals = Map[Long, Accumulable[_, _]]()
- val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
+ val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
+ override protected def initialValue() = Map[Long, Accumulable[_, _]]()
+ }
var lastId: Long = 0
def newId(): Long = synchronized {
@@ -293,22 +296,21 @@ private object Accumulators {
if (original) {
originals(a.id) = a
} else {
- val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
- accums(a.id) = a
+ localAccums.get()(a.id) = a
}
}
// Clear the local (non-original) accumulators for the current thread
def clear() {
synchronized {
- localAccums.remove(Thread.currentThread)
+ localAccums.get.clear
}
}
// Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
- for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
+ for ((id, accum) <- localAccums.get) {
ret(id) = accum.localValue
}
return ret
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 79c9c451d273d..3b684bbeceaf2 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -34,7 +34,9 @@ case class Aggregator[K, V, C] (
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
- private val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
+ // When spilling is enabled sorting will happen externally, but not necessarily with an
+ // ExternalSorter.
+ private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
@deprecated("use combineValuesByKey with TaskContext argument", "0.9.0")
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] =
@@ -42,7 +44,7 @@ case class Aggregator[K, V, C] (
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
- if (!externalSorting) {
+ if (!isSpillEnabled) {
val combiners = new AppendOnlyMap[K,C]
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
@@ -59,8 +61,8 @@ case class Aggregator[K, V, C] (
// Update task metrics if context is not null
// TODO: Make context non optional in a future release
Option(context).foreach { c =>
- c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
- c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
+ c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
}
combiners.iterator
}
@@ -71,9 +73,9 @@ case class Aggregator[K, V, C] (
combineCombinersByKey(iter, null)
def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext)
- : Iterator[(K, C)] =
+ : Iterator[(K, C)] =
{
- if (!externalSorting) {
+ if (!isSpillEnabled) {
val combiners = new AppendOnlyMap[K,C]
var kc: Product2[K, C] = null
val update = (hadValue: Boolean, oldValue: C) => {
@@ -93,8 +95,8 @@ case class Aggregator[K, V, C] (
// Update task metrics if context is not null
// TODO: Make context non-optional in a future release
Option(context).foreach { c =>
- c.taskMetrics.memoryBytesSpilled += combiners.memoryBytesSpilled
- c.taskMetrics.diskBytesSpilled += combiners.diskBytesSpilled
+ c.taskMetrics.incMemoryBytesSpilled(combiners.memoryBytesSpilled)
+ c.taskMetrics.incDiskBytesSpilled(combiners.diskBytesSpilled)
}
combiners.iterator
}
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index 80da62c44edc5..a96d754744a05 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -44,9 +44,18 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
blockManager.get(key) match {
case Some(blockResult) =>
// Partition is already materialized, so just return its values
- context.taskMetrics.inputMetrics = Some(blockResult.inputMetrics)
- new InterruptibleIterator(context, blockResult.data.asInstanceOf[Iterator[T]])
-
+ val inputMetrics = blockResult.inputMetrics
+ val existingMetrics = context.taskMetrics
+ .getInputMetricsForReadMethod(inputMetrics.readMethod)
+ existingMetrics.incBytesRead(inputMetrics.bytesRead)
+
+ val iter = blockResult.data.asInstanceOf[Iterator[T]]
+ new InterruptibleIterator[T](context, iter) {
+ override def next(): T = {
+ existingMetrics.incRecordsRead(1)
+ delegate.next()
+ }
+ }
case None =>
// Acquire a lock for loading this partition
// If another thread already holds the lock, wait for it to finish return its results
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index ab2594cfc02eb..9a7cd4523e5ab 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -60,6 +60,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
* @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
* the default serializer, as specified by `spark.serializer` config option, will
* be used.
+ * @param keyOrdering key ordering for RDD's shuffles
+ * @param aggregator map/reduce-side aggregator for RDD's shuffle
+ * @param mapSideCombine whether to perform partial aggregation (also known as map-side combine)
*/
@DeveloperApi
class ShuffleDependency[K, V, C](
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
new file mode 100644
index 0000000000000..a46a81eabd965
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * A client that communicates with the cluster manager to request or kill executors.
+ */
+private[spark] trait ExecutorAllocationClient {
+
+ /**
+ * Request an additional number of executors from the cluster manager.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ def requestExecutors(numAdditionalExecutors: Int): Boolean
+
+ /**
+ * Request that the cluster manager kill the specified executors.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ def killExecutors(executorIds: Seq[String]): Boolean
+
+ /**
+ * Request that the cluster manager kill the specified executor.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId))
+}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 88adb892998af..02d54bf3b53cc 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -49,6 +49,7 @@ import org.apache.spark.scheduler._
* spark.dynamicAllocation.enabled - Whether this feature is enabled
* spark.dynamicAllocation.minExecutors - Lower bound on the number of executors
* spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors
+ * spark.dynamicAllocation.initialExecutors - Number of executors to start with
*
* spark.dynamicAllocation.schedulerBacklogTimeout (M) -
* If there are backlogged tasks for this duration, add new executors
@@ -60,24 +61,30 @@ import org.apache.spark.scheduler._
* spark.dynamicAllocation.executorIdleTimeout (K) -
* If an executor has been idle for this duration, remove it
*/
-private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging {
- import ExecutorAllocationManager._
+private[spark] class ExecutorAllocationManager(
+ client: ExecutorAllocationClient,
+ listenerBus: LiveListenerBus,
+ conf: SparkConf)
+ extends Logging {
+
+ allocationManager =>
- private val conf = sc.conf
+ import ExecutorAllocationManager._
- // Lower and upper bounds on the number of executors. These are required.
- private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1)
- private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1)
+ // Lower and upper bounds on the number of executors.
+ private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", 0)
+ private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors",
+ Integer.MAX_VALUE)
- // How long there must be backlogged tasks for before an addition is triggered
+ // How long there must be backlogged tasks for before an addition is triggered (seconds)
private val schedulerBacklogTimeout = conf.getLong(
- "spark.dynamicAllocation.schedulerBacklogTimeout", 60)
+ "spark.dynamicAllocation.schedulerBacklogTimeout", 5)
// Same as above, but used only after `schedulerBacklogTimeout` is exceeded
private val sustainedSchedulerBacklogTimeout = conf.getLong(
"spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout)
- // How long an executor must be idle for before it is removed
+ // How long an executor must be idle for before it is removed (seconds)
private val executorIdleTimeout = conf.getLong(
"spark.dynamicAllocation.executorIdleTimeout", 600)
@@ -119,7 +126,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
private var clock: Clock = new RealClock
// Listener for Spark events that impact the allocation policy
- private val listener = new ExecutorAllocationListener(this)
+ private val listener = new ExecutorAllocationListener
/**
* Verify that the settings specified through the config are valid.
@@ -127,10 +134,10 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
*/
private def validateSettings(): Unit = {
if (minNumExecutors < 0 || maxNumExecutors < 0) {
- throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!")
+ throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be positive!")
}
- if (minNumExecutors == 0 || maxNumExecutors == 0) {
- throw new SparkException("spark.dynamicAllocation.{min/max}Executors cannot be 0!")
+ if (maxNumExecutors == 0) {
+ throw new SparkException("spark.dynamicAllocation.maxExecutors cannot be 0!")
}
if (minNumExecutors > maxNumExecutors) {
throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " +
@@ -153,7 +160,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
"shuffle service. You may enable this through spark.shuffle.service.enabled.")
}
if (tasksPerExecutor == 0) {
- throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores")
+ throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.")
}
}
@@ -168,7 +175,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
* Register for scheduler callbacks to decide when to add and remove executors.
*/
def start(): Unit = {
- sc.addSparkListener(listener)
+ listenerBus.addListener(listener)
startPolling()
}
@@ -207,11 +214,12 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
addTime += sustainedSchedulerBacklogTimeout * 1000
}
- removeTimes.foreach { case (executorId, expireTime) =>
- if (now >= expireTime) {
+ removeTimes.retain { case (executorId, expireTime) =>
+ val expired = now >= expireTime
+ if (expired) {
removeExecutor(executorId)
- removeTimes.remove(executorId)
}
+ !expired
}
}
@@ -253,7 +261,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd)
val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd
- val addRequestAcknowledged = testing || sc.requestExecutors(actualNumExecutorsToAdd)
+ val addRequestAcknowledged = testing || client.requestExecutors(actualNumExecutorsToAdd)
if (addRequestAcknowledged) {
logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " +
s"tasks are backlogged (new desired total will be $newTotalExecutors)")
@@ -289,13 +297,13 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
// Do not kill the executor if we have already reached the lower bound
val numExistingExecutors = executorIds.size - executorsPendingToRemove.size
if (numExistingExecutors - 1 < minNumExecutors) {
- logInfo(s"Not removing idle executor $executorId because there are only " +
+ logDebug(s"Not removing idle executor $executorId because there are only " +
s"$numExistingExecutors executor(s) left (limit $minNumExecutors)")
return false
}
// Send a request to the backend to kill this executor
- val removeRequestAcknowledged = testing || sc.killExecutor(executorId)
+ val removeRequestAcknowledged = testing || client.killExecutor(executorId)
if (removeRequestAcknowledged) {
logInfo(s"Removing executor $executorId because it has been idle for " +
s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})")
@@ -313,7 +321,11 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
private def onExecutorAdded(executorId: String): Unit = synchronized {
if (!executorIds.contains(executorId)) {
executorIds.add(executorId)
- executorIds.foreach(onExecutorIdle)
+ // If an executor (call this executor X) is not removed because the lower bound
+ // has been reached, it will no longer be marked as idle. When new executors join,
+ // however, we are no longer at the lower bound, and so we must mark executor X
+ // as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951)
+ executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle)
logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})")
if (numExecutorsPending > 0) {
numExecutorsPending -= 1
@@ -371,10 +383,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
* the executor is not already marked as idle.
*/
private def onExecutorIdle(executorId: String): Unit = synchronized {
- if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
- logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
- s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)")
- removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000
+ if (executorIds.contains(executorId)) {
+ if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
+ logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
+ s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)")
+ removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000
+ }
+ } else {
+ logWarning(s"Attempted to mark unknown executor $executorId idle")
}
}
@@ -394,25 +410,24 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
* and consistency of events returned by the listener. For simplicity, it does not account
* for speculated tasks.
*/
- private class ExecutorAllocationListener(allocationManager: ExecutorAllocationManager)
- extends SparkListener {
+ private class ExecutorAllocationListener extends SparkListener {
private val stageIdToNumTasks = new mutable.HashMap[Int, Int]
private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]]
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
- synchronized {
- val stageId = stageSubmitted.stageInfo.stageId
- val numTasks = stageSubmitted.stageInfo.numTasks
+ val stageId = stageSubmitted.stageInfo.stageId
+ val numTasks = stageSubmitted.stageInfo.numTasks
+ allocationManager.synchronized {
stageIdToNumTasks(stageId) = numTasks
allocationManager.onSchedulerBacklogged()
}
}
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
- synchronized {
- val stageId = stageCompleted.stageInfo.stageId
+ val stageId = stageCompleted.stageInfo.stageId
+ allocationManager.synchronized {
stageIdToNumTasks -= stageId
stageIdToTaskIndices -= stageId
@@ -424,64 +439,89 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
}
}
- override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
val stageId = taskStart.stageId
val taskId = taskStart.taskInfo.taskId
val taskIndex = taskStart.taskInfo.index
val executorId = taskStart.taskInfo.executorId
- // If this is the last pending task, mark the scheduler queue as empty
- stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
- val numTasksScheduled = stageIdToTaskIndices(stageId).size
- val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
- if (numTasksScheduled == numTasksTotal) {
- // No more pending tasks for this stage
- stageIdToNumTasks -= stageId
- if (stageIdToNumTasks.isEmpty) {
- allocationManager.onSchedulerQueueEmpty()
+ allocationManager.synchronized {
+ // This guards against the race condition in which the `SparkListenerTaskStart`
+ // event is posted before the `SparkListenerBlockManagerAdded` event, which is
+ // possible because these events are posted in different threads. (see SPARK-4951)
+ if (!allocationManager.executorIds.contains(executorId)) {
+ allocationManager.onExecutorAdded(executorId)
}
- }
- // Mark the executor on which this task is scheduled as busy
- executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId
- allocationManager.onExecutorBusy(executorId)
+ // If this is the last pending task, mark the scheduler queue as empty
+ stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
+ val numTasksScheduled = stageIdToTaskIndices(stageId).size
+ val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
+ if (numTasksScheduled == numTasksTotal) {
+ // No more pending tasks for this stage
+ stageIdToNumTasks -= stageId
+ if (stageIdToNumTasks.isEmpty) {
+ allocationManager.onSchedulerQueueEmpty()
+ }
+ }
+
+ // Mark the executor on which this task is scheduled as busy
+ executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId
+ allocationManager.onExecutorBusy(executorId)
+ }
}
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val executorId = taskEnd.taskInfo.executorId
val taskId = taskEnd.taskInfo.taskId
-
- // If the executor is no longer running scheduled any tasks, mark it as idle
- if (executorIdToTaskIds.contains(executorId)) {
- executorIdToTaskIds(executorId) -= taskId
- if (executorIdToTaskIds(executorId).isEmpty) {
- executorIdToTaskIds -= executorId
- allocationManager.onExecutorIdle(executorId)
+ allocationManager.synchronized {
+ // If the executor is no longer running scheduled any tasks, mark it as idle
+ if (executorIdToTaskIds.contains(executorId)) {
+ executorIdToTaskIds(executorId) -= taskId
+ if (executorIdToTaskIds(executorId).isEmpty) {
+ executorIdToTaskIds -= executorId
+ allocationManager.onExecutorIdle(executorId)
+ }
}
}
}
- override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = {
- val executorId = blockManagerAdded.blockManagerId.executorId
+ override def onExecutorAdded(executorAdded: SparkListenerExecutorAdded): Unit = {
+ val executorId = executorAdded.executorId
if (executorId != SparkContext.DRIVER_IDENTIFIER) {
- allocationManager.onExecutorAdded(executorId)
+ // This guards against the race condition in which the `SparkListenerTaskStart`
+ // event is posted before the `SparkListenerBlockManagerAdded` event, which is
+ // possible because these events are posted in different threads. (see SPARK-4951)
+ if (!allocationManager.executorIds.contains(executorId)) {
+ allocationManager.onExecutorAdded(executorId)
+ }
}
}
- override def onBlockManagerRemoved(
- blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = {
- allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId)
+ override def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved): Unit = {
+ allocationManager.onExecutorRemoved(executorRemoved.executorId)
}
/**
* An estimate of the total number of pending tasks remaining for currently running stages. Does
* not account for tasks which may have failed and been resubmitted.
+ *
+ * Note: This is not thread-safe without the caller owning the `allocationManager` lock.
*/
def totalPendingTasks(): Int = {
stageIdToNumTasks.map { case (stageId, numTasks) =>
numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0)
}.sum
}
+
+ /**
+ * Return true if an executor is not currently running a task, and false otherwise.
+ *
+ * Note: This is not thread-safe without the caller owning the `allocationManager` lock.
+ */
+ def isExecutorIdle(executorId: String): Boolean = {
+ !executorIdToTaskIds.contains(executorId)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index edc3889c9ae51..3f33332a81eaf 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -24,6 +24,7 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
private[spark] class HttpFileServer(
+ conf: SparkConf,
securityManager: SecurityManager,
requestedPort: Int = 0)
extends Logging {
@@ -35,13 +36,13 @@ private[spark] class HttpFileServer(
var serverUri : String = null
def initialize() {
- baseDir = Utils.createTempDir()
+ baseDir = Utils.createTempDir(Utils.getLocalDir(conf), "httpd")
fileDir = new File(baseDir, "files")
jarDir = new File(baseDir, "jars")
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server")
+ httpServer = new HttpServer(conf, baseDir, securityManager, requestedPort, "HTTP file server")
httpServer.start()
serverUri = httpServer.uri
logDebug("HTTP file server started at: " + serverUri)
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 912558d0cab7d..09a9ccc226721 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -19,6 +19,7 @@ package org.apache.spark
import java.io.File
+import org.eclipse.jetty.server.ssl.SslSocketConnector
import org.eclipse.jetty.util.security.{Constraint, Password}
import org.eclipse.jetty.security.authentication.DigestAuthenticator
import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService}
@@ -42,6 +43,7 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* around a Jetty server.
*/
private[spark] class HttpServer(
+ conf: SparkConf,
resourceBase: File,
securityManager: SecurityManager,
requestedPort: Int = 0,
@@ -57,7 +59,7 @@ private[spark] class HttpServer(
} else {
logInfo("Starting HTTP Server")
val (actualServer, actualPort) =
- Utils.startServiceOnPort[Server](requestedPort, doStart, serverName)
+ Utils.startServiceOnPort[Server](requestedPort, doStart, conf, serverName)
server = actualServer
port = actualPort
}
@@ -71,7 +73,10 @@ private[spark] class HttpServer(
*/
private def doStart(startPort: Int): (Server, Int) = {
val server = new Server()
- val connector = new SocketConnector
+
+ val connector = securityManager.fileServerSSLOptions.createJettySslContextFactory()
+ .map(new SslSocketConnector(_)).getOrElse(new SocketConnector)
+
connector.setMaxIdleTime(60 * 1000)
connector.setSoLingerTime(-1)
connector.setPort(startPort)
@@ -148,13 +153,14 @@ private[spark] class HttpServer(
}
/**
- * Get the URI of this HTTP server (http://host:port)
+ * Get the URI of this HTTP server (http://host:port or https://host:port)
*/
def uri: String = {
if (server == null) {
throw new ServerStateException("Server is not started")
} else {
- "http://" + Utils.localIpAddress + ":" + port
+ val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http"
+ s"$scheme://${Utils.localIpAddress}:$port"
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index d4f2624061e35..419d093d55643 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -118,15 +118,17 @@ trait Logging {
// org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently
// org.apache.logging.slf4j.Log4jLoggerFactory
val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass)
- val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
- if (!log4j12Initialized && usingLog4j12) {
- val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
- case Some(url) =>
- PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
- case None =>
- System.err.println(s"Spark was unable to load $defaultLogProps")
+ if (usingLog4j12) {
+ val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
+ if (!log4j12Initialized) {
+ val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ case None =>
+ System.err.println(s"Spark was unable to load $defaultLogProps")
+ }
}
}
Logging.initialized = true
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 7d96962c4acd7..6e4edc7c80d7a 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -72,20 +72,22 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
/**
* Class that keeps track of the location of the map output of
* a stage. This is abstract because different versions of MapOutputTracker
- * (driver and worker) use different HashMap to store its metadata.
+ * (driver and executor) use different HashMap to store its metadata.
*/
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
private val timeout = AkkaUtils.askTimeout(conf)
+ private val retryAttempts = AkkaUtils.numRetries(conf)
+ private val retryIntervalMs = AkkaUtils.retryWaitMs(conf)
/** Set to the MapOutputTrackerActor living on the driver. */
var trackerActor: ActorRef = _
/**
- * This HashMap has different behavior for the master and the workers.
+ * This HashMap has different behavior for the driver and the executors.
*
- * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks.
- * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the
- * master's corresponding HashMap.
+ * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks.
+ * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the
+ * driver's corresponding HashMap.
*
* Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a
* thread-safe map.
@@ -99,7 +101,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
protected var epoch: Long = 0
protected val epochLock = new AnyRef
- /** Remembers which map output locations are currently being fetched on a worker. */
+ /** Remembers which map output locations are currently being fetched on an executor. */
private val fetching = new HashSet[Int]
/**
@@ -108,8 +110,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
*/
protected def askTracker(message: Any): Any = {
try {
- val future = trackerActor.ask(message)(timeout)
- Await.result(future, timeout)
+ AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout)
} catch {
case e: Exception =>
logError("Error communicating with MapOutputTracker", e)
@@ -136,14 +137,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
- if (fetching.contains(shuffleId)) {
- // Someone else is fetching it; wait for them to be done
- while (fetching.contains(shuffleId)) {
- try {
- fetching.wait()
- } catch {
- case e: InterruptedException =>
- }
+ // Someone else is fetching it; wait for them to be done
+ while (fetching.contains(shuffleId)) {
+ try {
+ fetching.wait()
+ } catch {
+ case e: InterruptedException =>
}
}
@@ -198,8 +197,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
/**
* Called from executors to update the epoch number, potentially clearing old outputs
- * because of a fetch failure. Each worker task calls this with the latest epoch
- * number on the master at the time it was created.
+ * because of a fetch failure. Each executor task calls this with the latest epoch
+ * number on the driver at the time it was created.
*/
def updateEpoch(newEpoch: Long) {
epochLock.synchronized {
@@ -231,7 +230,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
private var cacheEpoch = epoch
/**
- * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master,
+ * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the driver,
* so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set).
* Other than these two scenarios, nothing should be dropped from this HashMap.
*/
@@ -341,7 +340,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
/**
- * MapOutputTracker for the workers, which fetches map output information from the driver's
+ * MapOutputTracker for the executors, which fetches map output information from the driver's
* MapOutputTrackerMaster.
*/
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala
index 27892dbd2a0bc..dd3f28e4197e3 100644
--- a/core/src/main/scala/org/apache/spark/Partition.scala
+++ b/core/src/main/scala/org/apache/spark/Partition.scala
@@ -18,11 +18,11 @@
package org.apache.spark
/**
- * A partition of an RDD.
+ * An identifier for a partition in an RDD.
*/
trait Partition extends Serializable {
/**
- * Get the split's index within its parent RDD
+ * Get the partition's index within its parent RDD
*/
def index: Int
diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala
new file mode 100644
index 0000000000000..2cdc167f85af0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.File
+
+import com.typesafe.config.{Config, ConfigFactory, ConfigValueFactory}
+import org.eclipse.jetty.util.ssl.SslContextFactory
+
+/**
+ * SSLOptions class is a common container for SSL configuration options. It offers methods to
+ * generate specific objects to configure SSL for different communication protocols.
+ *
+ * SSLOptions is intended to provide the maximum common set of SSL settings, which are supported
+ * by the protocol, which it can generate the configuration for. Since Akka doesn't support client
+ * authentication with SSL, SSLOptions cannot support it either.
+ *
+ * @param enabled enables or disables SSL; if it is set to false, the rest of the
+ * settings are disregarded
+ * @param keyStore a path to the key-store file
+ * @param keyStorePassword a password to access the key-store file
+ * @param keyPassword a password to access the private key in the key-store
+ * @param trustStore a path to the trust-store file
+ * @param trustStorePassword a password to access the trust-store file
+ * @param protocol SSL protocol (remember that SSLv3 was compromised) supported by Java
+ * @param enabledAlgorithms a set of encryption algorithms to use
+ */
+private[spark] case class SSLOptions(
+ enabled: Boolean = false,
+ keyStore: Option[File] = None,
+ keyStorePassword: Option[String] = None,
+ keyPassword: Option[String] = None,
+ trustStore: Option[File] = None,
+ trustStorePassword: Option[String] = None,
+ protocol: Option[String] = None,
+ enabledAlgorithms: Set[String] = Set.empty) {
+
+ /**
+ * Creates a Jetty SSL context factory according to the SSL settings represented by this object.
+ */
+ def createJettySslContextFactory(): Option[SslContextFactory] = {
+ if (enabled) {
+ val sslContextFactory = new SslContextFactory()
+
+ keyStore.foreach(file => sslContextFactory.setKeyStorePath(file.getAbsolutePath))
+ trustStore.foreach(file => sslContextFactory.setTrustStore(file.getAbsolutePath))
+ keyStorePassword.foreach(sslContextFactory.setKeyStorePassword)
+ trustStorePassword.foreach(sslContextFactory.setTrustStorePassword)
+ keyPassword.foreach(sslContextFactory.setKeyManagerPassword)
+ protocol.foreach(sslContextFactory.setProtocol)
+ sslContextFactory.setIncludeCipherSuites(enabledAlgorithms.toSeq: _*)
+
+ Some(sslContextFactory)
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Creates an Akka configuration object which contains all the SSL settings represented by this
+ * object. It can be used then to compose the ultimate Akka configuration.
+ */
+ def createAkkaConfig: Option[Config] = {
+ import scala.collection.JavaConversions._
+ if (enabled) {
+ Some(ConfigFactory.empty()
+ .withValue("akka.remote.netty.tcp.security.key-store",
+ ConfigValueFactory.fromAnyRef(keyStore.map(_.getAbsolutePath).getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.key-store-password",
+ ConfigValueFactory.fromAnyRef(keyStorePassword.getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.trust-store",
+ ConfigValueFactory.fromAnyRef(trustStore.map(_.getAbsolutePath).getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.trust-store-password",
+ ConfigValueFactory.fromAnyRef(trustStorePassword.getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.key-password",
+ ConfigValueFactory.fromAnyRef(keyPassword.getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.random-number-generator",
+ ConfigValueFactory.fromAnyRef(""))
+ .withValue("akka.remote.netty.tcp.security.protocol",
+ ConfigValueFactory.fromAnyRef(protocol.getOrElse("")))
+ .withValue("akka.remote.netty.tcp.security.enabled-algorithms",
+ ConfigValueFactory.fromIterable(enabledAlgorithms.toSeq))
+ .withValue("akka.remote.netty.tcp.enable-ssl",
+ ConfigValueFactory.fromAnyRef(true)))
+ } else {
+ None
+ }
+ }
+
+ /** Returns a string representation of this SSLOptions with all the passwords masked. */
+ override def toString: String = s"SSLOptions{enabled=$enabled, " +
+ s"keyStore=$keyStore, keyStorePassword=${keyStorePassword.map(_ => "xxx")}, " +
+ s"trustStore=$trustStore, trustStorePassword=${trustStorePassword.map(_ => "xxx")}, " +
+ s"protocol=$protocol, enabledAlgorithms=$enabledAlgorithms}"
+
+}
+
+private[spark] object SSLOptions extends Logging {
+
+ /** Resolves SSLOptions settings from a given Spark configuration object at a given namespace.
+ *
+ * The following settings are allowed:
+ * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively
+ * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory
+ * $ - `[ns].keyStorePassword` - a password to the key-store file
+ * $ - `[ns].keyPassword` - a password to the private key
+ * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current
+ * directory
+ * $ - `[ns].trustStorePassword` - a password to the trust-store file
+ * $ - `[ns].protocol` - a protocol name supported by a particular Java version
+ * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers
+ *
+ * For a list of protocols and ciphers supported by particular Java versions, you may go to
+ * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle
+ * blog page]].
+ *
+ * You can optionally specify the default configuration. If you do, for each setting which is
+ * missing in SparkConf, the corresponding setting is used from the default configuration.
+ *
+ * @param conf Spark configuration object where the settings are collected from
+ * @param ns the namespace name
+ * @param defaults the default configuration
+ * @return [[org.apache.spark.SSLOptions]] object
+ */
+ def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = {
+ val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled))
+
+ val keyStore = conf.getOption(s"$ns.keyStore").map(new File(_))
+ .orElse(defaults.flatMap(_.keyStore))
+
+ val keyStorePassword = conf.getOption(s"$ns.keyStorePassword")
+ .orElse(defaults.flatMap(_.keyStorePassword))
+
+ val keyPassword = conf.getOption(s"$ns.keyPassword")
+ .orElse(defaults.flatMap(_.keyPassword))
+
+ val trustStore = conf.getOption(s"$ns.trustStore").map(new File(_))
+ .orElse(defaults.flatMap(_.trustStore))
+
+ val trustStorePassword = conf.getOption(s"$ns.trustStorePassword")
+ .orElse(defaults.flatMap(_.trustStorePassword))
+
+ val protocol = conf.getOption(s"$ns.protocol")
+ .orElse(defaults.flatMap(_.protocol))
+
+ val enabledAlgorithms = conf.getOption(s"$ns.enabledAlgorithms")
+ .map(_.split(",").map(_.trim).filter(_.nonEmpty).toSet)
+ .orElse(defaults.map(_.enabledAlgorithms))
+ .getOrElse(Set.empty)
+
+ new SSLOptions(
+ enabled,
+ keyStore,
+ keyStorePassword,
+ keyPassword,
+ trustStore,
+ trustStorePassword,
+ protocol,
+ enabledAlgorithms)
+ }
+
+}
+
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index dbff9d12b5ad7..88d35a4bacc6e 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -18,7 +18,11 @@
package org.apache.spark
import java.net.{Authenticator, PasswordAuthentication}
+import java.security.KeyStore
+import java.security.cert.X509Certificate
+import javax.net.ssl._
+import com.google.common.io.Files
import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
@@ -55,7 +59,7 @@ import org.apache.spark.network.sasl.SecretKeyHolder
* Spark also has a set of admin acls (`spark.admin.acls`) which is a set of users/administrators
* who always have permission to view or modify the Spark application.
*
- * Spark does not currently support encryption after authentication.
+ * Starting from version 1.3, Spark has partial support for encrypted connections with SSL.
*
* At this point spark has multiple communication protocols that need to be secured and
* different underlying mechanisms are used depending on the protocol:
@@ -67,8 +71,9 @@ import org.apache.spark.network.sasl.SecretKeyHolder
* to connect to the server. There is no control of the underlying
* authentication mechanism so its not clear if the password is passed in
* plaintext or uses DIGEST-MD5 or some other mechanism.
- * Akka also has an option to turn on SSL, this option is not currently supported
- * but we could add a configuration option in the future.
+ *
+ * Akka also has an option to turn on SSL, this option is currently supported (see
+ * the details below).
*
* - HTTP for broadcast and file server (via HttpServer) -> Spark currently uses Jetty
* for the HttpServer. Jetty supports multiple authentication mechanisms -
@@ -77,8 +82,9 @@ import org.apache.spark.network.sasl.SecretKeyHolder
* to authenticate using DIGEST-MD5 via a single user and the shared secret.
* Since we are using DIGEST-MD5, the shared secret is not passed on the wire
* in plaintext.
- * We currently do not support SSL (https), but Jetty can be configured to use it
- * so we could add a configuration option for this in the future.
+ *
+ * We currently support SSL (https) for this communication protocol (see the details
+ * below).
*
* The Spark HttpServer installs the HashLoginServer and configures it to DIGEST-MD5.
* Any clients must specify the user and password. There is a default
@@ -93,19 +99,19 @@ import org.apache.spark.network.sasl.SecretKeyHolder
* Note that SASL is pluggable as to what mechanism it uses. We currently use
* DIGEST-MD5 but this could be changed to use Kerberos or other in the future.
* Spark currently supports "auth" for the quality of protection, which means
- * the connection is not supporting integrity or privacy protection (encryption)
+ * the connection does not support integrity or privacy protection (encryption)
* after authentication. SASL also supports "auth-int" and "auth-conf" which
- * SPARK could be support in the future to allow the user to specify the quality
+ * SPARK could support in the future to allow the user to specify the quality
* of protection they want. If we support those, the messages will also have to
* be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
*
* Since the NioBlockTransferService does asynchronous messages passing, the SASL
* authentication is a bit more complex. A ConnectionManager can be both a client
- * and a Server, so for a particular connection is has to determine what to do.
+ * and a Server, so for a particular connection it has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
* match up incoming messages with connections waiting for authentication.
- * The ConnectionManager tracks all the sendingConnections using the ConnectionId
- * and waits for the response from the server and does the handshake before sending
+ * The ConnectionManager tracks all the sendingConnections using the ConnectionId,
+ * waits for the response from the server, and does the handshake before sending
* the real message.
*
* The NettyBlockTransferService ensures that SASL authentication is performed
@@ -114,14 +120,14 @@ import org.apache.spark.network.sasl.SecretKeyHolder
*
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
- * properly. For non-Yarn deployments, users can write a filter to go through a
- * companies normal login service. If an authentication filter is in place then the
+ * properly. For non-Yarn deployments, users can write a filter to go through their
+ * organization's normal login service. If an authentication filter is in place then the
* SparkUI can be configured to check the logged in user against the list of users who
* have view acls to see if that user is authorized.
* The filters can also be used for many different purposes. For instance filters
* could be used for logging, encryption, or compression.
*
- * The exact mechanisms used to generate/distributed the shared secret is deployment specific.
+ * The exact mechanisms used to generate/distribute the shared secret are deployment-specific.
*
* For Yarn deployments, the secret is automatically generated using the Akka remote
* Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed
@@ -138,21 +144,52 @@ import org.apache.spark.network.sasl.SecretKeyHolder
* All the nodes (Master and Workers) and the applications need to have the same shared secret.
* This again is not ideal as one user could potentially affect another users application.
* This should be enhanced in the future to provide better protection.
- * If the UI needs to be secured the user needs to install a javax servlet filter to do the
+ * If the UI needs to be secure, the user needs to install a javax servlet filter to do the
* authentication. Spark will then use that user to compare against the view acls to do
* authorization. If not filter is in place the user is generally null and no authorization
* can take place.
+ *
+ * Connection encryption (SSL) configuration is organized hierarchically. The user can configure
+ * the default SSL settings which will be used for all the supported communication protocols unless
+ * they are overwritten by protocol specific settings. This way the user can easily provide the
+ * common settings for all the protocols without disabling the ability to configure each one
+ * individually.
+ *
+ * All the SSL settings like `spark.ssl.xxx` where `xxx` is a particular configuration property,
+ * denote the global configuration for all the supported protocols. In order to override the global
+ * configuration for the particular protocol, the properties must be overwritten in the
+ * protocol-specific namespace. Use `spark.ssl.yyy.xxx` settings to overwrite the global
+ * configuration for particular protocol denoted by `yyy`. Currently `yyy` can be either `akka` for
+ * Akka based connections or `fs` for broadcast and file server.
+ *
+ * Refer to [[org.apache.spark.SSLOptions]] documentation for the list of
+ * options that can be specified.
+ *
+ * SecurityManager initializes SSLOptions objects for different protocols separately. SSLOptions
+ * object parses Spark configuration at a given namespace and builds the common representation
+ * of SSL settings. SSLOptions is then used to provide protocol-specific configuration like
+ * TypeSafe configuration for Akka or SSLContextFactory for Jetty.
+ *
+ * SSL must be configured on each node and configured for each component involved in
+ * communication using the particular protocol. In YARN clusters, the key-store can be prepared on
+ * the client side then distributed and used by the executors as the part of the application
+ * (YARN allows the user to deploy files before the application is started).
+ * In standalone deployment, the user needs to provide key-stores and configuration
+ * options for master and workers. In this mode, the user may allow the executors to use the SSL
+ * settings inherited from the worker which spawned that executor. It can be accomplished by
+ * setting `spark.ssl.useNodeLocalConf` to `true`.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder {
+private[spark] class SecurityManager(sparkConf: SparkConf)
+ extends Logging with SecretKeyHolder {
// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"
private val authOn = sparkConf.getBoolean("spark.authenticate", false)
// keep spark.ui.acls.enable for backwards compatibility with 1.0
- private var aclsOn = sparkConf.getOption("spark.acls.enable").getOrElse(
- sparkConf.get("spark.ui.acls.enable", "false")).toBoolean
+ private var aclsOn =
+ sparkConf.getBoolean("spark.acls.enable", sparkConf.getBoolean("spark.ui.acls.enable", false))
// admin acls should be set before view or modify acls
private var adminAcls: Set[String] =
@@ -196,6 +233,57 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with
)
}
+ // the default SSL configuration - it will be used by all communication layers unless overwritten
+ private val defaultSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl", defaults = None)
+
+ // SSL configuration for different communication layers - they can override the default
+ // configuration at a specified namespace. The namespace *must* start with spark.ssl.
+ val fileServerSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.fs", Some(defaultSSLOptions))
+ val akkaSSLOptions = SSLOptions.parse(sparkConf, "spark.ssl.akka", Some(defaultSSLOptions))
+
+ logDebug(s"SSLConfiguration for file server: $fileServerSSLOptions")
+ logDebug(s"SSLConfiguration for Akka: $akkaSSLOptions")
+
+ val (sslSocketFactory, hostnameVerifier) = if (fileServerSSLOptions.enabled) {
+ val trustStoreManagers =
+ for (trustStore <- fileServerSSLOptions.trustStore) yield {
+ val input = Files.asByteSource(fileServerSSLOptions.trustStore.get).openStream()
+
+ try {
+ val ks = KeyStore.getInstance(KeyStore.getDefaultType)
+ ks.load(input, fileServerSSLOptions.trustStorePassword.get.toCharArray)
+
+ val tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm)
+ tmf.init(ks)
+ tmf.getTrustManagers
+ } finally {
+ input.close()
+ }
+ }
+
+ lazy val credulousTrustStoreManagers = Array({
+ logWarning("Using 'accept-all' trust manager for SSL connections.")
+ new X509TrustManager {
+ override def getAcceptedIssuers: Array[X509Certificate] = null
+
+ override def checkClientTrusted(x509Certificates: Array[X509Certificate], s: String) {}
+
+ override def checkServerTrusted(x509Certificates: Array[X509Certificate], s: String) {}
+ }: TrustManager
+ })
+
+ val sslContext = SSLContext.getInstance(fileServerSSLOptions.protocol.getOrElse("Default"))
+ sslContext.init(null, trustStoreManagers.getOrElse(credulousTrustStoreManagers), null)
+
+ val hostVerifier = new HostnameVerifier {
+ override def verify(s: String, sslSession: SSLSession): Boolean = true
+ }
+
+ (Some(sslContext.getSocketFactory), Some(hostVerifier))
+ } else {
+ (None, None)
+ }
+
/**
* Split a comma separated String, filter out any empty items, and return a Set of strings
*/
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index c14764f773982..13aa9960ac33a 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -17,9 +17,13 @@
package org.apache.spark
+import java.util.concurrent.ConcurrentHashMap
+
import scala.collection.JavaConverters._
-import scala.collection.mutable.{HashMap, LinkedHashSet}
+import scala.collection.mutable.LinkedHashSet
+
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.util.Utils
/**
* Configuration for a Spark application. Used to set various Spark parameters as key-value pairs.
@@ -46,12 +50,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Create a SparkConf that loads defaults from system properties and the classpath */
def this() = this(true)
- private[spark] val settings = new HashMap[String, String]()
+ private val settings = new ConcurrentHashMap[String, String]()
if (loadDefaults) {
// Load any spark.* system properties
- for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) {
- settings(k) = v
+ for ((key, value) <- Utils.getSystemProperties if key.startsWith("spark.")) {
+ set(key, value)
}
}
@@ -63,7 +67,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
if (value == null) {
throw new NullPointerException("null value for " + key)
}
- settings(key) = value
+ settings.put(key, value)
this
}
@@ -129,15 +133,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Set multiple parameters together */
def setAll(settings: Traversable[(String, String)]) = {
- this.settings ++= settings
+ this.settings.putAll(settings.toMap.asJava)
this
}
/** Set a parameter if it isn't already configured */
def setIfMissing(key: String, value: String): SparkConf = {
- if (!settings.contains(key)) {
- settings(key) = value
- }
+ settings.putIfAbsent(key, value)
this
}
@@ -163,21 +165,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Get a parameter; throws a NoSuchElementException if it's not set */
def get(key: String): String = {
- settings.getOrElse(key, throw new NoSuchElementException(key))
+ getOption(key).getOrElse(throw new NoSuchElementException(key))
}
/** Get a parameter, falling back to a default if not set */
def get(key: String, defaultValue: String): String = {
- settings.getOrElse(key, defaultValue)
+ getOption(key).getOrElse(defaultValue)
}
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
- settings.get(key)
+ Option(settings.get(key))
}
/** Get all parameters as a list of pairs */
- def getAll: Array[(String, String)] = settings.clone().toArray
+ def getAll: Array[(String, String)] = {
+ settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray
+ }
/** Get a parameter as an integer, falling back to a default if not set */
def getInt(key: String, defaultValue: Int): Int = {
@@ -224,11 +228,11 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getAppId: String = get("spark.app.id")
/** Does the configuration contain a given parameter? */
- def contains(key: String): Boolean = settings.contains(key)
+ def contains(key: String): Boolean = settings.containsKey(key)
/** Copy this object */
override def clone: SparkConf = {
- new SparkConf(false).setAll(settings)
+ new SparkConf(false).setAll(getAll)
}
/**
@@ -240,7 +244,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Checks for illegal or deprecated config settings. Throws an exception for the former. Not
* idempotent - may mutate this conf object to convert deprecated settings to supported ones. */
private[spark] def validateSettings() {
- if (settings.contains("spark.local.dir")) {
+ if (contains("spark.local.dir")) {
val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " +
"the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)."
logWarning(msg)
@@ -265,7 +269,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
}
// Validate spark.executor.extraJavaOptions
- settings.get(executorOptsKey).map { javaOpts =>
+ getOption(executorOptsKey).map { javaOpts =>
if (javaOpts.contains("-Dspark")) {
val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " +
"Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit."
@@ -345,7 +349,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
* configuration out for debugging.
*/
def toDebugString: String = {
- settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
+ getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
}
}
@@ -366,11 +370,14 @@ private[spark] object SparkConf {
isAkkaConf(name) ||
name.startsWith("spark.akka") ||
name.startsWith("spark.auth") ||
+ name.startsWith("spark.ssl") ||
isSparkPortConf(name)
}
/**
- * Return whether the given config is a Spark port config.
+ * Return true if the given config matches either `spark.*.port` or `spark.port.*`.
*/
- def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port")
+ def isSparkPortConf(name: String): Boolean = {
+ (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.")
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index aded7c12e274e..71bdbc9b38ddb 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -20,33 +20,42 @@ package org.apache.spark
import scala.language.implicitConversions
import java.io._
+import java.lang.reflect.Constructor
import java.net.URI
import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
import java.util.UUID.randomUUID
+
import scala.collection.{Map, Set}
import scala.collection.JavaConversions._
import scala.collection.generic.Growable
import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
+
+import akka.actor.Props
+
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable, FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
-import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat, TextInputFormat}
+import org.apache.hadoop.io.{ArrayWritable, BooleanWritable, BytesWritable, DoubleWritable,
+ FloatWritable, IntWritable, LongWritable, NullWritable, Text, Writable}
+import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, SequenceFileInputFormat,
+ TextInputFormat}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob}
import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
+
import org.apache.mesos.MesosNativeLibrary
-import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
import org.apache.spark.executor.TriggerThreadDump
-import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
+import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat,
+ FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
-import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkDeploySchedulerBackend, SimrSchedulerBackend}
+import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend,
+ SparkDeploySchedulerBackend, SimrSchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage._
@@ -64,7 +73,7 @@ import org.apache.spark.util._
* @param config a Spark Config object describing the application configuration. Any settings in
* this config overrides the default configs as well as system properties.
*/
-class SparkContext(config: SparkConf) extends Logging {
+class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationClient {
// The call site where this SparkContext was constructed.
private val creationSite: CallSite = Utils.getCallSite()
@@ -85,6 +94,14 @@ class SparkContext(config: SparkConf) extends Logging {
val startTime = System.currentTimeMillis()
+ @volatile private var stopped: Boolean = false
+
+ private def assertNotStopped(): Unit = {
+ if (stopped) {
+ throw new IllegalStateException("Cannot call methods on a stopped SparkContext")
+ }
+ }
+
/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
@@ -172,6 +189,9 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] def this(master: String, appName: String, sparkHome: String, jars: Seq[String]) =
this(master, appName, sparkHome, jars, Map(), Map())
+ // log out Spark Version in Spark driver log
+ logInfo(s"Running Spark version $SPARK_VERSION")
+
private[spark] val conf = config.clone()
conf.validateSettings()
@@ -226,7 +246,7 @@ class SparkContext(config: SparkConf) extends Logging {
// An asynchronous listener bus for Spark events
private[spark] val listenerBus = new LiveListenerBus
- conf.set("spark.executor.id", "driver")
+ conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER)
// Create the Spark execution environment (cache, map output tracker, etc)
private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
@@ -268,7 +288,12 @@ class SparkContext(config: SparkConf) extends Logging {
// the bound port to the cluster manager properly
ui.foreach(_.bind())
- /** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
+ /**
+ * A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse.
+ *
+ * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you
+ * plan to set some global configurations for all Hadoop RDDs.
+ */
val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf)
// Add each JAR given through the constructor
@@ -326,8 +351,13 @@ class SparkContext(config: SparkConf) extends Logging {
try {
dagScheduler = new DAGScheduler(this)
} catch {
- case e: Exception => throw
- new SparkException("DAGScheduler cannot be initialized due to %s".format(e.getMessage))
+ case e: Exception => {
+ try {
+ stop()
+ } finally {
+ throw new SparkException("Error while constructing DAGScheduler", e)
+ }
+ }
}
// start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's
@@ -344,6 +374,8 @@ class SparkContext(config: SparkConf) extends Logging {
// The metrics system for Driver need to be set spark.app.id to app ID.
// So it should start after we get app ID from the task scheduler and set spark.app.id.
metricsSystem.start()
+ // Attach the driver metrics servlet handler to the web ui after the metrics system is started.
+ metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler)))
// Optionally log Spark events
private[spark] val eventLogger: Option[EventLoggingListener] = {
@@ -357,17 +389,18 @@ class SparkContext(config: SparkConf) extends Logging {
}
// Optionally scale number of executors dynamically based on workload. Exposed for testing.
+ private val dynamicAllocationEnabled = conf.getBoolean("spark.dynamicAllocation.enabled", false)
+ private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false)
private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] =
- if (conf.getBoolean("spark.dynamicAllocation.enabled", false)) {
- Some(new ExecutorAllocationManager(this))
+ if (dynamicAllocationEnabled) {
+ assert(master.contains("yarn") || dynamicAllocationTesting,
+ "Dynamic allocation of executors is currently only supported in YARN mode")
+ Some(new ExecutorAllocationManager(this, listenerBus, conf))
} else {
None
}
executorAllocationManager.foreach(_.start())
- // At this point, all relevant SparkListeners have been registered, so begin releasing events
- listenerBus.start()
-
private[spark] val cleaner: Option[ContextCleaner] = {
if (conf.getBoolean("spark.cleaner.referenceTracking", true)) {
Some(new ContextCleaner(this))
@@ -377,6 +410,7 @@ class SparkContext(config: SparkConf) extends Logging {
}
cleaner.foreach(_.start())
+ setupAndStartListenerBus()
postEnvironmentUpdate()
postApplicationStart()
@@ -444,7 +478,6 @@ class SparkContext(config: SparkConf) extends Logging {
Option(localProperties.get).map(_.getProperty(key)).getOrElse(null)
/** Set a human readable description of the current job. */
- @deprecated("use setJobGroup", "0.8.1")
def setJobDescription(value: String) {
setLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION, value)
}
@@ -507,12 +540,12 @@ class SparkContext(config: SparkConf) extends Logging {
/** Distribute a local Scala collection to form an RDD.
*
- * @note Parallelize acts lazily. If `seq` is a mutable collection and is
- * altered after the call to parallelize and before the first action on the
- * RDD, the resultant RDD will reflect the modified collection. Pass a copy of
- * the argument to avoid this.
+ * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
+ * to parallelize and before the first action on the RDD, the resultant RDD will reflect the
+ * modified collection. Pass a copy of the argument to avoid this.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
@@ -528,6 +561,7 @@ class SparkContext(config: SparkConf) extends Logging {
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -537,6 +571,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
+ assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
@@ -570,6 +605,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, String)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -615,6 +651,7 @@ class SparkContext(config: SparkConf) extends Logging {
@Experimental
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, PortableDataStream)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -632,6 +669,9 @@ class SparkContext(config: SparkConf) extends Logging {
*
* Load data from a flat binary file, assuming the length of each record is constant.
*
+ * '''Note:''' We ensure that the byte array for each record in the resulting RDD
+ * has the provided record length.
+ *
* @param path Directory to the input data files
* @param recordLength The length at which to split the records
* @return An RDD of data with values, represented as byte arrays
@@ -639,13 +679,18 @@ class SparkContext(config: SparkConf) extends Logging {
@Experimental
def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
: RDD[Array[Byte]] = {
+ assertNotStopped()
conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
classOf[FixedLengthBinaryInputFormat],
classOf[LongWritable],
classOf[BytesWritable],
conf=conf)
- val data = br.map{ case (k, v) => v.getBytes}
+ val data = br.map { case (k, v) =>
+ val bytes = v.getBytes
+ assert(bytes.length == recordLength, "Byte array does not have correct length")
+ bytes
+ }
data
}
@@ -654,16 +699,20 @@ class SparkContext(config: SparkConf) extends Logging {
* necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable),
* using the older MapReduce API (`org.apache.hadoop.mapred`).
*
- * @param conf JobConf for setting up the dataset
+ * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast.
+ * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make
+ * sure you won't modify the conf. A safe approach is always creating a new conf for
+ * a new RDD.
* @param inputFormatClass Class of the InputFormat
* @param keyClass Class of the keys
* @param valueClass Class of the values
* @param minPartitions Minimum number of Hadoop Splits to generate.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def hadoopRDD[K, V](
conf: JobConf,
@@ -672,18 +721,20 @@ class SparkContext(config: SparkConf) extends Logging {
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
}
/** Get an RDD for a Hadoop file with an arbitrary InputFormat
- *
- * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
- * */
+ *
+ * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
+ */
def hadoopFile[K, V](
path: String,
inputFormatClass: Class[_ <: InputFormat[K, V]],
@@ -691,6 +742,7 @@ class SparkContext(config: SparkConf) extends Logging {
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
@@ -713,9 +765,10 @@ class SparkContext(config: SparkConf) extends Logging {
* }}}
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def hadoopFile[K, V, F <: InputFormat[K, V]]
(path: String, minPartitions: Int)
@@ -736,9 +789,10 @@ class SparkContext(config: SparkConf) extends Logging {
* }}}
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] =
@@ -760,9 +814,10 @@ class SparkContext(config: SparkConf) extends Logging {
* and extra configuration options to pass to the input format.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](
path: String,
@@ -770,6 +825,9 @@ class SparkContext(config: SparkConf) extends Logging {
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ assertNotStopped()
+ // The call to new NewHadoopJob automatically adds security credentials to conf,
+ // so we don't need to explicitly add them ourselves
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -780,31 +838,46 @@ class SparkContext(config: SparkConf) extends Logging {
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*
+ * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast.
+ * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make
+ * sure you won't modify the conf. A safe approach is always creating a new conf for
+ * a new RDD.
+ * @param fClass Class of the InputFormat
+ * @param kClass Class of the keys
+ * @param vClass Class of the values
+ *
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
conf: Configuration = hadoopConfiguration,
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
- new NewHadoopRDD(this, fClass, kClass, vClass, conf)
+ assertNotStopped()
+ // Add necessary security credentials to the JobConf. Required to access secure HDFS.
+ val jconf = new JobConf(conf)
+ SparkHadoopUtil.get.addCredentials(jconf)
+ new NewHadoopRDD(this, fClass, kClass, vClass, jconf)
}
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def sequenceFile[K, V](path: String,
keyClass: Class[K],
valueClass: Class[V],
minPartitions: Int
): RDD[(K, V)] = {
+ assertNotStopped()
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
@@ -812,13 +885,15 @@ class SparkContext(config: SparkConf) extends Logging {
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
* */
- def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
- ): RDD[(K, V)] =
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
+ }
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -837,15 +912,17 @@ class SparkContext(config: SparkConf) extends Logging {
* allow it to figure out the Writable class to use in the subclass case.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
- * record, directly caching the returned RDD will create many references to the same object.
- * If you plan to directly cache Hadoop writable objects, you should first copy them using
- * a `map` function.
+ * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle
+ * operation will create many references to the same object.
+ * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first
+ * copy them using a `map` function.
*/
def sequenceFile[K, V]
(path: String, minPartitions: Int = defaultMinPartitions)
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
+ assertNotStopped()
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
@@ -867,6 +944,7 @@ class SparkContext(config: SparkConf) extends Logging {
path: String,
minPartitions: Int = defaultMinPartitions
): RDD[T] = {
+ assertNotStopped()
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader))
}
@@ -942,6 +1020,13 @@ class SparkContext(config: SparkConf) extends Logging {
* The variable will be sent to each cluster only once.
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
+ assertNotStopped()
+ if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have created RDD broadcast variables but not used them:
+ logWarning("Can not directly broadcast RDDs; instead, call collect() and "
+ + "broadcast the result (see SPARK-5063)")
+ }
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
@@ -955,12 +1040,48 @@ class SparkContext(config: SparkConf) extends Logging {
* filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
* use `SparkFiles.get(fileName)` to find its download location.
*/
- def addFile(path: String) {
+ def addFile(path: String): Unit = {
+ addFile(path, false)
+ }
+
+ /**
+ * Add a file to be downloaded with this Spark job on every node.
+ * The `path` passed can be either a local file, a file in HDFS (or other Hadoop-supported
+ * filesystems), or an HTTP, HTTPS or FTP URI. To access the file in Spark jobs,
+ * use `SparkFiles.get(fileName)` to find its download location.
+ *
+ * A directory can be given if the recursive option is set to true. Currently directories are only
+ * supported for Hadoop-supported filesystems.
+ */
+ def addFile(path: String, recursive: Boolean): Unit = {
val uri = new URI(path)
- val key = uri.getScheme match {
- case null | "file" => env.httpFileServer.addFile(new File(uri.getPath))
- case "local" => "file:" + uri.getPath
- case _ => path
+ val schemeCorrectedPath = uri.getScheme match {
+ case null | "local" => "file:" + uri.getPath
+ case _ => path
+ }
+
+ val hadoopPath = new Path(schemeCorrectedPath)
+ val scheme = new URI(schemeCorrectedPath).getScheme
+ if (!Array("http", "https", "ftp").contains(scheme)) {
+ val fs = hadoopPath.getFileSystem(hadoopConfiguration)
+ if (!fs.exists(hadoopPath)) {
+ throw new FileNotFoundException(s"Added file $hadoopPath does not exist.")
+ }
+ val isDir = fs.isDirectory(hadoopPath)
+ if (!isLocal && scheme == "file" && isDir) {
+ throw new SparkException(s"addFile does not support local directories when not running " +
+ "local mode.")
+ }
+ if (!recursive && isDir) {
+ throw new SparkException(s"Added file $hadoopPath is a directory and recursive is not " +
+ "turned on.")
+ }
+ }
+
+ val key = if (!isLocal && scheme == "file") {
+ env.httpFileServer.addFile(new File(uri.getPath))
+ } else {
+ schemeCorrectedPath
}
val timestamp = System.currentTimeMillis
addedFiles(key) = timestamp
@@ -988,7 +1109,9 @@ class SparkContext(config: SparkConf) extends Logging {
* This is currently only supported in Yarn mode. Return whether the request is received.
*/
@DeveloperApi
- def requestExecutors(numAdditionalExecutors: Int): Boolean = {
+ override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
+ assert(master.contains("yarn") || dynamicAllocationTesting,
+ "Requesting executors is currently only supported in YARN mode")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.requestExecutors(numAdditionalExecutors)
@@ -1004,7 +1127,9 @@ class SparkContext(config: SparkConf) extends Logging {
* This is currently only supported in Yarn mode. Return whether the request is received.
*/
@DeveloperApi
- def killExecutors(executorIds: Seq[String]): Boolean = {
+ override def killExecutors(executorIds: Seq[String]): Boolean = {
+ assert(master.contains("yarn") || dynamicAllocationTesting,
+ "Killing executors is currently only supported in YARN mode")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.killExecutors(executorIds)
@@ -1020,7 +1145,7 @@ class SparkContext(config: SparkConf) extends Logging {
* This is currently only supported in Yarn mode. Return whether the request is received.
*/
@DeveloperApi
- def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId))
+ override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId)
/** The version of Spark on which this application is running. */
def version = SPARK_VERSION
@@ -1030,6 +1155,7 @@ class SparkContext(config: SparkConf) extends Logging {
* memory available for caching.
*/
def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ assertNotStopped()
env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
(blockManagerId.host + ":" + blockManagerId.port, mem)
}
@@ -1042,6 +1168,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
@DeveloperApi
def getRDDStorageInfo: Array[RDDInfo] = {
+ assertNotStopped()
val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
rddInfos.filter(_.isCached)
@@ -1059,6 +1186,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
@DeveloperApi
def getExecutorStorageStatus: Array[StorageStatus] = {
+ assertNotStopped()
env.blockManager.master.getStorageStatus
}
@@ -1068,6 +1196,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
@DeveloperApi
def getAllPools: Seq[Schedulable] = {
+ assertNotStopped()
// TODO(xiajunluan): We should take nested pools into account
taskScheduler.rootPool.schedulableQueue.toSeq
}
@@ -1078,6 +1207,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
@DeveloperApi
def getPoolForName(pool: String): Option[Schedulable] = {
+ assertNotStopped()
Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
}
@@ -1085,6 +1215,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Return current scheduling mode
*/
def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ assertNotStopped()
taskScheduler.schedulingMode
}
@@ -1159,7 +1290,19 @@ class SparkContext(config: SparkConf) extends Logging {
null
}
} else {
- env.httpFileServer.addJar(new File(uri.getPath))
+ try {
+ env.httpFileServer.addJar(new File(uri.getPath))
+ } catch {
+ case exc: FileNotFoundException =>
+ logError(s"Jar not found at $path")
+ null
+ case e: Exception =>
+ // For now just log an error but allow to go through so spark examples work.
+ // The spark examples don't really need the jar distributed since its also
+ // the app jar.
+ logError("Error adding jar (" + e + "), was the --addJars option used?")
+ null
+ }
}
// A JAR file which exists locally on every worker node
case "local" =>
@@ -1190,16 +1333,14 @@ class SparkContext(config: SparkConf) extends Logging {
SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
postApplicationEnd()
ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
+ if (!stopped) {
+ stopped = true
env.metricsSystem.report()
metadataCleaner.cancel()
env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
+ dagScheduler.stop()
+ dagScheduler = null
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
@@ -1273,8 +1414,8 @@ class SparkContext(config: SparkConf) extends Logging {
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- if (dagScheduler == null) {
- throw new SparkException("SparkContext has been shutdown")
+ if (stopped) {
+ throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
@@ -1361,6 +1502,7 @@ class SparkContext(config: SparkConf) extends Logging {
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
+ assertNotStopped()
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
@@ -1383,6 +1525,7 @@ class SparkContext(config: SparkConf) extends Logging {
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
+ assertNotStopped()
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
@@ -1401,11 +1544,13 @@ class SparkContext(config: SparkConf) extends Logging {
* for more information.
*/
def cancelJobGroup(groupId: String) {
+ assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
+ assertNotStopped()
dagScheduler.cancelAllJobs()
}
@@ -1452,13 +1597,20 @@ class SparkContext(config: SparkConf) extends Logging {
def getCheckpointDir = checkpointDir
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
- def defaultParallelism: Int = taskScheduler.defaultParallelism
+ def defaultParallelism: Int = {
+ assertNotStopped()
+ taskScheduler.defaultParallelism
+ }
/** Default min number of partitions for Hadoop RDDs when not given by user */
@deprecated("use defaultMinPartitions", "1.0.0")
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
- /** Default min number of partitions for Hadoop RDDs when not given by user */
+ /**
+ * Default min number of partitions for Hadoop RDDs when not given by user
+ * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2.
+ * The reasons for this are discussed in https://github.com/mesos/spark/pull/718
+ */
def defaultMinPartitions: Int = math.min(defaultParallelism, 2)
private val nextShuffleId = new AtomicInteger(0)
@@ -1470,6 +1622,58 @@ class SparkContext(config: SparkConf) extends Logging {
/** Register a new RDD, returning its RDD ID */
private[spark] def newRddId(): Int = nextRddId.getAndIncrement()
+ /**
+ * Registers listeners specified in spark.extraListeners, then starts the listener bus.
+ * This should be called after all internal listeners have been registered with the listener bus
+ * (e.g. after the web UI and event logging listeners have been registered).
+ */
+ private def setupAndStartListenerBus(): Unit = {
+ // Use reflection to instantiate listeners specified via `spark.extraListeners`
+ try {
+ val listenerClassNames: Seq[String] =
+ conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "")
+ for (className <- listenerClassNames) {
+ // Use reflection to find the right constructor
+ val constructors = {
+ val listenerClass = Class.forName(className)
+ listenerClass.getConstructors.asInstanceOf[Array[Constructor[_ <: SparkListener]]]
+ }
+ val constructorTakingSparkConf = constructors.find { c =>
+ c.getParameterTypes.sameElements(Array(classOf[SparkConf]))
+ }
+ lazy val zeroArgumentConstructor = constructors.find { c =>
+ c.getParameterTypes.isEmpty
+ }
+ val listener: SparkListener = {
+ if (constructorTakingSparkConf.isDefined) {
+ constructorTakingSparkConf.get.newInstance(conf)
+ } else if (zeroArgumentConstructor.isDefined) {
+ zeroArgumentConstructor.get.newInstance()
+ } else {
+ throw new SparkException(
+ s"$className did not have a zero-argument constructor or a" +
+ " single-argument constructor that accepts SparkConf. Note: if the class is" +
+ " defined inside of another Scala class, then its constructors may accept an" +
+ " implicit parameter that references the enclosing class; in this case, you must" +
+ " define the listener as a top-level class in order to prevent this extra" +
+ " parameter from breaking Spark's ability to find a valid constructor.")
+ }
+ }
+ listenerBus.addListener(listener)
+ logInfo(s"Registered listener $className")
+ }
+ } catch {
+ case e: Exception =>
+ try {
+ stop()
+ } finally {
+ throw new SparkException(s"Exception when registering SparkListener", e)
+ }
+ }
+
+ listenerBus.start()
+ }
+
/** Post the application start event */
private def postApplicationStart() {
// Note: this code assumes that the task scheduler has been initialized and has contacted
@@ -1489,8 +1693,8 @@ class SparkContext(config: SparkConf) extends Logging {
val schedulingMode = getSchedulingMode.toString
val addedJarPaths = addedJars.keys.toSeq
val addedFilePaths = addedFiles.keys.toSeq
- val environmentDetails =
- SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths, addedFilePaths)
+ val environmentDetails = SparkEnv.environmentDetails(conf, schedulingMode, addedJarPaths,
+ addedFilePaths)
val environmentUpdate = SparkListenerEnvironmentUpdate(environmentDetails)
listenerBus.post(environmentUpdate)
}
@@ -1675,8 +1879,14 @@ object SparkContext extends Logging {
@deprecated("Replaced by implicit functions in the RDD companion object. This is " +
"kept here only for backward compatibility.", "1.3.0")
def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
- rdd: RDD[(K, V)]) =
+ rdd: RDD[(K, V)]) = {
+ val kf = implicitly[K => Writable]
+ val vf = implicitly[V => Writable]
+ // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it
+ implicit val keyWritableFactory = new WritableFactory[K](_ => null, kf)
+ implicit val valueWritableFactory = new WritableFactory[V](_ => null, vf)
RDD.rddToSequenceFileRDDFunctions(rdd)
+ }
@deprecated("Replaced by implicit functions in the RDD companion object. This is " +
"kept here only for backward compatibility.", "1.3.0")
@@ -1693,21 +1903,36 @@ object SparkContext extends Logging {
def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]) =
RDD.numericRDDToDoubleRDDFunctions(rdd)
- // Implicit conversions to common Writable types, for saveAsSequenceFile
+ // The following deprecated functions have already been moved to `object WritableFactory` to
+ // make the compiler find them automatically. They are still kept here for backward compatibility.
- implicit def intToIntWritable(i: Int) = new IntWritable(i)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def intToIntWritable(i: Int): IntWritable = new IntWritable(i)
- implicit def longToLongWritable(l: Long) = new LongWritable(l)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def longToLongWritable(l: Long): LongWritable = new LongWritable(l)
- implicit def floatToFloatWritable(f: Float) = new FloatWritable(f)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def floatToFloatWritable(f: Float): FloatWritable = new FloatWritable(f)
- implicit def doubleToDoubleWritable(d: Double) = new DoubleWritable(d)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def doubleToDoubleWritable(d: Double): DoubleWritable = new DoubleWritable(d)
- implicit def boolToBoolWritable (b: Boolean) = new BooleanWritable(b)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def boolToBoolWritable (b: Boolean): BooleanWritable = new BooleanWritable(b)
- implicit def bytesToBytesWritable (aob: Array[Byte]) = new BytesWritable(aob)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def bytesToBytesWritable (aob: Array[Byte]): BytesWritable = new BytesWritable(aob)
- implicit def stringToText(s: String) = new Text(s)
+ @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " +
+ "kept here only for backward compatibility.", "1.3.0")
+ implicit def stringToText(s: String): Text = new Text(s)
private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T])
: ArrayWritable = {
@@ -1885,7 +2110,7 @@ object SparkContext extends Logging {
val scheduler = new TaskSchedulerImpl(sc)
val localCluster = new LocalSparkCluster(
- numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
+ numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt, sc.conf)
val masterUrls = localCluster.start()
val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
@@ -1926,7 +2151,7 @@ object SparkContext extends Logging {
case "yarn-client" =>
val scheduler = try {
val clazz =
- Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler")
val cons = clazz.getConstructor(classOf[SparkContext])
cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl]
@@ -1996,7 +2221,7 @@ object WritableConverter {
new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
}
- // The following implicit functions were in SparkContext before 1.2 and users had to
+ // The following implicit functions were in SparkContext before 1.3 and users had to
// `import SparkContext._` to enable them. Now we move them here to make the compiler find
// them automatically. However, we still keep the old functions in SparkContext for backward
// compatibility and forward to the following functions directly.
@@ -2029,3 +2254,46 @@ object WritableConverter {
implicit def writableWritableConverter[T <: Writable](): WritableConverter[T] =
new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
}
+
+/**
+ * A class encapsulating how to convert some type T to Writable. It stores both the Writable class
+ * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion.
+ * The Writable class will be used in `SequenceFileRDDFunctions`.
+ */
+private[spark] class WritableFactory[T](
+ val writableClass: ClassTag[T] => Class[_ <: Writable],
+ val convert: T => Writable) extends Serializable
+
+object WritableFactory {
+
+ private[spark] def simpleWritableFactory[T: ClassTag, W <: Writable : ClassTag](convert: T => W)
+ : WritableFactory[T] = {
+ val writableClass = implicitly[ClassTag[W]].runtimeClass.asInstanceOf[Class[W]]
+ new WritableFactory[T](_ => writableClass, convert)
+ }
+
+ implicit def intWritableFactory: WritableFactory[Int] =
+ simpleWritableFactory(new IntWritable(_))
+
+ implicit def longWritableFactory: WritableFactory[Long] =
+ simpleWritableFactory(new LongWritable(_))
+
+ implicit def floatWritableFactory: WritableFactory[Float] =
+ simpleWritableFactory(new FloatWritable(_))
+
+ implicit def doubleWritableFactory: WritableFactory[Double] =
+ simpleWritableFactory(new DoubleWritable(_))
+
+ implicit def booleanWritableFactory: WritableFactory[Boolean] =
+ simpleWritableFactory(new BooleanWritable(_))
+
+ implicit def bytesWritableFactory: WritableFactory[Array[Byte]] =
+ simpleWritableFactory(new BytesWritable(_))
+
+ implicit def stringWritableFactory: WritableFactory[String] =
+ simpleWritableFactory(new Text(_))
+
+ implicit def writableWritableFactory[T <: Writable: ClassTag]: WritableFactory[T] =
+ simpleWritableFactory(w => w)
+
+}
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index e464b32e61dd6..f25db7f8de565 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -156,7 +156,15 @@ object SparkEnv extends Logging {
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
val hostname = conf.get("spark.driver.host")
val port = conf.get("spark.driver.port").toInt
- create(conf, SparkContext.DRIVER_IDENTIFIER, hostname, port, true, isLocal, listenerBus)
+ create(
+ conf,
+ SparkContext.DRIVER_IDENTIFIER,
+ hostname,
+ port,
+ isDriver = true,
+ isLocal = isLocal,
+ listenerBus = listenerBus
+ )
}
/**
@@ -169,10 +177,18 @@ object SparkEnv extends Logging {
hostname: String,
port: Int,
numCores: Int,
- isLocal: Boolean,
- actorSystem: ActorSystem = null): SparkEnv = {
- create(conf, executorId, hostname, port, false, isLocal, defaultActorSystem = actorSystem,
- numUsableCores = numCores)
+ isLocal: Boolean): SparkEnv = {
+ val env = create(
+ conf,
+ executorId,
+ hostname,
+ port,
+ isDriver = false,
+ isLocal = isLocal,
+ numUsableCores = numCores
+ )
+ SparkEnv.set(env)
+ env
}
/**
@@ -186,7 +202,6 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean,
listenerBus: LiveListenerBus = null,
- defaultActorSystem: ActorSystem = null,
numUsableCores: Int = 0): SparkEnv = {
// Listener bus is only used on the driver
@@ -196,20 +211,17 @@ object SparkEnv extends Logging {
val securityManager = new SecurityManager(conf)
- // If an existing actor system is already provided, use it.
- // This is the case when an executor is launched in coarse-grained mode.
- val (actorSystem, boundPort) =
- Option(defaultActorSystem) match {
- case Some(as) => (as, port)
- case None =>
- val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
- AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)
- }
+ // Create the ActorSystem for Akka and get the port it binds to.
+ val (actorSystem, boundPort) = {
+ val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
+ AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)
+ }
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
- // This is so that we tell the executors the correct port to connect to.
if (isDriver) {
conf.set("spark.driver.port", boundPort.toString)
+ } else {
+ conf.set("spark.executor.port", boundPort.toString)
}
// Create an instance of the class with the given name, possibly initializing it with our conf
@@ -300,7 +312,7 @@ object SparkEnv extends Logging {
val httpFileServer =
if (isDriver) {
val fileServerPort = conf.getInt("spark.fileserver.port", 0)
- val server = new HttpFileServer(securityManager, fileServerPort)
+ val server = new HttpFileServer(conf, securityManager, fileServerPort)
server.initialize()
conf.set("spark.fileserver.uri", server.serverUri)
server
@@ -314,6 +326,10 @@ object SparkEnv extends Logging {
// Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
+ // We need to set the executor ID before the MetricsSystem is created because sources and
+ // sinks specified in the metrics configuration file will want to incorporate this executor's
+ // ID into the metrics they report.
+ conf.set("spark.executor.id", executorId)
val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
ms.start()
ms
@@ -323,7 +339,7 @@ object SparkEnv extends Logging {
// this is a temporary directory; in distributed mode, this is the executor's current working
// directory.
val sparkFilesDir: String = if (isDriver) {
- Utils.createTempDir().getAbsolutePath
+ Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
} else {
"."
}
@@ -383,7 +399,7 @@ object SparkEnv extends Logging {
val sparkProperties = (conf.getAll ++ schedulerMode).sorted
// System properties that are not java classpaths
- val systemProperties = System.getProperties.iterator.toSeq
+ val systemProperties = Utils.getSystemProperties.toSeq
val otherProperties = systemProperties.filter { case (k, _) =>
k != "java.class.path" && !k.startsWith("spark.")
}.sorted
diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala
new file mode 100644
index 0000000000000..7d7fe1a446313
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/TaskContext.scala
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.Serializable
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.util.TaskCompletionListener
+
+
+object TaskContext {
+ /**
+ * Return the currently active TaskContext. This can be called inside of
+ * user functions to access contextual information about running tasks.
+ */
+ def get(): TaskContext = taskContext.get
+
+ private val taskContext: ThreadLocal[TaskContext] = new ThreadLocal[TaskContext]
+
+ // Note: protected[spark] instead of private[spark] to prevent the following two from
+ // showing up in JavaDoc.
+ /**
+ * Set the thread local TaskContext. Internal to Spark.
+ */
+ protected[spark] def setTaskContext(tc: TaskContext): Unit = taskContext.set(tc)
+
+ /**
+ * Unset the thread local TaskContext. Internal to Spark.
+ */
+ protected[spark] def unset(): Unit = taskContext.remove()
+}
+
+
+/**
+ * Contextual information about a task which can be read or mutated during
+ * execution. To access the TaskContext for a running task, use:
+ * {{{
+ * org.apache.spark.TaskContext.get()
+ * }}}
+ */
+abstract class TaskContext extends Serializable {
+ // Note: TaskContext must NOT define a get method. Otherwise it will prevent the Scala compiler
+ // from generating a static get method (based on the companion object's get method).
+
+ // Note: Update JavaTaskContextCompileCheck when new methods are added to this class.
+
+ // Note: getters in this class are defined with parentheses to maintain backward compatibility.
+
+ /**
+ * Returns true if the task has completed.
+ */
+ def isCompleted(): Boolean
+
+ /**
+ * Returns true if the task has been killed.
+ */
+ def isInterrupted(): Boolean
+
+ @deprecated("use isRunningLocally", "1.2.0")
+ def runningLocally(): Boolean
+
+ /**
+ * Returns true if the task is running locally in the driver program.
+ * @return
+ */
+ def isRunningLocally(): Boolean
+
+ /**
+ * Adds a (Java friendly) listener to be executed on task completion.
+ * This will be called in all situation - success, failure, or cancellation.
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ def addTaskCompletionListener(listener: TaskCompletionListener): TaskContext
+
+ /**
+ * Adds a listener in the form of a Scala closure to be executed on task completion.
+ * This will be called in all situations - success, failure, or cancellation.
+ * An example use is for HadoopRDD to register a callback to close the input stream.
+ */
+ def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext
+
+ /**
+ * Adds a callback function to be executed on task completion. An example use
+ * is for HadoopRDD to register a callback to close the input stream.
+ * Will be called in any situation - success, failure, or cancellation.
+ *
+ * @param f Callback function.
+ */
+ @deprecated("use addTaskCompletionListener", "1.2.0")
+ def addOnCompleteCallback(f: () => Unit)
+
+ /**
+ * The ID of the stage that this task belong to.
+ */
+ def stageId(): Int
+
+ /**
+ * The ID of the RDD partition that is computed by this task.
+ */
+ def partitionId(): Int
+
+ /**
+ * How many times this task has been attempted. The first task attempt will be assigned
+ * attemptNumber = 0, and subsequent attempts will have increasing attempt numbers.
+ */
+ def attemptNumber(): Int
+
+ @deprecated("use attemptNumber", "1.3.0")
+ def attemptId(): Long
+
+ /**
+ * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts
+ * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID.
+ */
+ def taskAttemptId(): Long
+
+ /** ::DeveloperApi:: */
+ @DeveloperApi
+ def taskMetrics(): TaskMetrics
+}
diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
index afd2b85d33a77..337c8e4ebebcd 100644
--- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
+++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala
@@ -22,14 +22,19 @@ import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerExce
import scala.collection.mutable.ArrayBuffer
-private[spark] class TaskContextImpl(val stageId: Int,
+private[spark] class TaskContextImpl(
+ val stageId: Int,
val partitionId: Int,
- val attemptId: Long,
+ override val taskAttemptId: Long,
+ override val attemptNumber: Int,
val runningLocally: Boolean = false,
val taskMetrics: TaskMetrics = TaskMetrics.empty)
extends TaskContext
with Logging {
+ // For backwards-compatibility; this method is now deprecated as of 1.3.0.
+ override def attemptId(): Long = taskAttemptId
+
// List of callback functions to execute when the task completes.
@transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener]
@@ -82,10 +87,10 @@ private[spark] class TaskContextImpl(val stageId: Int,
interrupted = true
}
- override def isCompleted: Boolean = completed
+ override def isCompleted(): Boolean = completed
- override def isRunningLocally: Boolean = runningLocally
+ override def isRunningLocally(): Boolean = runningLocally
- override def isInterrupted: Boolean = interrupted
+ override def isInterrupted(): Boolean = interrupted
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/NullType.java b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
similarity index 76%
rename from sql/core/src/main/java/org/apache/spark/sql/api/java/NullType.java
rename to core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
index 6d5ecdf46e551..9df61062e1f85 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/api/java/NullType.java
+++ b/core/src/main/scala/org/apache/spark/TaskNotSerializableException.scala
@@ -15,13 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.sql.api.java;
+package org.apache.spark
+
+import org.apache.spark.annotation.DeveloperApi
/**
- * The data type representing null and NULL values.
- *
- * {@code NullType} is represented by the singleton object {@link DataType#NullType}.
+ * Exception thrown when a task cannot be serialized.
*/
-public class NullType extends DataType {
- protected NullType() {}
-}
+private[spark] class TaskNotSerializableException(error: Throwable) extends Exception(error)
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index 34078142f5385..be081c3825566 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -43,11 +43,20 @@ private[spark] object TestUtils {
* Note: if this is used during class loader tests, class names should be unique
* in order to avoid interference between tests.
*/
- def createJarWithClasses(classNames: Seq[String], value: String = ""): URL = {
+ def createJarWithClasses(
+ classNames: Seq[String],
+ toStringValue: String = "",
+ classNamesWithBase: Seq[(String, String)] = Seq(),
+ classpathUrls: Seq[URL] = Seq()): URL = {
val tempDir = Utils.createTempDir()
- val files = for (name <- classNames) yield createCompiledClass(name, tempDir, value)
+ val files1 = for (name <- classNames) yield {
+ createCompiledClass(name, tempDir, toStringValue, classpathUrls = classpathUrls)
+ }
+ val files2 = for ((childName, baseName) <- classNamesWithBase) yield {
+ createCompiledClass(childName, tempDir, toStringValue, baseName, classpathUrls)
+ }
val jarFile = new File(tempDir, "testJar-%s.jar".format(System.currentTimeMillis()))
- createJar(files, jarFile)
+ createJar(files1 ++ files2, jarFile)
}
@@ -85,15 +94,26 @@ private[spark] object TestUtils {
}
/** Creates a compiled class with the given name. Class file will be placed in destDir. */
- def createCompiledClass(className: String, destDir: File, value: String = ""): File = {
+ def createCompiledClass(
+ className: String,
+ destDir: File,
+ toStringValue: String = "",
+ baseClass: String = null,
+ classpathUrls: Seq[URL] = Seq()): File = {
val compiler = ToolProvider.getSystemJavaCompiler
+ val extendsText = Option(baseClass).map { c => s" extends ${c}" }.getOrElse("")
val sourceFile = new JavaSourceFromString(className,
- "public class " + className + " implements java.io.Serializable {" +
- " @Override public String toString() { return \"" + value + "\"; }}")
+ "public class " + className + extendsText + " implements java.io.Serializable {" +
+ " @Override public String toString() { return \"" + toStringValue + "\"; }}")
// Calling this outputs a class file in pwd. It's easier to just rename the file than
// build a custom FileManager that controls the output location.
- compiler.getTask(null, null, null, null, null, Seq(sourceFile)).call()
+ val options = if (classpathUrls.nonEmpty) {
+ Seq("-classpath", classpathUrls.map { _.getFile }.mkString(File.pathSeparator))
+ } else {
+ Seq()
+ }
+ compiler.getTask(null, null, null, options, null, Seq(sourceFile)).call()
val fileName = className + ".class"
val result = new File(fileName)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index bd451634e53d2..0f91c942ecd50 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -38,6 +38,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
+/**
+ * Defines operations common to several Java RDD implementations.
+ * Note that this trait is not intended to be implemented by user code.
+ */
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
@@ -344,6 +348,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
+ /**
+ * Reduces the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree
+ * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]]
+ */
+ def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth)
+
+ /**
+ * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2.
+ */
+ def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2)
+
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
@@ -365,6 +382,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
combOp: JFunction2[U, U, U]): U =
rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U])
+ /**
+ * Aggregates the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree
+ * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]]
+ */
+ def treeAggregate[U](
+ zeroValue: U,
+ seqOp: JFunction2[U, T, U],
+ combOp: JFunction2[U, U, U],
+ depth: Int): U = {
+ rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U])
+ }
+
+ /**
+ * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2.
+ */
+ def treeAggregate[U](
+ zeroValue: U,
+ seqOp: JFunction2[U, T, U],
+ combOp: JFunction2[U, U, U]): U = {
+ treeAggregate(zeroValue, seqOp, combOp, 2)
+ }
+
/**
* Return the number of elements in the RDD.
*/
@@ -435,6 +476,12 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
*/
def first(): T = rdd.first()
+ /**
+ * @return true if and only if the RDD contains no elements at all. Note that an RDD
+ * may be empty even when it has at least 1 partition.
+ */
+ def isEmpty(): Boolean = rdd.isEmpty()
+
/**
* Save this RDD as a text file, using string representations of elements.
*/
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 97f5c9f257e09..6d6ed693be752 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -373,6 +373,15 @@ class JavaSparkContext(val sc: SparkContext)
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
* etc).
*
+ * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast.
+ * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make
+ * sure you won't modify the conf. A safe approach is always creating a new conf for
+ * a new RDD.
+ * @param inputFormatClass Class of the InputFormat
+ * @param keyClass Class of the keys
+ * @param valueClass Class of the values
+ * @param minPartitions Minimum number of Hadoop Splits to generate.
+ *
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
* record, directly caching the returned RDD will create many references to the same object.
* If you plan to directly cache Hadoop writable objects, you should first copy them using
@@ -395,6 +404,14 @@ class JavaSparkContext(val sc: SparkContext)
* Get an RDD for a Hadoop-readable dataset from a Hadooop JobConf giving its InputFormat and any
* other necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable,
*
+ * @param conf JobConf for setting up the dataset. Note: This will be put into a Broadcast.
+ * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make
+ * sure you won't modify the conf. A safe approach is always creating a new conf for
+ * a new RDD.
+ * @param inputFormatClass Class of the InputFormat
+ * @param keyClass Class of the keys
+ * @param valueClass Class of the values
+ *
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
* record, directly caching the returned RDD will create many references to the same object.
* If you plan to directly cache Hadoop writable objects, you should first copy them using
@@ -476,6 +493,14 @@ class JavaSparkContext(val sc: SparkContext)
* Get an RDD for a given Hadoop file with an arbitrary new API InputFormat
* and extra configuration options to pass to the input format.
*
+ * @param conf Configuration for setting up the dataset. Note: This will be put into a Broadcast.
+ * Therefore if you plan to reuse this conf to create multiple RDDs, you need to make
+ * sure you won't modify the conf. A safe approach is always creating a new conf for
+ * a new RDD.
+ * @param fClass Class of the InputFormat
+ * @param kClass Class of the keys
+ * @param vClass Class of the values
+ *
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
* record, directly caching the returned RDD will create many references to the same object.
* If you plan to directly cache Hadoop writable objects, you should first copy them using
@@ -675,6 +700,9 @@ class JavaSparkContext(val sc: SparkContext)
/**
* Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse.
+ *
+ * '''Note:''' As it will be reused in all Hadoop RDDs, it's better not to modify it unless you
+ * plan to set some global configurations for all Hadoop RDDs.
*/
def hadoopConfiguration(): Configuration = {
sc.hadoopConfiguration
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index b52d0a5028e84..71b26737b8c02 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,7 +19,8 @@ package org.apache.spark.api.java
import com.google.common.base.Optional
-import scala.collection.convert.Wrappers.MapWrapper
+import java.{util => ju}
+import scala.collection.mutable
private[spark] object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
@@ -32,7 +33,64 @@ private[spark] object JavaUtils {
def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
new SerializableMapWrapper(underlying)
+ // Implementation is copied from scala.collection.convert.Wrappers.MapWrapper,
+ // but implements java.io.Serializable. It can't just be subclassed to make it
+ // Serializable since the MapWrapper class has no no-arg constructor. This class
+ // doesn't need a no-arg constructor though.
class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
- extends MapWrapper(underlying) with java.io.Serializable
+ extends ju.AbstractMap[A, B] with java.io.Serializable { self =>
+ override def size = underlying.size
+
+ override def get(key: AnyRef): B = try {
+ underlying get key.asInstanceOf[A] match {
+ case None => null.asInstanceOf[B]
+ case Some(v) => v
+ }
+ } catch {
+ case ex: ClassCastException => null.asInstanceOf[B]
+ }
+
+ override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] {
+ def size = self.size
+
+ def iterator = new ju.Iterator[ju.Map.Entry[A, B]] {
+ val ui = underlying.iterator
+ var prev : Option[A] = None
+
+ def hasNext = ui.hasNext
+
+ def next() = {
+ val (k, v) = ui.next
+ prev = Some(k)
+ new ju.Map.Entry[A, B] {
+ import scala.util.hashing.byteswap32
+ def getKey = k
+ def getValue = v
+ def setValue(v1 : B) = self.put(k, v1)
+ override def hashCode = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16)
+ override def equals(other: Any) = other match {
+ case e: ju.Map.Entry[_, _] => k == e.getKey && v == e.getValue
+ case _ => false
+ }
+ }
+ }
+
+ def remove() {
+ prev match {
+ case Some(k) =>
+ underlying match {
+ case mm: mutable.Map[A, _] =>
+ mm remove k
+ prev = None
+ case _ =>
+ throw new UnsupportedOperationException("remove")
+ }
+ case _ =>
+ throw new IllegalStateException("next must be called at least once before remove")
+ }
+ }
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index 5ba66178e2b78..c9181a29d4756 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -138,6 +138,11 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] {
mapWritable.put(convertToWritable(k), convertToWritable(v))
}
mapWritable
+ case array: Array[Any] => {
+ val arrayWriteable = new ArrayWritable(classOf[Writable])
+ arrayWriteable.set(array.map(convertToWritable(_)))
+ arrayWriteable
+ }
case other => throw new SparkException(
s"Data of type ${other.getClass.getName} cannot be used")
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index bad40e6529f74..b89effc16d36d 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -67,17 +67,16 @@ private[spark] class PythonRDD(
envVars += ("SPARK_REUSE_WORKER" -> "1")
}
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
+ // Whether is the worker released into idle pool
+ @volatile var released = false
// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
- var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
writerThread.join()
- if (reuse_worker && complete_cleanly) {
- env.releasePythonWorker(pythonExec, envVars.toMap, worker)
- } else {
+ if (!reuse_worker || !released) {
try {
worker.close()
} catch {
@@ -125,8 +124,8 @@ private[spark] class PythonRDD(
init, finish))
val memoryBytesSpilled = stream.readLong()
val diskBytesSpilled = stream.readLong()
- context.taskMetrics.memoryBytesSpilled += memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
read()
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
@@ -145,8 +144,12 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
+ // Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
- complete_cleanly = true
+ if (reuse_worker) {
+ env.releasePythonWorker(pythonExec, envVars.toMap, worker)
+ released = true
+ }
}
null
}
@@ -313,6 +316,7 @@ private object SpecialLengths {
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
val END_OF_STREAM = -4
+ val NULL = -5
}
private[spark] object PythonRDD extends Logging {
@@ -371,54 +375,25 @@ private[spark] object PythonRDD extends Logging {
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
- // The right way to implement this would be to use TypeTags to get the full
- // type of T. Since I don't want to introduce breaking changes throughout the
- // entire Spark API, I have to use this hacky approach:
- if (iter.hasNext) {
- val first = iter.next()
- val newIter = Seq(first).iterator ++ iter
- first match {
- case arr: Array[Byte] =>
- newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes =>
- dataOut.writeInt(bytes.length)
- dataOut.write(bytes)
- }
- case string: String =>
- newIter.asInstanceOf[Iterator[String]].foreach { str =>
- writeUTF(str, dataOut)
- }
- case stream: PortableDataStream =>
- newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
- val bytes = stream.toArray()
- dataOut.writeInt(bytes.length)
- dataOut.write(bytes)
- }
- case (key: String, stream: PortableDataStream) =>
- newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
- case (key, stream) =>
- writeUTF(key, dataOut)
- val bytes = stream.toArray()
- dataOut.writeInt(bytes.length)
- dataOut.write(bytes)
- }
- case (key: String, value: String) =>
- newIter.asInstanceOf[Iterator[(String, String)]].foreach {
- case (key, value) =>
- writeUTF(key, dataOut)
- writeUTF(value, dataOut)
- }
- case (key: Array[Byte], value: Array[Byte]) =>
- newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
- case (key, value) =>
- dataOut.writeInt(key.length)
- dataOut.write(key)
- dataOut.writeInt(value.length)
- dataOut.write(value)
- }
- case other =>
- throw new SparkException("Unexpected element type " + first.getClass)
- }
+
+ def write(obj: Any): Unit = obj match {
+ case null =>
+ dataOut.writeInt(SpecialLengths.NULL)
+ case arr: Array[Byte] =>
+ dataOut.writeInt(arr.length)
+ dataOut.write(arr)
+ case str: String =>
+ writeUTF(str, dataOut)
+ case stream: PortableDataStream =>
+ write(stream.toArray())
+ case (key, value) =>
+ write(key)
+ write(value)
+ case other =>
+ throw new SparkException("Unexpected element type " + other.getClass)
}
+
+ iter.foreach(write)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
index be5ebfa9219d3..acbaba6791850 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala
@@ -17,11 +17,14 @@
package org.apache.spark.api.python
-import java.io.{File, InputStream, IOException, OutputStream}
+import java.io.{File}
+import java.util.{List => JList}
+import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.SparkContext
+import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
private[spark] object PythonUtils {
/** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */
@@ -39,4 +42,15 @@ private[spark] object PythonUtils {
def mergePythonPaths(paths: String*): String = {
paths.filter(_ != "").mkString(File.pathSeparator)
}
+
+ def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = {
+ sc.parallelize(List("a", null, "b"))
+ }
+
+ /**
+ * Convert list of T into seq of T (for calling API with varargs)
+ */
+ def toSeq[T](cols: JList[T]): Seq[T] = {
+ cols.toList.toSeq
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index a4153aaa926f8..fb52a960e0765 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -153,7 +153,10 @@ private[spark] object SerDeUtil extends Logging {
iter.flatMap { row =>
val obj = unpickle.loads(row)
if (batched) {
- obj.asInstanceOf[JArrayList[_]].asScala
+ obj match {
+ case array: Array[Any] => array.toSeq
+ case _ => obj.asInstanceOf[JArrayList[_]].asScala
+ }
} else {
Seq(obj)
}
@@ -199,7 +202,10 @@ private[spark] object SerDeUtil extends Logging {
* representation is serialized
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
- val (keyFailed, valueFailed) = checkPickle(rdd.first())
+ val (keyFailed, valueFailed) = rdd.take(1) match {
+ case Array() => (false, false)
+ case Array(first) => checkPickle(first)
+ }
rdd.mapPartitions { iter =>
val cleaned = iter.map { case (k, v) =>
@@ -226,10 +232,12 @@ private[spark] object SerDeUtil extends Logging {
}
val rdd = pythonToJava(pyRDD, batched).rdd
- rdd.first match {
- case obj if isPair(obj) =>
+ rdd.take(1) match {
+ case Array(obj) if isPair(obj) =>
// we only accept (K, V)
- case other => throw new SparkException(
+ case Array() =>
+ // we also accept empty collections
+ case Array(other) => throw new SparkException(
s"RDD element of type ${other.getClass.getName} cannot be used")
}
rdd.map { obj =>
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
index c0cbd28a845be..cf289fb3ae39f 100644
--- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -107,7 +107,6 @@ private[python] class WritableToDoubleArrayConverter extends Converter[Any, Arra
* given directory (probably a temp directory)
*/
object WriteInputFormatTestDataGenerator {
- import SparkContext._
def main(args: Array[String]) {
val path = args(0)
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 31f0a462f84d8..1444c0dd3d2d6 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -151,9 +151,10 @@ private[broadcast] object HttpBroadcast extends Logging {
}
private def createServer(conf: SparkConf) {
- broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast")
val broadcastPort = conf.getInt("spark.broadcast.port", 0)
- server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
+ server =
+ new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
@@ -198,6 +199,7 @@ private[broadcast] object HttpBroadcast extends Logging {
uc = new URL(url).openConnection()
uc.setConnectTimeout(httpReadTimeout)
}
+ Utils.setupSecureURLConnection(uc, securityManager)
val in = {
uc.setReadTimeout(httpReadTimeout)
diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
index 65a1a8fd7e929..ae55b4ff40b74 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala
@@ -28,5 +28,14 @@ private[spark] class ApplicationDescription(
val user = System.getProperty("user.name", "")
+ def copy(
+ name: String = name,
+ maxCores: Option[Int] = maxCores,
+ memoryPerSlave: Int = memoryPerSlave,
+ command: Command = command,
+ appUiUrl: String = appUiUrl,
+ eventLogDir: Option[String] = eventLogDir): ApplicationDescription =
+ new ApplicationDescription(name, maxCores, memoryPerSlave, command, appUiUrl, eventLogDir)
+
override def toString: String = "ApplicationDescription(" + name + ")"
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index f2687ce6b42b4..38b3da0b13756 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -39,7 +39,8 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
val timeout = AkkaUtils.askTimeout(conf)
override def preStart() = {
- masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master))
+ masterActor = context.actorSelection(
+ Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(context.system)))
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
@@ -160,6 +161,8 @@ object Client {
val (actorSystem, _) = AkkaUtils.createActorSystem(
"driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf))
+ // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely
+ Master.toAkkaUrl(driverArgs.master, AkkaUtils.protocol(actorSystem))
actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf))
actorSystem.awaitTermination()
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 2e1e52906ceeb..415bd50591692 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -23,14 +23,13 @@ import scala.collection.mutable.ListBuffer
import org.apache.log4j.Level
-import org.apache.spark.util.MemoryParam
+import org.apache.spark.util.{IntParam, MemoryParam}
/**
* Command-line parser for the driver client.
*/
private[spark] class ClientArguments(args: Array[String]) {
- val defaultCores = 1
- val defaultMemory = 512
+ import ClientArguments._
var cmd: String = "" // 'launch' or 'kill'
var logLevel = Level.WARN
@@ -39,9 +38,9 @@ private[spark] class ClientArguments(args: Array[String]) {
var master: String = ""
var jarUrl: String = ""
var mainClass: String = ""
- var supervise: Boolean = false
- var memory: Int = defaultMemory
- var cores: Int = defaultCores
+ var supervise: Boolean = DEFAULT_SUPERVISE
+ var memory: Int = DEFAULT_MEMORY
+ var cores: Int = DEFAULT_CORES
private var _driverOptions = ListBuffer[String]()
def driverOptions = _driverOptions.toSeq
@@ -50,9 +49,9 @@ private[spark] class ClientArguments(args: Array[String]) {
parse(args.toList)
- def parse(args: List[String]): Unit = args match {
- case ("--cores" | "-c") :: value :: tail =>
- cores = value.toInt
+ private def parse(args: List[String]): Unit = args match {
+ case ("--cores" | "-c") :: IntParam(value) :: tail =>
+ cores = value
parse(tail)
case ("--memory" | "-m") :: MemoryParam(value) :: tail =>
@@ -106,9 +105,10 @@ private[spark] class ClientArguments(args: Array[String]) {
|Usage: DriverClient kill
|
|Options:
- | -c CORES, --cores CORES Number of cores to request (default: $defaultCores)
- | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $defaultMemory)
+ | -c CORES, --cores CORES Number of cores to request (default: $DEFAULT_CORES)
+ | -m MEMORY, --memory MEMORY Megabytes of memory to request (default: $DEFAULT_MEMORY)
| -s, --supervise Whether to restart the driver on failure
+ | (default: $DEFAULT_SUPERVISE)
| -v, --verbose Print more debugging output
""".stripMargin
System.err.println(usage)
@@ -117,6 +117,10 @@ private[spark] class ClientArguments(args: Array[String]) {
}
object ClientArguments {
+ private[spark] val DEFAULT_CORES = 1
+ private[spark] val DEFAULT_MEMORY = 512 // MB
+ private[spark] val DEFAULT_SUPERVISE = false
+
def isValidJarUrl(s: String): Boolean = {
try {
val uri = new URI(s)
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index c46f84de8444a..7f600d89604a2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -88,6 +88,8 @@ private[deploy] object DeployMessages {
case class KillDriver(driverId: String) extends DeployMessage
+ case class ApplicationFinished(id: String)
+
// Worker internal
case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders
@@ -146,15 +148,22 @@ private[deploy] object DeployMessages {
// Master to MasterWebUI
- case class MasterStateResponse(host: String, port: Int, workers: Array[WorkerInfo],
- activeApps: Array[ApplicationInfo], completedApps: Array[ApplicationInfo],
- activeDrivers: Array[DriverInfo], completedDrivers: Array[DriverInfo],
- status: MasterState) {
+ case class MasterStateResponse(
+ host: String,
+ port: Int,
+ restPort: Option[Int],
+ workers: Array[WorkerInfo],
+ activeApps: Array[ApplicationInfo],
+ completedApps: Array[ApplicationInfo],
+ activeDrivers: Array[DriverInfo],
+ completedDrivers: Array[DriverInfo],
+ status: MasterState) {
Utils.checkHost(host, "Required hostname")
assert (port > 0)
def uri = "spark://" + host + ":" + port
+ def restUri: Option[String] = restPort.map { p => "spark://" + host + ":" + p }
}
// WorkerWebUI to Worker
@@ -175,4 +184,5 @@ private[deploy] object DeployMessages {
// Liveness checks in various places
case object SendHeartbeat
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala
index 58c95dc4f9116..b056a19ce6598 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DriverDescription.scala
@@ -25,5 +25,13 @@ private[spark] class DriverDescription(
val command: Command)
extends Serializable {
+ def copy(
+ jarUrl: String = jarUrl,
+ mem: Int = mem,
+ cores: Int = cores,
+ supervise: Boolean = supervise,
+ command: Command = command): DriverDescription =
+ new DriverDescription(jarUrl, mem, cores, supervise, command)
+
override def toString: String = s"DriverDescription (${command.mainClass})"
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index 9a7a113c95715..0401b15446a7b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -33,7 +33,11 @@ import org.apache.spark.util.Utils
* fault recovery without spinning up a lot of processes.
*/
private[spark]
-class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int)
+class LocalSparkCluster(
+ numWorkers: Int,
+ coresPerWorker: Int,
+ memoryPerWorker: Int,
+ conf: SparkConf)
extends Logging {
private val localHostname = Utils.localHostName()
@@ -43,9 +47,11 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
def start(): Array[String] = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
+ // Disable REST server on Master in this mode unless otherwise specified
+ val _conf = conf.clone().setIfMissing("spark.master.rest.enabled", "false")
+
/* Start the Master */
- val conf = new SparkConf(false)
- val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf)
+ val (masterSystem, masterPort, _, _) = Master.startSystemAndActor(localHostname, 0, 0, _conf)
masterActorSystems += masterSystem
val masterUrl = "spark://" + localHostname + ":" + masterPort
val masters = Array(masterUrl)
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index 039c8719e2867..53e18c4bcec23 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -26,7 +26,7 @@ import org.apache.spark.api.python.PythonUtils
import org.apache.spark.util.{RedirectThread, Utils}
/**
- * A main class used by spark-submit to launch Python applications. It executes python as a
+ * A main class used to launch Python applications. It executes python as a
* subprocess and then has it connect back to the JVM to access system properties, etc.
*/
object PythonRunner {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 60ee115e393ce..03238e9fa0088 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -21,9 +21,10 @@ import java.lang.reflect.Method
import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.{FileStatus, FileSystem, Path}
import org.apache.hadoop.fs.FileSystem.Statistics
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
@@ -132,16 +133,15 @@ class SparkHadoopUtil extends Logging {
* statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
* Returns None if the required method can't be found.
*/
- private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration)
- : Option[() => Long] = {
+ private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = {
try {
- val threadStats = getFileSystemThreadStatistics(path, conf)
+ val threadStats = getFileSystemThreadStatistics()
val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesRead = f()
Some(() => f() - baselineBytesRead)
} catch {
- case e: NoSuchMethodException => {
+ case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => {
logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e)
None
}
@@ -155,26 +155,23 @@ class SparkHadoopUtil extends Logging {
* statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
* Returns None if the required method can't be found.
*/
- private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration)
- : Option[() => Long] = {
+ private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = {
try {
- val threadStats = getFileSystemThreadStatistics(path, conf)
+ val threadStats = getFileSystemThreadStatistics()
val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesWritten = f()
Some(() => f() - baselineBytesWritten)
} catch {
- case e: NoSuchMethodException => {
+ case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => {
logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e)
None
}
}
}
- private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = {
- val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
- val scheme = qualifiedPath.toUri().getScheme()
- val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
+ private def getFileSystemThreadStatistics(): Seq[AnyRef] = {
+ val stats = FileSystem.getAllStatistics()
stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
}
@@ -183,6 +180,32 @@ class SparkHadoopUtil extends Logging {
Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
statisticsDataClass.getDeclaredMethod(methodName)
}
+
+ /**
+ * Using reflection to get the Configuration from JobContext/TaskAttemptContext. If we directly
+ * call `JobContext/TaskAttemptContext.getConfiguration`, it will generate different byte codes
+ * for Hadoop 1.+ and Hadoop 2.+ because JobContext/TaskAttemptContext is class in Hadoop 1.+
+ * while it's interface in Hadoop 2.+.
+ */
+ def getConfigurationFromJobContext(context: JobContext): Configuration = {
+ val method = context.getClass.getMethod("getConfiguration")
+ method.invoke(context).asInstanceOf[Configuration]
+ }
+
+ /**
+ * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the
+ * given path points to a file, return a single-element collection containing [[FileStatus]] of
+ * that file.
+ */
+ def listLeafStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = {
+ def recurse(path: Path) = {
+ val (directories, leaves) = fs.listStatus(path).partition(_.isDir)
+ leaves ++ directories.flatMap(f => listLeafStatuses(fs, f.getPath))
+ }
+
+ val baseStatus = fs.getFileStatus(basePath)
+ if (baseStatus.isDir) recurse(basePath) else Array(baseStatus)
+ }
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 955cbd6dab96d..6d213926f3d7b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -18,14 +18,36 @@
package org.apache.spark.deploy
import java.io.{File, PrintStream}
-import java.lang.reflect.{Modifier, InvocationTargetException}
+import java.lang.reflect.{InvocationTargetException, Modifier}
import java.net.URL
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
-import org.apache.spark.executor.ExecutorURLClassLoader
+import org.apache.hadoop.fs.Path
+import org.apache.ivy.Ivy
+import org.apache.ivy.core.LogOptions
+import org.apache.ivy.core.module.descriptor._
+import org.apache.ivy.core.module.id.{ArtifactId, ModuleId, ModuleRevisionId}
+import org.apache.ivy.core.report.ResolveReport
+import org.apache.ivy.core.resolve.ResolveOptions
+import org.apache.ivy.core.retrieve.RetrieveOptions
+import org.apache.ivy.core.settings.IvySettings
+import org.apache.ivy.plugins.matcher.GlobPatternMatcher
+import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver}
+
+import org.apache.spark.deploy.rest._
+import org.apache.spark.executor._
import org.apache.spark.util.Utils
+/**
+ * Whether to submit, kill, or request the status of an application.
+ * The latter two operations are currently supported only for standalone cluster mode.
+ */
+private[spark] object SparkSubmitAction extends Enumeration {
+ type SparkSubmitAction = Value
+ val SUBMIT, KILL, REQUEST_STATUS = Value
+}
+
/**
* Main gateway of launching a Spark application.
*
@@ -71,21 +93,74 @@ object SparkSubmit {
if (appArgs.verbose) {
printStream.println(appArgs)
}
- val (childArgs, classpath, sysProps, mainClass) = createLaunchEnv(appArgs)
- launch(childArgs, classpath, sysProps, mainClass, appArgs.verbose)
+ appArgs.action match {
+ case SparkSubmitAction.SUBMIT => submit(appArgs)
+ case SparkSubmitAction.KILL => kill(appArgs)
+ case SparkSubmitAction.REQUEST_STATUS => requestStatus(appArgs)
+ }
+ }
+
+ /** Kill an existing submission using the REST protocol. Standalone cluster mode only. */
+ private def kill(args: SparkSubmitArguments): Unit = {
+ new StandaloneRestClient()
+ .killSubmission(args.master, args.submissionToKill)
}
/**
- * @return a tuple containing
- * (1) the arguments for the child process,
- * (2) a list of classpath entries for the child,
- * (3) a list of system properties and env vars, and
- * (4) the main class for the child
+ * Request the status of an existing submission using the REST protocol.
+ * Standalone cluster mode only.
*/
- private[spark] def createLaunchEnv(args: SparkSubmitArguments)
- : (ArrayBuffer[String], ArrayBuffer[String], Map[String, String], String) = {
+ private def requestStatus(args: SparkSubmitArguments): Unit = {
+ new StandaloneRestClient()
+ .requestSubmissionStatus(args.master, args.submissionToRequestStatusFor)
+ }
- // Values to return
+ /**
+ * Submit the application using the provided parameters.
+ *
+ * This runs in two steps. First, we prepare the launch environment by setting up
+ * the appropriate classpath, system properties, and application arguments for
+ * running the child main class based on the cluster manager and the deploy mode.
+ * Second, we use this launch environment to invoke the main method of the child
+ * main class.
+ */
+ private[spark] def submit(args: SparkSubmitArguments): Unit = {
+ val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args)
+ // In standalone cluster mode, there are two submission gateways:
+ // (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper
+ // (2) The new REST-based gateway introduced in Spark 1.3
+ // The latter is the default behavior as of Spark 1.3, but Spark submit will fail over
+ // to use the legacy gateway if the master endpoint turns out to be not a REST server.
+ if (args.isStandaloneCluster && args.useRest) {
+ try {
+ printStream.println("Running Spark using the REST application submission protocol.")
+ runMain(childArgs, childClasspath, sysProps, childMainClass)
+ } catch {
+ // Fail over to use the legacy submission gateway
+ case e: SubmitRestConnectionException =>
+ printWarning(s"Master endpoint ${args.master} was not a REST server. " +
+ "Falling back to legacy submission gateway instead.")
+ args.useRest = false
+ submit(args)
+ }
+ // In all other modes, just run the main class as prepared
+ } else {
+ runMain(childArgs, childClasspath, sysProps, childMainClass)
+ }
+ }
+
+ /**
+ * Prepare the environment for submitting an application.
+ * This returns a 4-tuple:
+ * (1) the arguments for the child process,
+ * (2) a list of classpath entries for the child,
+ * (3) a map of system properties, and
+ * (4) the main class for the child
+ * Exposed for testing.
+ */
+ private[spark] def prepareSubmitEnvironment(args: SparkSubmitArguments)
+ : (Seq[String], Seq[String], Map[String, String], String) = {
+ // Return values
val childArgs = new ArrayBuffer[String]()
val childClasspath = new ArrayBuffer[String]()
val sysProps = new HashMap[String, String]()
@@ -134,21 +209,38 @@ object SparkSubmit {
}
}
+ val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
+
+ // Require all python files to be local, so we can add them to the PYTHONPATH
+ // In YARN cluster mode, python files are distributed as regular files, which can be non-local
+ if (args.isPython && !isYarnCluster) {
+ if (Utils.nonLocalPaths(args.primaryResource).nonEmpty) {
+ printErrorAndExit(s"Only local python files are supported: $args.primaryResource")
+ }
+ val nonLocalPyFiles = Utils.nonLocalPaths(args.pyFiles).mkString(",")
+ if (nonLocalPyFiles.nonEmpty) {
+ printErrorAndExit(s"Only local additional python files are supported: $nonLocalPyFiles")
+ }
+ }
+
// The following modes are not supported or applicable
(clusterManager, deployMode) match {
case (MESOS, CLUSTER) =>
printErrorAndExit("Cluster deploy mode is currently not supported for Mesos clusters.")
- case (_, CLUSTER) if args.isPython =>
- printErrorAndExit("Cluster deploy mode is currently not supported for python applications.")
+ case (STANDALONE, CLUSTER) if args.isPython =>
+ printErrorAndExit("Cluster deploy mode is currently not supported for python " +
+ "applications on standalone clusters.")
case (_, CLUSTER) if isShell(args.primaryResource) =>
printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.")
case (_, CLUSTER) if isSqlShell(args.mainClass) =>
printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.")
+ case (_, CLUSTER) if isThriftServer(args.mainClass) =>
+ printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.")
case _ =>
}
// If we're running a python app, set the main class to our specific python runner
- if (args.isPython) {
+ if (args.isPython && deployMode == CLIENT) {
if (args.primaryResource == PYSPARK_SHELL) {
args.mainClass = "py4j.GatewayServer"
args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0")
@@ -165,9 +257,28 @@ object SparkSubmit {
}
}
+ // In yarn-cluster mode for a python app, add primary resource and pyFiles to files
+ // that can be distributed with the job
+ if (args.isPython && isYarnCluster) {
+ args.files = mergeFileLists(args.files, args.primaryResource)
+ args.files = mergeFileLists(args.files, args.pyFiles)
+ }
+
// Special flag to avoid deprecation warnings at the client
sysProps("SPARK_SUBMIT") = "true"
+ // Resolve maven dependencies if there are any and add classpath to jars
+ val resolvedMavenCoordinates =
+ SparkSubmitUtils.resolveMavenCoordinates(
+ args.packages, Option(args.repositories), Option(args.ivyRepoPath))
+ if (!resolvedMavenCoordinates.trim.isEmpty) {
+ if (args.jars == null || args.jars.trim.isEmpty) {
+ args.jars = resolvedMavenCoordinates
+ } else {
+ args.jars += s",$resolvedMavenCoordinates"
+ }
+ }
+
// A list of rules to map each argument to system properties or command-line options in
// each deploy mode; we iterate through these below
val options = List[OptionAssigner](
@@ -176,6 +287,7 @@ object SparkSubmit {
OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"),
OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"),
OptionAssigner(args.jars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars"),
+ OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"),
OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT,
sysProp = "spark.driver.memory"),
OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES,
@@ -186,9 +298,13 @@ object SparkSubmit {
sysProp = "spark.driver.extraLibraryPath"),
// Standalone cluster only
+ // Do not set CL arguments here because there are multiple possibilities for the main class
OptionAssigner(args.jars, STANDALONE, CLUSTER, sysProp = "spark.jars"),
- OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, clOption = "--memory"),
- OptionAssigner(args.driverCores, STANDALONE, CLUSTER, clOption = "--cores"),
+ OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"),
+ OptionAssigner(args.driverMemory, STANDALONE, CLUSTER, sysProp = "spark.driver.memory"),
+ OptionAssigner(args.driverCores, STANDALONE, CLUSTER, sysProp = "spark.driver.cores"),
+ OptionAssigner(args.supervise.toString, STANDALONE, CLUSTER,
+ sysProp = "spark.driver.supervise"),
// Yarn client only
OptionAssigner(args.queue, YARN, CLIENT, sysProp = "spark.yarn.queue"),
@@ -200,6 +316,7 @@ object SparkSubmit {
// Yarn cluster only
OptionAssigner(args.name, YARN, CLUSTER, clOption = "--name"),
OptionAssigner(args.driverMemory, YARN, CLUSTER, clOption = "--driver-memory"),
+ OptionAssigner(args.driverCores, YARN, CLUSTER, clOption = "--driver-cores"),
OptionAssigner(args.queue, YARN, CLUSTER, clOption = "--queue"),
OptionAssigner(args.numExecutors, YARN, CLUSTER, clOption = "--num-executors"),
OptionAssigner(args.executorMemory, YARN, CLUSTER, clOption = "--executor-memory"),
@@ -228,7 +345,6 @@ object SparkSubmit {
if (args.childArgs != null) { childArgs ++= args.childArgs }
}
-
// Map all arguments to command-line options or system properties for our chosen mode
for (opt <- options) {
if (opt.value != null &&
@@ -242,7 +358,6 @@ object SparkSubmit {
// Add the application jar automatically so the user doesn't have to call sc.addJar
// For YARN cluster mode, the jar is already distributed on each node as "app.jar"
// For python files, the primary resource is already distributed as a regular file
- val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
if (!isYarnCluster && !args.isPython) {
var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty)
if (isUserJar(args.primaryResource)) {
@@ -251,14 +366,21 @@ object SparkSubmit {
sysProps.put("spark.jars", jars.mkString(","))
}
- // In standalone-cluster mode, use Client as a wrapper around the user class
- if (clusterManager == STANDALONE && deployMode == CLUSTER) {
- childMainClass = "org.apache.spark.deploy.Client"
- if (args.supervise) {
- childArgs += "--supervise"
+ // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+).
+ // All Spark parameters are expected to be passed to the client through system properties.
+ if (args.isStandaloneCluster) {
+ if (args.useRest) {
+ childMainClass = "org.apache.spark.deploy.rest.StandaloneRestClient"
+ childArgs += (args.primaryResource, args.mainClass)
+ } else {
+ // In legacy standalone cluster mode, use Client as a wrapper around the user class
+ childMainClass = "org.apache.spark.deploy.Client"
+ if (args.supervise) { childArgs += "--supervise" }
+ Option(args.driverMemory).foreach { m => childArgs += ("--memory", m) }
+ Option(args.driverCores).foreach { c => childArgs += ("--cores", c) }
+ childArgs += "launch"
+ childArgs += (args.master, args.primaryResource, args.mainClass)
}
- childArgs += "launch"
- childArgs += (args.master, args.primaryResource, args.mainClass)
if (args.childArgs != null) {
childArgs ++= args.childArgs
}
@@ -267,10 +389,22 @@ object SparkSubmit {
// In yarn-cluster mode, use yarn.Client as a wrapper around the user class
if (isYarnCluster) {
childMainClass = "org.apache.spark.deploy.yarn.Client"
- if (args.primaryResource != SPARK_INTERNAL) {
- childArgs += ("--jar", args.primaryResource)
+ if (args.isPython) {
+ val mainPyFile = new Path(args.primaryResource).getName
+ childArgs += ("--primary-py-file", mainPyFile)
+ if (args.pyFiles != null) {
+ // These files will be distributed to each machine's working directory, so strip the
+ // path prefix
+ val pyFilesNames = args.pyFiles.split(",").map(p => (new Path(p)).getName).mkString(",")
+ childArgs += ("--py-files", pyFilesNames)
+ }
+ childArgs += ("--class", "org.apache.spark.deploy.PythonRunner")
+ } else {
+ if (args.primaryResource != SPARK_INTERNAL) {
+ childArgs += ("--jar", args.primaryResource)
+ }
+ childArgs += ("--class", args.mainClass)
}
- childArgs += ("--class", args.mainClass)
if (args.childArgs != null) {
args.childArgs.foreach { arg => childArgs += ("--arg", arg) }
}
@@ -283,7 +417,7 @@ object SparkSubmit {
// Ignore invalid spark.driver.host in cluster modes.
if (deployMode == CLUSTER) {
- sysProps -= ("spark.driver.host")
+ sysProps -= "spark.driver.host"
}
// Resolve paths in certain spark properties
@@ -312,9 +446,15 @@ object SparkSubmit {
(childArgs, childClasspath, sysProps, childMainClass)
}
- private def launch(
- childArgs: ArrayBuffer[String],
- childClasspath: ArrayBuffer[String],
+ /**
+ * Run the main method of the child class using the provided launch environment.
+ *
+ * Note that this main class will not be the one provided by the user if we're
+ * running cluster deploy mode or python applications.
+ */
+ private def runMain(
+ childArgs: Seq[String],
+ childClasspath: Seq[String],
sysProps: Map[String, String],
childMainClass: String,
verbose: Boolean = false) {
@@ -326,8 +466,14 @@ object SparkSubmit {
printStream.println("\n")
}
- val loader = new ExecutorURLClassLoader(new Array[URL](0),
- Thread.currentThread.getContextClassLoader)
+ val loader =
+ if (sysProps.getOrElse("spark.files.userClassPathFirst", "false").toBoolean) {
+ new ChildExecutorURLClassLoader(new Array[URL](0),
+ Thread.currentThread.getContextClassLoader)
+ } else {
+ new ExecutorURLClassLoader(new Array[URL](0),
+ Thread.currentThread.getContextClassLoader)
+ }
Thread.currentThread.setContextClassLoader(loader)
for (jar <- childClasspath) {
@@ -346,8 +492,8 @@ object SparkSubmit {
case e: ClassNotFoundException =>
e.printStackTrace(printStream)
if (childMainClass.contains("thriftserver")) {
- println(s"Failed to load main class $childMainClass.")
- println("You need to build Spark with -Phive and -Phive-thriftserver.")
+ printStream.println(s"Failed to load main class $childMainClass.")
+ printStream.println("You need to build Spark with -Phive and -Phive-thriftserver.")
}
System.exit(CLASS_NOT_FOUND_EXIT_STATUS)
}
@@ -371,7 +517,7 @@ object SparkSubmit {
}
}
- private def addJarToClasspath(localJar: String, loader: ExecutorURLClassLoader) {
+ private def addJarToClasspath(localJar: String, loader: MutableURLClassLoader) {
val uri = Utils.resolveURI(localJar)
uri.getScheme match {
case "file" | "local" =>
@@ -407,6 +553,13 @@ object SparkSubmit {
mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver"
}
+ /**
+ * Return whether the given main class represents a thrift server.
+ */
+ private[spark] def isThriftServer(mainClass: String): Boolean = {
+ mainClass == "org.apache.spark.sql.hive.thriftserver.HiveThriftServer2"
+ }
+
/**
* Return whether the given primary resource requires running python.
*/
@@ -430,11 +583,199 @@ object SparkSubmit {
}
}
+/** Provides utility functions to be used inside SparkSubmit. */
+private[spark] object SparkSubmitUtils {
+
+ // Exposed for testing
+ private[spark] var printStream = SparkSubmit.printStream
+
+ /**
+ * Represents a Maven Coordinate
+ * @param groupId the groupId of the coordinate
+ * @param artifactId the artifactId of the coordinate
+ * @param version the version of the coordinate
+ */
+ private[spark] case class MavenCoordinate(groupId: String, artifactId: String, version: String)
+
+/**
+ * Extracts maven coordinates from a comma-delimited string
+ * @param coordinates Comma-delimited string of maven coordinates
+ * @return Sequence of Maven coordinates
+ */
+ private[spark] def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = {
+ coordinates.split(",").map { p =>
+ val splits = p.split(":")
+ require(splits.length == 3, s"Provided Maven Coordinates must be in the form " +
+ s"'groupId:artifactId:version'. The coordinate provided is: $p")
+ require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " +
+ s"be whitespace. The groupId provided is: ${splits(0)}")
+ require(splits(1) != null && splits(1).trim.nonEmpty, s"The artifactId cannot be null or " +
+ s"be whitespace. The artifactId provided is: ${splits(1)}")
+ require(splits(2) != null && splits(2).trim.nonEmpty, s"The version cannot be null or " +
+ s"be whitespace. The version provided is: ${splits(2)}")
+ new MavenCoordinate(splits(0), splits(1), splits(2))
+ }
+ }
+
+ /**
+ * Extracts maven coordinates from a comma-delimited string
+ * @param remoteRepos Comma-delimited string of remote repositories
+ * @return A ChainResolver used by Ivy to search for and resolve dependencies.
+ */
+ private[spark] def createRepoResolvers(remoteRepos: Option[String]): ChainResolver = {
+ // We need a chain resolver if we want to check multiple repositories
+ val cr = new ChainResolver
+ cr.setName("list")
+
+ // the biblio resolver resolves POM declared dependencies
+ val br: IBiblioResolver = new IBiblioResolver
+ br.setM2compatible(true)
+ br.setUsepoms(true)
+ br.setName("central")
+ cr.add(br)
+
+ val repositoryList = remoteRepos.getOrElse("")
+ // add any other remote repositories other than maven central
+ if (repositoryList.trim.nonEmpty) {
+ repositoryList.split(",").zipWithIndex.foreach { case (repo, i) =>
+ val brr: IBiblioResolver = new IBiblioResolver
+ brr.setM2compatible(true)
+ brr.setUsepoms(true)
+ brr.setRoot(repo)
+ brr.setName(s"repo-${i + 1}")
+ cr.add(brr)
+ printStream.println(s"$repo added as a remote repository with the name: ${brr.getName}")
+ }
+ }
+ cr
+ }
+
+ /**
+ * Output a comma-delimited list of paths for the downloaded jars to be added to the classpath
+ * (will append to jars in SparkSubmit). The name of the jar is given
+ * after a '!' by Ivy. It also sometimes contains '(bundle)' after '.jar'. Remove that as well.
+ * @param artifacts Sequence of dependencies that were resolved and retrieved
+ * @param cacheDirectory directory where jars are cached
+ * @return a comma-delimited list of paths for the dependencies
+ */
+ private[spark] def resolveDependencyPaths(
+ artifacts: Array[AnyRef],
+ cacheDirectory: File): String = {
+ artifacts.map { artifactInfo =>
+ val artifactString = artifactInfo.toString
+ val jarName = artifactString.drop(artifactString.lastIndexOf("!") + 1)
+ cacheDirectory.getAbsolutePath + File.separator +
+ jarName.substring(0, jarName.lastIndexOf(".jar") + 4)
+ }.mkString(",")
+ }
+
+ /** Adds the given maven coordinates to Ivy's module descriptor. */
+ private[spark] def addDependenciesToIvy(
+ md: DefaultModuleDescriptor,
+ artifacts: Seq[MavenCoordinate],
+ ivyConfName: String): Unit = {
+ artifacts.foreach { mvn =>
+ val ri = ModuleRevisionId.newInstance(mvn.groupId, mvn.artifactId, mvn.version)
+ val dd = new DefaultDependencyDescriptor(ri, false, false)
+ dd.addDependencyConfiguration(ivyConfName, ivyConfName)
+ printStream.println(s"${dd.getDependencyId} added as a dependency")
+ md.addDependency(dd)
+ }
+ }
+
+ /** A nice function to use in tests as well. Values are dummy strings. */
+ private[spark] def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance(
+ ModuleRevisionId.newInstance("org.apache.spark", "spark-submit-parent", "1.0"))
+
+ /**
+ * Resolves any dependencies that were supplied through maven coordinates
+ * @param coordinates Comma-delimited string of maven coordinates
+ * @param remoteRepos Comma-delimited string of remote repositories other than maven central
+ * @param ivyPath The path to the local ivy repository
+ * @return The comma-delimited path to the jars of the given maven artifacts including their
+ * transitive dependencies
+ */
+ private[spark] def resolveMavenCoordinates(
+ coordinates: String,
+ remoteRepos: Option[String],
+ ivyPath: Option[String],
+ isTest: Boolean = false): String = {
+ if (coordinates == null || coordinates.trim.isEmpty) {
+ ""
+ } else {
+ val artifacts = extractMavenCoordinates(coordinates)
+ // Default configuration name for ivy
+ val ivyConfName = "default"
+ // set ivy settings for location of cache
+ val ivySettings: IvySettings = new IvySettings
+ // Directories for caching downloads through ivy and storing the jars when maven coordinates
+ // are supplied to spark-submit
+ val alternateIvyCache = ivyPath.getOrElse("")
+ val packagesDirectory: File =
+ if (alternateIvyCache.trim.isEmpty) {
+ new File(ivySettings.getDefaultIvyUserDir, "jars")
+ } else {
+ ivySettings.setDefaultCache(new File(alternateIvyCache, "cache"))
+ new File(alternateIvyCache, "jars")
+ }
+ printStream.println(
+ s"Ivy Default Cache set to: ${ivySettings.getDefaultCache.getAbsolutePath}")
+ printStream.println(s"The jars for the packages stored in: $packagesDirectory")
+ // create a pattern matcher
+ ivySettings.addMatcher(new GlobPatternMatcher)
+ // create the dependency resolvers
+ val repoResolver = createRepoResolvers(remoteRepos)
+ ivySettings.addResolver(repoResolver)
+ ivySettings.setDefaultResolver(repoResolver.getName)
+
+ val ivy = Ivy.newInstance(ivySettings)
+ // Set resolve options to download transitive dependencies as well
+ val resolveOptions = new ResolveOptions
+ resolveOptions.setTransitive(true)
+ val retrieveOptions = new RetrieveOptions
+ // Turn downloading and logging off for testing
+ if (isTest) {
+ resolveOptions.setDownload(false)
+ resolveOptions.setLog(LogOptions.LOG_QUIET)
+ retrieveOptions.setLog(LogOptions.LOG_QUIET)
+ } else {
+ resolveOptions.setDownload(true)
+ }
+
+ // A Module descriptor must be specified. Entries are dummy strings
+ val md = getModuleDescriptor
+ md.setDefaultConf(ivyConfName)
+
+ // Add an exclusion rule for Spark
+ val sparkArtifacts = new ArtifactId(new ModuleId("org.apache.spark", "*"), "*", "*", "*")
+ val sparkDependencyExcludeRule =
+ new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null)
+ sparkDependencyExcludeRule.addConfiguration(ivyConfName)
+
+ // Exclude any Spark dependencies, and add all supplied maven artifacts as dependencies
+ md.addExcludeRule(sparkDependencyExcludeRule)
+ addDependenciesToIvy(md, artifacts, ivyConfName)
+
+ // resolve dependencies
+ val rr: ResolveReport = ivy.resolve(md, resolveOptions)
+ if (rr.hasError) {
+ throw new RuntimeException(rr.getAllProblemMessages.toString)
+ }
+ // retrieve all resolved dependencies
+ ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId,
+ packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]",
+ retrieveOptions.setConfs(Array(ivyConfName)))
+
+ resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory)
+ }
+ }
+}
+
/**
* Provides an indirection layer for passing arguments as system properties or flags to
* the user's driver program or to downstream launcher tools.
*/
-private[spark] case class OptionAssigner(
+private case class OptionAssigner(
value: String,
clusterManager: Int,
deployMode: Int,
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index f0e9ee67f6a67..bd0ae26fd8210 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -17,10 +17,12 @@
package org.apache.spark.deploy
+import java.net.URI
import java.util.jar.JarFile
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import org.apache.spark.deploy.SparkSubmitAction._
import org.apache.spark.util.Utils
/**
@@ -38,8 +40,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var driverExtraClassPath: String = null
var driverExtraLibraryPath: String = null
var driverExtraJavaOptions: String = null
- var driverCores: String = null
- var supervise: Boolean = false
var queue: String = null
var numExecutors: String = null
var files: String = null
@@ -49,11 +49,22 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var name: String = null
var childArgs: ArrayBuffer[String] = new ArrayBuffer[String]()
var jars: String = null
+ var packages: String = null
+ var repositories: String = null
+ var ivyRepoPath: String = null
var verbose: Boolean = false
var isPython: Boolean = false
var pyFiles: String = null
+ var action: SparkSubmitAction = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
+ // Standalone cluster mode only
+ var supervise: Boolean = false
+ var driverCores: String = null
+ var submissionToKill: String = null
+ var submissionToRequestStatusFor: String = null
+ var useRest: Boolean = true // used internally
+
/** Default properties present in the currently defined defaults file. */
lazy val defaultSparkProperties: HashMap[String, String] = {
val defaultProperties = new HashMap[String, String]()
@@ -78,7 +89,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
// Use `sparkProperties` map along with env vars to fill in any missing parameters
loadEnvironmentArguments()
- checkRequiredArguments()
+ validateArguments()
/**
* Merge values from the default properties file with those specified through --conf.
@@ -103,10 +114,22 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
.orElse(sparkProperties.get("spark.master"))
.orElse(env.get("MASTER"))
.orNull
+ driverExtraClassPath = Option(driverExtraClassPath)
+ .orElse(sparkProperties.get("spark.driver.extraClassPath"))
+ .orNull
+ driverExtraJavaOptions = Option(driverExtraJavaOptions)
+ .orElse(sparkProperties.get("spark.driver.extraJavaOptions"))
+ .orNull
+ driverExtraLibraryPath = Option(driverExtraLibraryPath)
+ .orElse(sparkProperties.get("spark.driver.extraLibraryPath"))
+ .orNull
driverMemory = Option(driverMemory)
.orElse(sparkProperties.get("spark.driver.memory"))
.orElse(env.get("SPARK_DRIVER_MEMORY"))
.orNull
+ driverCores = Option(driverCores)
+ .orElse(sparkProperties.get("spark.driver.cores"))
+ .orNull
executorMemory = Option(executorMemory)
.orElse(sparkProperties.get("spark.executor.memory"))
.orElse(env.get("SPARK_EXECUTOR_MEMORY"))
@@ -119,33 +142,61 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
.orNull
name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull
jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull
+ ivyRepoPath = sparkProperties.get("spark.jars.ivy").orNull
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
+ numExecutors = Option(numExecutors)
+ .getOrElse(sparkProperties.get("spark.executor.instances").orNull)
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && primaryResource != null) {
- try {
- val jar = new JarFile(primaryResource)
- // Note that this might still return null if no main-class is set; we catch that later
- mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class")
- } catch {
- case e: Exception =>
- SparkSubmit.printErrorAndExit("Cannot load main class from JAR: " + primaryResource)
- return
+ val uri = new URI(primaryResource)
+ val uriScheme = uri.getScheme()
+
+ uriScheme match {
+ case "file" =>
+ try {
+ val jar = new JarFile(uri.getPath)
+ // Note that this might still return null if no main-class is set; we catch that later
+ mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class")
+ } catch {
+ case e: Exception =>
+ SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource")
+ }
+ case _ =>
+ SparkSubmit.printErrorAndExit(
+ s"Cannot load main class from JAR $primaryResource with URI $uriScheme. " +
+ "Please specify a class through --class.")
}
}
// Global defaults. These should be keep to minimum to avoid confusing behavior.
master = Option(master).getOrElse("local[*]")
+ // In YARN mode, app name can be set via SPARK_YARN_APP_NAME (see SPARK-5222)
+ if (master.startsWith("yarn")) {
+ name = Option(name).orElse(env.get("SPARK_YARN_APP_NAME")).orNull
+ }
+
// Set name from main class if not given
name = Option(name).orElse(Option(mainClass)).orNull
if (name == null && primaryResource != null) {
name = Utils.stripDirectory(primaryResource)
}
+
+ // Action should be SUBMIT unless otherwise specified
+ action = Option(action).getOrElse(SUBMIT)
}
/** Ensure that required fields exists. Call this only once all defaults are loaded. */
- private def checkRequiredArguments(): Unit = {
+ private def validateArguments(): Unit = {
+ action match {
+ case SUBMIT => validateSubmitArguments()
+ case KILL => validateKillArguments()
+ case REQUEST_STATUS => validateStatusRequestArguments()
+ }
+ }
+
+ private def validateSubmitArguments(): Unit = {
if (args.length == 0) {
printUsageAndExit(-1)
}
@@ -159,18 +210,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
SparkSubmit.printErrorAndExit("--py-files given but primary resource is not a Python script")
}
- // Require all python files to be local, so we can add them to the PYTHONPATH
- if (isPython) {
- if (Utils.nonLocalPaths(primaryResource).nonEmpty) {
- SparkSubmit.printErrorAndExit(s"Only local python files are supported: $primaryResource")
- }
- val nonLocalPyFiles = Utils.nonLocalPaths(pyFiles).mkString(",")
- if (nonLocalPyFiles.nonEmpty) {
- SparkSubmit.printErrorAndExit(
- s"Only local additional python files are supported: $nonLocalPyFiles")
- }
- }
-
if (master.startsWith("yarn")) {
val hasHadoopEnv = env.contains("HADOOP_CONF_DIR") || env.contains("YARN_CONF_DIR")
if (!hasHadoopEnv && !Utils.isTesting) {
@@ -180,6 +219,29 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
+ private def validateKillArguments(): Unit = {
+ if (!master.startsWith("spark://")) {
+ SparkSubmit.printErrorAndExit("Killing submissions is only supported in standalone mode!")
+ }
+ if (submissionToKill == null) {
+ SparkSubmit.printErrorAndExit("Please specify a submission to kill.")
+ }
+ }
+
+ private def validateStatusRequestArguments(): Unit = {
+ if (!master.startsWith("spark://")) {
+ SparkSubmit.printErrorAndExit(
+ "Requesting submission statuses is only supported in standalone mode!")
+ }
+ if (submissionToRequestStatusFor == null) {
+ SparkSubmit.printErrorAndExit("Please specify a submission to request status for.")
+ }
+ }
+
+ def isStandaloneCluster: Boolean = {
+ master.startsWith("spark://") && deployMode == "cluster"
+ }
+
override def toString = {
s"""Parsed arguments:
| master $master
@@ -204,6 +266,8 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| name $name
| childArgs [${childArgs.mkString(" ")}]
| jars $jars
+ | packages $packages
+ | repositories $repositories
| verbose $verbose
|
|Spark properties used, including those specified through
@@ -212,7 +276,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
""".stripMargin
}
- /** Fill in values by parsing user options. */
+ /**
+ * Fill in values by parsing user options.
+ * NOTE: Any changes here must be reflected in YarnClientSchedulerBackend.
+ */
private def parseOpts(opts: Seq[String]): Unit = {
val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r
@@ -283,6 +350,22 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
propertiesFile = value
parse(tail)
+ case ("--kill") :: value :: tail =>
+ submissionToKill = value
+ if (action != null) {
+ SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $KILL.")
+ }
+ action = KILL
+ parse(tail)
+
+ case ("--status") :: value :: tail =>
+ submissionToRequestStatusFor = value
+ if (action != null) {
+ SparkSubmit.printErrorAndExit(s"Action cannot be both $action and $REQUEST_STATUS.")
+ }
+ action = REQUEST_STATUS
+ parse(tail)
+
case ("--supervise") :: tail =>
supervise = true
parse(tail)
@@ -307,6 +390,14 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
jars = Utils.resolveURIs(value)
parse(tail)
+ case ("--packages") :: value :: tail =>
+ packages = value
+ parse(tail)
+
+ case ("--repositories") :: value :: tail =>
+ repositories = value
+ parse(tail)
+
case ("--conf" | "-c") :: value :: tail =>
value.split("=", 2).toSeq match {
case Seq(k, v) => sparkProperties(k) = v
@@ -347,7 +438,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
outStream.println("Unknown/unsupported param " + unknownParam)
}
outStream.println(
- """Usage: spark-submit [options] [app options]
+ """Usage: spark-submit [options] [app arguments]
+ |Usage: spark-submit --kill [submission ID] --master [spark://...]
+ |Usage: spark-submit --status [submission ID] --master [spark://...]
+ |
|Options:
| --master MASTER_URL spark://host:port, mesos://host:port, yarn, or local.
| --deploy-mode DEPLOY_MODE Whether to launch the driver program locally ("client") or
@@ -357,6 +451,13 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| --name NAME A name of your application.
| --jars JARS Comma-separated list of local jars to include on the driver
| and executor classpaths.
+ | --packages Comma-separated list of maven coordinates of jars to include
+ | on the driver and executor classpaths. Will search the local
+ | maven repo, then maven central and any additional remote
+ | repositories given by --repositories. The format for the
+ | coordinates should be groupId:artifactId:version.
+ | --repositories Comma-separated list of additional remote repositories to
+ | search for the maven coordinates given with --packages.
| --py-files PY_FILES Comma-separated list of .zip, .egg, or .py files to place
| on the PYTHONPATH for Python apps.
| --files FILES Comma-separated list of files to be placed in the working
@@ -381,16 +482,21 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| Spark standalone with cluster deploy mode only:
| --driver-cores NUM Cores for driver (Default: 1).
| --supervise If given, restarts the driver on failure.
+ | --kill SUBMISSION_ID If given, kills the driver specified.
+ | --status SUBMISSION_ID If given, requests the status of the driver specified.
|
| Spark standalone and Mesos only:
| --total-executor-cores NUM Total cores for all executors.
|
| YARN-only:
+ | --driver-cores NUM Number of cores used by the driver, only in cluster mode
+ | (Default: 1).
| --executor-cores NUM Number of cores per executor (Default: 1).
| --queue QUEUE_NAME The YARN queue to submit to (Default: "default").
| --num-executors NUM Number of executors to launch (Default: 2).
| --archives ARCHIVES Comma separated list of archives to be extracted into the
- | working directory of each executor.""".stripMargin
+ | working directory of each executor.
+ """.stripMargin
)
SparkSubmit.exitFn()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
index d2687faad62b1..2eab9981845e8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
@@ -151,7 +151,8 @@ private[spark] object SparkSubmitDriverBootstrapper {
val isWindows = Utils.isWindows
val isSubprocess = sys.env.contains("IS_SUBPROCESS")
if (!isWindows) {
- val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin")
+ val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin",
+ propagateEof = true)
stdinThread.start()
// Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on
// broken pipe, signaling that the parent process has exited. This is the case if the
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 4efebcaa350fe..ffe940fbda2fb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -26,7 +26,7 @@ import akka.actor._
import akka.pattern.ask
import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{Logging, SparkConf, SparkException}
+import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
@@ -47,6 +47,8 @@ private[spark] class AppClient(
conf: SparkConf)
extends Logging {
+ val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))
+
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
@@ -75,9 +77,9 @@ private[spark] class AppClient(
}
def tryRegisterAllMasters() {
- for (masterUrl <- masterUrls) {
- logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
+ for (masterAkkaUrl <- masterAkkaUrls) {
+ logInfo("Connecting to master " + masterAkkaUrl + "...")
+ val actor = context.actorSelection(masterAkkaUrl)
actor ! RegisterApplication(appDescription)
}
}
@@ -103,20 +105,15 @@ private[spark] class AppClient(
}
def changeMaster(url: String) {
+ // activeMasterUrl is a valid Spark url since we receive it from master.
activeMasterUrl = url
- master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
- masterAddress = activeMasterUrl match {
- case Master.sparkUrlRegex(host, port) =>
- Address("akka.tcp", Master.systemName, host, port.toInt)
- case x =>
- throw new SparkException("Invalid spark URL: " + x)
- }
+ master = context.actorSelection(
+ Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(actorSystem)))
+ masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(actorSystem))
}
private def isPossibleMaster(remoteUrl: Address) = {
- masterUrls.map(s => Master.toAkkaUrl(s))
- .map(u => AddressFromURIString(u).hostPort)
- .contains(remoteUrl.hostPort)
+ masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort)
}
override def receiveWithLogging = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
index fbe39b27649f6..553bf3cb945ab 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala
@@ -25,7 +25,8 @@ private[spark] case class ApplicationHistoryInfo(
startTime: Long,
endTime: Long,
lastUpdated: Long,
- sparkUser: String)
+ sparkUser: String,
+ completed: Boolean = false)
private[spark] abstract class ApplicationHistoryProvider {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 82a54dbfb5330..868c63d30a202 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -17,14 +17,16 @@
package org.apache.spark.deploy.history
-import java.io.FileNotFoundException
+import java.io.{BufferedInputStream, FileNotFoundException, InputStream}
import scala.collection.mutable
import org.apache.hadoop.fs.{FileStatus, Path}
+import org.apache.hadoop.fs.permission.AccessControlException
import org.apache.spark.{Logging, SecurityManager, SparkConf}
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.io.CompressionCodec
import org.apache.spark.scheduler._
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.Utils
@@ -64,6 +66,12 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
@volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo]
= new mutable.LinkedHashMap()
+ // Constants used to parse Spark 1.0.0 log directories.
+ private[history] val LOG_PREFIX = "EVENT_LOG_"
+ private[history] val SPARK_VERSION_PREFIX = "SPARK_VERSION_"
+ private[history] val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_"
+ private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
+
/**
* A background thread that periodically checks for event log updates on disk.
*
@@ -90,7 +98,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
initialize()
- private def initialize() {
+ private def initialize(): Unit = {
// Validate the log directory.
val path = new Path(logDir)
if (!fs.exists(path)) {
@@ -106,8 +114,12 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
}
checkForLogs()
- logCheckingThread.setDaemon(true)
- logCheckingThread.start()
+
+ // Disable the background thread during tests.
+ if (!conf.contains("spark.testing")) {
+ logCheckingThread.setDaemon(true)
+ logCheckingThread.start()
+ }
}
override def getListing() = applications.values
@@ -115,8 +127,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
override def getAppUI(appId: String): Option[SparkUI] = {
try {
applications.get(appId).map { info =>
- val (replayBus, appListener) = createReplayBus(fs.getFileStatus(
- new Path(logDir, info.logDir)))
+ val replayBus = new ReplayListenerBus()
val ui = {
val conf = this.conf.clone()
val appSecManager = new SecurityManager(conf)
@@ -125,15 +136,17 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
// Do not call ui.bind() to avoid creating a new server for each application
}
- replayBus.replay()
+ val appListener = new ApplicationEventListener()
+ replayBus.addListener(appListener)
+ val appInfo = replay(fs.getFileStatus(new Path(logDir, info.logPath)), replayBus)
- ui.setAppName(s"${appListener.appName.getOrElse(NOT_STARTED)} ($appId)")
+ ui.setAppName(s"${appInfo.name} ($appId)")
val uiAclsEnabled = conf.getBoolean("spark.history.ui.acls.enable", false)
ui.getSecurityManager.setAcls(uiAclsEnabled)
// make sure to set admin acls before view acls so they are properly picked up
ui.getSecurityManager.setAdminAcls(appListener.adminAcls.getOrElse(""))
- ui.getSecurityManager.setViewAcls(appListener.sparkUser.getOrElse(NOT_STARTED),
+ ui.getSecurityManager.setViewAcls(appInfo.sparkUser,
appListener.viewAcls.getOrElse(""))
ui
}
@@ -149,45 +162,39 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
* Tries to reuse as much of the data already in memory as possible, by not reading
* applications that haven't been updated since last time the logs were checked.
*/
- private def checkForLogs() = {
+ private[history] def checkForLogs(): Unit = {
lastLogCheckTimeMs = getMonotonicTimeMs()
logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs))
- try {
- val logStatus = fs.listStatus(new Path(logDir))
- val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]()
- // Load all new logs from the log directory. Only directories that have a modification time
- // later than the last known log directory will be loaded.
+ try {
var newLastModifiedTime = lastModifiedTime
- val logInfos = logDirs
- .filter { dir =>
- if (fs.isFile(new Path(dir.getPath(), EventLoggingListener.APPLICATION_COMPLETE))) {
- val modTime = getModificationTime(dir)
- newLastModifiedTime = math.max(newLastModifiedTime, modTime)
- modTime > lastModifiedTime
- } else {
- false
+ val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq)
+ .getOrElse(Seq[FileStatus]())
+ val logInfos = statusList
+ .filter { entry =>
+ try {
+ getModificationTime(entry).map { time =>
+ newLastModifiedTime = math.max(newLastModifiedTime, time)
+ time >= lastModifiedTime
+ }.getOrElse(false)
+ } catch {
+ case e: AccessControlException =>
+ // Do not use "logInfo" since these messages can get pretty noisy if printed on
+ // every poll.
+ logDebug(s"No permission to read $entry, ignoring.")
+ false
}
}
- .flatMap { dir =>
+ .flatMap { entry =>
try {
- val (replayBus, appListener) = createReplayBus(dir)
- replayBus.replay()
- Some(new FsApplicationHistoryInfo(
- dir.getPath().getName(),
- appListener.appId.getOrElse(dir.getPath().getName()),
- appListener.appName.getOrElse(NOT_STARTED),
- appListener.startTime.getOrElse(-1L),
- appListener.endTime.getOrElse(-1L),
- getModificationTime(dir),
- appListener.sparkUser.getOrElse(NOT_STARTED)))
+ Some(replay(entry, new ReplayListenerBus()))
} catch {
case e: Exception =>
- logInfo(s"Failed to load application log data from $dir.", e)
+ logError(s"Failed to load application log data from $entry.", e)
None
}
}
- .sortBy { info => -info.endTime }
+ .sortWith(compareAppInfo)
lastModifiedTime = newLastModifiedTime
@@ -197,7 +204,9 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
if (!logInfos.isEmpty) {
val newApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]()
def addIfAbsent(info: FsApplicationHistoryInfo) = {
- if (!newApps.contains(info.id)) {
+ if (!newApps.contains(info.id) ||
+ newApps(info.id).logPath.endsWith(EventLoggingListener.IN_PROGRESS) &&
+ !info.logPath.endsWith(EventLoggingListener.IN_PROGRESS)) {
newApps += (info.id -> info)
}
}
@@ -205,7 +214,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
val newIterator = logInfos.iterator.buffered
val oldIterator = applications.values.iterator.buffered
while (newIterator.hasNext && oldIterator.hasNext) {
- if (newIterator.head.endTime > oldIterator.head.endTime) {
+ if (compareAppInfo(newIterator.head, oldIterator.head)) {
addIfAbsent(newIterator.next)
} else {
addIfAbsent(oldIterator.next)
@@ -217,37 +226,128 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
applications = newApps
}
} catch {
- case t: Throwable => logError("Exception in checking for event log updates", t)
+ case e: Exception => logError("Exception in checking for event log updates", e)
}
}
- private def createReplayBus(logDir: FileStatus): (ReplayListenerBus, ApplicationEventListener) = {
- val path = logDir.getPath()
- val elogInfo = EventLoggingListener.parseLoggingInfo(path, fs)
- val replayBus = new ReplayListenerBus(elogInfo.logPaths, fs, elogInfo.compressionCodec)
- val appListener = new ApplicationEventListener
- replayBus.addListener(appListener)
- (replayBus, appListener)
+ /**
+ * Comparison function that defines the sort order for the application listing.
+ *
+ * @return Whether `i1` should precede `i2`.
+ */
+ private def compareAppInfo(
+ i1: FsApplicationHistoryInfo,
+ i2: FsApplicationHistoryInfo): Boolean = {
+ if (i1.endTime != i2.endTime) i1.endTime >= i2.endTime else i1.startTime >= i2.startTime
}
- /** Return when this directory was last modified. */
- private def getModificationTime(dir: FileStatus): Long = {
- try {
- val logFiles = fs.listStatus(dir.getPath)
- if (logFiles != null && !logFiles.isEmpty) {
- logFiles.map(_.getModificationTime).max
+ /**
+ * Replays the events in the specified log file and returns information about the associated
+ * application.
+ */
+ private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationHistoryInfo = {
+ val logPath = eventLog.getPath()
+ val (logInput, sparkVersion) =
+ if (isLegacyLogDirectory(eventLog)) {
+ openLegacyEventLog(logPath)
} else {
- dir.getModificationTime
+ EventLoggingListener.openEventLog(logPath, fs)
}
- } catch {
- case t: Throwable =>
- logError("Exception in accessing modification time of %s".format(dir.getPath), t)
- -1L
+ try {
+ val appListener = new ApplicationEventListener
+ bus.addListener(appListener)
+ bus.replay(logInput, sparkVersion)
+ new FsApplicationHistoryInfo(
+ logPath.getName(),
+ appListener.appId.getOrElse(logPath.getName()),
+ appListener.appName.getOrElse(NOT_STARTED),
+ appListener.startTime.getOrElse(-1L),
+ appListener.endTime.getOrElse(-1L),
+ getModificationTime(eventLog).get,
+ appListener.sparkUser.getOrElse(NOT_STARTED),
+ isApplicationCompleted(eventLog))
+ } finally {
+ logInput.close()
+ }
+ }
+
+ /**
+ * Loads a legacy log directory. This assumes that the log directory contains a single event
+ * log file (along with other metadata files), which is the case for directories generated by
+ * the code in previous releases.
+ *
+ * @return 2-tuple of (input stream of the events, version of Spark which wrote the log)
+ */
+ private[history] def openLegacyEventLog(dir: Path): (InputStream, String) = {
+ val children = fs.listStatus(dir)
+ var eventLogPath: Path = null
+ var codecName: Option[String] = None
+ var sparkVersion: String = null
+
+ children.foreach { child =>
+ child.getPath().getName() match {
+ case name if name.startsWith(LOG_PREFIX) =>
+ eventLogPath = child.getPath()
+
+ case codec if codec.startsWith(COMPRESSION_CODEC_PREFIX) =>
+ codecName = Some(codec.substring(COMPRESSION_CODEC_PREFIX.length()))
+
+ case version if version.startsWith(SPARK_VERSION_PREFIX) =>
+ sparkVersion = version.substring(SPARK_VERSION_PREFIX.length())
+
+ case _ =>
+ }
+ }
+
+ if (eventLogPath == null || sparkVersion == null) {
+ throw new IllegalArgumentException(s"$dir is not a Spark application log directory.")
+ }
+
+ val codec = try {
+ codecName.map { c => CompressionCodec.createCodec(conf, c) }
+ } catch {
+ case e: Exception =>
+ throw new IllegalArgumentException(s"Unknown compression codec $codecName.")
+ }
+
+ val in = new BufferedInputStream(fs.open(eventLogPath))
+ (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion)
+ }
+
+ /**
+ * Return whether the specified event log path contains a old directory-based event log.
+ * Previously, the event log of an application comprises of multiple files in a directory.
+ * As of Spark 1.3, these files are consolidated into a single one that replaces the directory.
+ * See SPARK-2261 for more detail.
+ */
+ private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDir()
+
+ /**
+ * Returns the modification time of the given event log. If the status points at an empty
+ * directory, `None` is returned, indicating that there isn't an event log at that location.
+ */
+ private def getModificationTime(fsEntry: FileStatus): Option[Long] = {
+ if (isLegacyLogDirectory(fsEntry)) {
+ val statusList = fs.listStatus(fsEntry.getPath)
+ if (!statusList.isEmpty) Some(statusList.map(_.getModificationTime()).max) else None
+ } else {
+ Some(fsEntry.getModificationTime())
}
}
/** Returns the system's mononotically increasing time. */
- private def getMonotonicTimeMs() = System.nanoTime() / (1000 * 1000)
+ private def getMonotonicTimeMs(): Long = System.nanoTime() / (1000 * 1000)
+
+ /**
+ * Return true when the application has completed.
+ */
+ private def isApplicationCompleted(entry: FileStatus): Boolean = {
+ if (isLegacyLogDirectory(entry)) {
+ fs.exists(new Path(entry.getPath(), APPLICATION_COMPLETE))
+ } else {
+ !entry.getPath().getName().endsWith(EventLoggingListener.IN_PROGRESS)
+ }
+ }
}
@@ -256,11 +356,12 @@ private object FsHistoryProvider {
}
private class FsApplicationHistoryInfo(
- val logDir: String,
+ val logPath: String,
id: String,
name: String,
startTime: Long,
endTime: Long,
lastUpdated: Long,
- sparkUser: String)
- extends ApplicationHistoryInfo(id, name, startTime, endTime, lastUpdated, sparkUser)
+ sparkUser: String,
+ completed: Boolean = true)
+ extends ApplicationHistoryInfo(id, name, startTime, endTime, lastUpdated, sparkUser, completed)
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index 5fdc350cd8512..e4e7bc2216014 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -26,12 +26,15 @@ import org.apache.spark.ui.{WebUIPage, UIUtils}
private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
private val pageSize = 20
+ private val plusOrMinus = 2
def render(request: HttpServletRequest): Seq[Node] = {
val requestedPage = Option(request.getParameter("page")).getOrElse("1").toInt
val requestedFirst = (requestedPage - 1) * pageSize
+ val requestedIncomplete =
+ Option(request.getParameter("showIncomplete")).getOrElse("false").toBoolean
- val allApps = parent.getApplicationList()
+ val allApps = parent.getApplicationList().filter(_.completed != requestedIncomplete)
val actualFirst = if (requestedFirst < allApps.size) requestedFirst else 0
val apps = allApps.slice(actualFirst, Math.min(actualFirst + pageSize, allApps.size))
@@ -39,6 +42,9 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
val last = Math.min(actualFirst + pageSize, allApps.size) - 1
val pageCount = allApps.size / pageSize + (if (allApps.size % pageSize > 0) 1 else 0)
+ val secondPageFromLeft = 2
+ val secondPageFromRight = pageCount - 1
+
val appTable = UIUtils.listingTable(appHeader, appRow, apps)
val providerConfig = parent.getProviderConfig()
val content =
@@ -48,12 +54,38 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
{providerConfig.map { case (k, v) => {k}: {v} }}
{
+ // This displays the indices of pages that are within `plusOrMinus` pages of
+ // the current page. Regardless of where the current page is, this also links
+ // to the first and last page. If the current page +/- `plusOrMinus` is greater
+ // than the 2nd page from the first page or less than the 2nd page from the last
+ // page, `...` will be displayed.
if (allApps.size > 0) {
+ val leftSideIndices =
+ rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _)
+ val rightSideIndices =
+ rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount)
+
Showing {actualFirst + 1}-{last + 1} of {allApps.size}
+ {if (requestedIncomplete) "(Incomplete applications)"}
- {if (actualPage > 1) <}
- {if (actualPage < pageCount) >}
+ {
+ if (actualPage > 1) {
+ <
+ 1
+ }
+ }
+ {if (actualPage - plusOrMinus > secondPageFromLeft) " ... "}
+ {leftSideIndices}
+ {actualPage}
+ {rightSideIndices}
+ {if (actualPage + plusOrMinus < secondPageFromRight) " ... "}
+ {
+ if (actualPage < pageCount) {
+ {pageCount}
+ >
+ }
+ }
++
appTable
@@ -67,6 +99,15 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
}
}
+
+ {
+ if (requestedIncomplete) {
+ "Back to completed applications"
+ } else {
+ "Show incomplete applications"
+ }
+ }
+
UIUtils.basicSparkPage(content, "History Server")
@@ -81,11 +122,16 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
"Spark User",
"Last Updated")
+ private def rangeIndices(range: Seq[Int], condition: Int => Boolean): Seq[Node] = {
+ range.filter(condition).map(nextPage => {nextPage} )
+ }
+
private def appRow(info: ApplicationHistoryInfo): Seq[Node] = {
val uiAddress = HistoryServer.UI_PATH_PREFIX + s"/${info.id}"
val startTime = UIUtils.formatDate(info.startTime)
- val endTime = UIUtils.formatDate(info.endTime)
- val duration = UIUtils.formatDuration(info.endTime - info.startTime)
+ val endTime = if (info.endTime > 0) UIUtils.formatDate(info.endTime) else "-"
+ val duration =
+ if (info.endTime > 0) UIUtils.formatDuration(info.endTime - info.startTime) else "-"
val lastUpdated = UIUtils.formatDate(info.lastUpdated)
{info.id} |
@@ -97,4 +143,11 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
{lastUpdated} |
}
+
+ private def makePageLink(linkPage: Int, showIncomplete: Boolean): String = {
+ "/?" + Array(
+ "page=" + linkPage,
+ "showIncomplete=" + showIncomplete
+ ).mkString("&")
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index ce00c0ffd21e0..fa9bfe5426b6c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -158,11 +158,12 @@ class HistoryServer(
/**
* The recommended way of starting and stopping a HistoryServer is through the scripts
- * start-history-server.sh and stop-history-server.sh. The path to a base log directory
- * is must be specified, while the requested UI port is optional. For example:
+ * start-history-server.sh and stop-history-server.sh. The path to a base log directory,
+ * as well as any other relevant history server configuration, should be specified via
+ * the $SPARK_HISTORY_OPTS environment variable. For example:
*
- * ./sbin/spark-history-server.sh /tmp/spark-events
- * ./sbin/spark-history-server.sh hdfs://1.2.3.4:9000/spark-events
+ * export SPARK_HISTORY_OPTS="-Dspark.history.fs.logDirectory=/tmp/spark-events"
+ * ./sbin/start-history-server.sh
*
* This launches the HistoryServer as a Spark daemon.
*/
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index ad7d81747c377..ede0a9dbefb8d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -38,8 +38,8 @@ private[spark] class ApplicationInfo(
extends Serializable {
@transient var state: ApplicationState.Value = _
- @transient var executors: mutable.HashMap[Int, ExecutorInfo] = _
- @transient var removedExecutors: ArrayBuffer[ExecutorInfo] = _
+ @transient var executors: mutable.HashMap[Int, ExecutorDesc] = _
+ @transient var removedExecutors: ArrayBuffer[ExecutorDesc] = _
@transient var coresGranted: Int = _
@transient var endTime: Long = _
@transient var appSource: ApplicationSource = _
@@ -55,12 +55,12 @@ private[spark] class ApplicationInfo(
private def init() {
state = ApplicationState.WAITING
- executors = new mutable.HashMap[Int, ExecutorInfo]
+ executors = new mutable.HashMap[Int, ExecutorDesc]
coresGranted = 0
endTime = -1L
appSource = new ApplicationSource(this)
nextExecutorId = 0
- removedExecutors = new ArrayBuffer[ExecutorInfo]
+ removedExecutors = new ArrayBuffer[ExecutorDesc]
}
private def newExecutorId(useID: Option[Int] = None): Int = {
@@ -75,14 +75,14 @@ private[spark] class ApplicationInfo(
}
}
- def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorInfo = {
- val exec = new ExecutorInfo(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
+ def addExecutor(worker: WorkerInfo, cores: Int, useID: Option[Int] = None): ExecutorDesc = {
+ val exec = new ExecutorDesc(newExecutorId(useID), this, worker, cores, desc.memoryPerSlave)
executors(exec.id) = exec
coresGranted += cores
exec
}
- def removeExecutor(exec: ExecutorInfo) {
+ def removeExecutor(exec: ExecutorDesc) {
if (executors.contains(exec.id)) {
removedExecutors += executors(exec.id)
executors -= exec.id
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
similarity index 95%
rename from core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
rename to core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
index d417070c51016..5d620dfcabad5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ExecutorInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ExecutorDesc.scala
@@ -19,7 +19,7 @@ package org.apache.spark.deploy.master
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
-private[spark] class ExecutorInfo(
+private[spark] class ExecutorDesc(
val id: Int,
val application: ApplicationInfo,
val worker: WorkerInfo,
@@ -37,7 +37,7 @@ private[spark] class ExecutorInfo(
override def equals(other: Any): Boolean = {
other match {
- case info: ExecutorInfo =>
+ case info: ExecutorDesc =>
fullId == info.fullId &&
worker.id == info.worker.id &&
cores == info.cores &&
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 7b32c505def9b..b8b1a25abff2e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -17,6 +17,7 @@
package org.apache.spark.deploy.master
+import java.io.FileNotFoundException
import java.net.URLEncoder
import java.text.SimpleDateFormat
import java.util.Date
@@ -32,6 +33,7 @@ import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import akka.serialization.Serialization
import akka.serialization.SerializationExtension
+import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, DriverDescription,
@@ -41,6 +43,7 @@ import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.deploy.master.DriverState.DriverState
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
+import org.apache.spark.deploy.rest.StandaloneRestServer
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus}
import org.apache.spark.ui.SparkUI
@@ -50,12 +53,13 @@ private[spark] class Master(
host: String,
port: Int,
webUiPort: Int,
- val securityMgr: SecurityManager)
+ val securityMgr: SecurityManager,
+ val conf: SparkConf)
extends Actor with ActorLogReceive with Logging with LeaderElectable {
import context.dispatcher // to use Akka's scheduler.schedule()
- val conf = new SparkConf
+ val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
def createDateFormat = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = conf.getLong("spark.worker.timeout", 60) * 1000
@@ -118,8 +122,20 @@ private[spark] class Master(
throw new SparkException("spark.deploy.defaultCores must be positive")
}
+ // Alternative application submission gateway that is stable across Spark versions
+ private val restServerEnabled = conf.getBoolean("spark.master.rest.enabled", true)
+ private val restServer =
+ if (restServerEnabled) {
+ val port = conf.getInt("spark.master.rest.port", 6066)
+ Some(new StandaloneRestServer(host, port, self, masterUrl, conf))
+ } else {
+ None
+ }
+ private val restServerBoundPort = restServer.map(_.start())
+
override def preStart() {
logInfo("Starting Spark master at " + masterUrl)
+ logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
// Listen for remote client disconnection events, since they don't go through Akka's watch()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.bind()
@@ -129,6 +145,10 @@ private[spark] class Master(
masterMetricsSystem.registerSource(masterSource)
masterMetricsSystem.start()
applicationMetricsSystem.start()
+ // Attach the master and app metrics servlet handler to the web ui after the metrics systems are
+ // started.
+ masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
+ applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
val (persistenceEngine_, leaderElectionAgent_) = RECOVERY_MODE match {
case "ZOOKEEPER" =>
@@ -166,6 +186,7 @@ private[spark] class Master(
recoveryCompletionTask.cancel()
}
webUi.stop()
+ restServer.foreach(_.stop())
masterMetricsSystem.stop()
applicationMetricsSystem.stop()
persistenceEngine.close()
@@ -413,7 +434,9 @@ private[spark] class Master(
}
case RequestMasterState => {
- sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray,
+ sender ! MasterStateResponse(
+ host, port, restServerBoundPort,
+ workers.toArray, apps.toArray, completedApps.toArray,
drivers.toArray, completedDrivers.toArray, state)
}
@@ -421,8 +444,8 @@ private[spark] class Master(
timeOutDeadWorkers()
}
- case RequestWebUIPort => {
- sender ! WebUIPortResponse(webUi.boundPort)
+ case BoundPortsRequest => {
+ sender ! BoundPortsResponse(port, webUi.boundPort, restServerBoundPort)
}
}
@@ -510,7 +533,7 @@ private[spark] class Master(
val shuffledAliveWorkers = Random.shuffle(workers.toSeq.filter(_.state == WorkerState.ALIVE))
val numWorkersAlive = shuffledAliveWorkers.size
var curPos = 0
-
+
for (driver <- waitingDrivers.toList) { // iterate over a copy of waitingDrivers
// We assign workers to each waiting driver in a round-robin fashion. For each driver, we
// start from the last worker that was assigned a driver, and continue onwards until we have
@@ -573,7 +596,7 @@ private[spark] class Master(
}
}
- def launchExecutor(worker: WorkerInfo, exec: ExecutorInfo) {
+ def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) {
logInfo("Launching executor " + exec.fullId + " on worker " + worker.id)
worker.addExecutor(exec)
worker.actor ! LaunchExecutor(masterUrl,
@@ -697,6 +720,11 @@ private[spark] class Master(
}
persistenceEngine.removeApplication(app)
schedule()
+
+ // Tell all workers that the application has finished, so they can clean up any app state.
+ workers.foreach { w =>
+ w.actor ! ApplicationFinished(app.id)
+ }
}
}
@@ -707,41 +735,51 @@ private[spark] class Master(
def rebuildSparkUI(app: ApplicationInfo): Boolean = {
val appName = app.desc.name
val notFoundBasePath = HistoryServer.UI_PATH_PREFIX + "/not-found"
- val eventLogDir = app.desc.eventLogDir.getOrElse {
- // Event logging is not enabled for this application
- app.desc.appUiUrl = notFoundBasePath
- return false
- }
-
- val appEventLogDir = EventLoggingListener.getLogDirPath(eventLogDir, app.id)
- val fileSystem = Utils.getHadoopFileSystem(appEventLogDir,
- SparkHadoopUtil.get.newConfiguration(conf))
- val eventLogInfo = EventLoggingListener.parseLoggingInfo(appEventLogDir, fileSystem)
- val eventLogPaths = eventLogInfo.logPaths
- val compressionCodec = eventLogInfo.compressionCodec
-
- if (eventLogPaths.isEmpty) {
- // Event logging is enabled for this application, but no event logs are found
- val title = s"Application history not found (${app.id})"
- var msg = s"No event logs found for application $appName in $appEventLogDir."
- logWarning(msg)
- msg += " Did you specify the correct logging directory?"
- msg = URLEncoder.encode(msg, "UTF-8")
- app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title"
- return false
- }
-
try {
- val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec)
+ val eventLogFile = app.desc.eventLogDir
+ .map { dir => EventLoggingListener.getLogPath(dir, app.id) }
+ .getOrElse {
+ // Event logging is not enabled for this application
+ app.desc.appUiUrl = notFoundBasePath
+ return false
+ }
+
+ val fs = Utils.getHadoopFileSystem(eventLogFile, hadoopConf)
+
+ if (fs.exists(new Path(eventLogFile + EventLoggingListener.IN_PROGRESS))) {
+ // Event logging is enabled for this application, but the application is still in progress
+ val title = s"Application history not found (${app.id})"
+ var msg = s"Application $appName is still in progress."
+ logWarning(msg)
+ msg = URLEncoder.encode(msg, "UTF-8")
+ app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title"
+ return false
+ }
+
+ val (logInput, sparkVersion) = EventLoggingListener.openEventLog(new Path(eventLogFile), fs)
+ val replayBus = new ReplayListenerBus()
val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf),
appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}")
- replayBus.replay()
+ try {
+ replayBus.replay(logInput, sparkVersion)
+ } finally {
+ logInput.close()
+ }
appIdToUI(app.id) = ui
webUi.attachSparkUI(ui)
// Application UI is successfully rebuilt, so link the Master UI to it
- app.desc.appUiUrl = ui.getBasePath
+ app.desc.appUiUrl = ui.basePath
true
} catch {
+ case fnf: FileNotFoundException =>
+ // Event logging is enabled for this application, but no event logs are found
+ val title = s"Application history not found (${app.id})"
+ var msg = s"No event logs found for application $appName in ${app.desc.eventLogDir}."
+ logWarning(msg)
+ msg += " Did you specify the correct logging directory?"
+ msg = URLEncoder.encode(msg, "UTF-8")
+ app.desc.appUiUrl = notFoundBasePath + s"?msg=$msg&title=$title"
+ false
case e: Exception =>
// Relay exception message to application UI page
val title = s"Application history load error (${app.id})"
@@ -823,39 +861,55 @@ private[spark] class Master(
private[spark] object Master extends Logging {
val systemName = "sparkMaster"
private val actorName = "Master"
- val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
def main(argStrings: Array[String]) {
SignalLogger.register(log)
val conf = new SparkConf
val args = new MasterArguments(argStrings, conf)
- val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf)
+ val (actorSystem, _, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf)
actorSystem.awaitTermination()
}
- /** Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
- def toAkkaUrl(sparkUrl: String): String = {
- sparkUrl match {
- case sparkUrlRegex(host, port) =>
- "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
- case _ =>
- throw new SparkException("Invalid master URL: " + sparkUrl)
- }
+ /**
+ * Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:port`.
+ *
+ * @throws SparkException if the url is invalid
+ */
+ def toAkkaUrl(sparkUrl: String, protocol: String): String = {
+ val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
+ AkkaUtils.address(protocol, systemName, host, port, actorName)
}
+ /**
+ * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`.
+ *
+ * @throws SparkException if the url is invalid
+ */
+ def toAkkaAddress(sparkUrl: String, protocol: String): Address = {
+ val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl)
+ Address(protocol, systemName, host, port)
+ }
+
+ /**
+ * Start the Master and return a four tuple of:
+ * (1) The Master actor system
+ * (2) The bound port
+ * (3) The web UI bound port
+ * (4) The REST server bound port, if any
+ */
def startSystemAndActor(
host: String,
port: Int,
webUiPort: Int,
- conf: SparkConf): (ActorSystem, Int, Int) = {
+ conf: SparkConf): (ActorSystem, Int, Int, Option[Int]) = {
val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf,
securityManager = securityMgr)
- val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort,
- securityMgr), actorName)
+ val actor = actorSystem.actorOf(
+ Props(classOf[Master], host, boundPort, webUiPort, securityMgr, conf), actorName)
val timeout = AkkaUtils.askTimeout(conf)
- val respFuture = actor.ask(RequestWebUIPort)(timeout)
- val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse]
- (actorSystem, boundPort, resp.webUIBoundPort)
+ val portsRequest = actor.ask(BoundPortsRequest)(timeout)
+ val portsResponse = Await.result(portsRequest, timeout).asInstanceOf[BoundPortsResponse]
+ (actorSystem, boundPort, portsResponse.webUIPort, portsResponse.restPort)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
index db72d8ae9bdaf..15c6296888f70 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterMessages.scala
@@ -36,7 +36,7 @@ private[master] object MasterMessages {
case object CompleteRecovery
- case object RequestWebUIPort
+ case object BoundPortsRequest
- case class WebUIPortResponse(webUIBoundPort: Int)
+ case class BoundPortsResponse(actorPort: Int, webUIPort: Int, restPort: Option[Int])
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index 473ddc23ff0f3..e94aae93e4495 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -38,7 +38,7 @@ private[spark] class WorkerInfo(
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
- @transient var executors: mutable.HashMap[String, ExecutorInfo] = _ // executorId => info
+ @transient var executors: mutable.HashMap[String, ExecutorDesc] = _ // executorId => info
@transient var drivers: mutable.HashMap[String, DriverInfo] = _ // driverId => info
@transient var state: WorkerState.Value = _
@transient var coresUsed: Int = _
@@ -70,13 +70,13 @@ private[spark] class WorkerInfo(
host + ":" + port
}
- def addExecutor(exec: ExecutorInfo) {
+ def addExecutor(exec: ExecutorDesc) {
executors(exec.fullId) = exec
coresUsed += exec.cores
memoryUsed += exec.memory
}
- def removeExecutor(exec: ExecutorInfo) {
+ def removeExecutor(exec: ExecutorDesc) {
if (executors.contains(exec.fullId)) {
executors -= exec.fullId
coresUsed -= exec.cores
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 4588c130ef439..3aae2b95d7396 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -27,7 +27,7 @@ import org.json4s.JValue
import org.apache.spark.deploy.{ExecutorState, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
-import org.apache.spark.deploy.master.ExecutorInfo
+import org.apache.spark.deploy.master.ExecutorDesc
import org.apache.spark.ui.{UIUtils, WebUIPage}
import org.apache.spark.util.Utils
@@ -109,7 +109,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app
UIUtils.basicSparkPage(content, "Application: " + app.desc.name)
}
- private def executorRow(executor: ExecutorInfo): Seq[Node] = {
+ private def executorRow(executor: ExecutorDesc): Seq[Node] = {
{executor.id} |
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 7ca3b08a28728..b47a081053e77 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -46,19 +46,19 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
val state = Await.result(stateFuture, timeout)
- val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory")
+ val workerHeaders = Seq("Worker Id", "Address", "State", "Cores", "Memory")
val workers = state.workers.sortBy(_.id)
val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
- val appHeaders = Seq("ID", "Name", "Cores", "Memory per Node", "Submitted Time", "User",
- "State", "Duration")
+ val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time",
+ "User", "State", "Duration")
val activeApps = state.activeApps.sortBy(_.startTime).reverse
val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps)
val completedApps = state.completedApps.sortBy(_.endTime).reverse
val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps)
- val driverHeaders = Seq("ID", "Submitted Time", "Worker", "State", "Cores", "Memory",
- "Main Class")
+ val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores",
+ "Memory", "Main Class")
val activeDrivers = state.activeDrivers.sortBy(_.startTime).reverse
val activeDriversTable = UIUtils.listingTable(driverHeaders, driverRow, activeDrivers)
val completedDrivers = state.completedDrivers.sortBy(_.startTime).reverse
@@ -73,6 +73,14 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
- URL: {state.uri}
+ {
+ state.restUri.map { uri =>
+ -
+ REST URL: {uri}
+ (cluster mode)
+
+ }.getOrElse { Seq.empty }
+ }
- Workers: {state.workers.size}
- Cores: {state.workers.map(_.cores).sum} Total,
{state.workers.map(_.coresUsed).sum} Used
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index d86ec1e03e45c..73400c5affb5d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -41,8 +41,6 @@ class MasterWebUI(val master: Master, requestedPort: Int)
attachPage(new HistoryNotFoundPage(this))
attachPage(new MasterPage(this))
attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static"))
- master.masterMetricsSystem.getServletHandlers.foreach(attachHandler)
- master.applicationMetricsSystem.getServletHandlers.foreach(attachHandler)
}
/** Attach a reconstructed UI to this Master UI. Only valid after bind(). */
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
new file mode 100644
index 0000000000000..115aa5278bb62
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestClient.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import java.io.{DataOutputStream, FileNotFoundException}
+import java.net.{HttpURLConnection, SocketException, URL}
+
+import scala.io.Source
+
+import com.fasterxml.jackson.databind.JsonMappingException
+import com.google.common.base.Charsets
+
+import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
+
+/**
+ * A client that submits applications to the standalone Master using a REST protocol.
+ * This client is intended to communicate with the [[StandaloneRestServer]] and is
+ * currently used for cluster mode only.
+ *
+ * In protocol version v1, the REST URL takes the form http://[host:port]/v1/submissions/[action],
+ * where [action] can be one of create, kill, or status. Each type of request is represented in
+ * an HTTP message sent to the following prefixes:
+ * (1) submit - POST to /submissions/create
+ * (2) kill - POST /submissions/kill/[submissionId]
+ * (3) status - GET /submissions/status/[submissionId]
+ *
+ * In the case of (1), parameters are posted in the HTTP body in the form of JSON fields.
+ * Otherwise, the URL fully specifies the intended action of the client.
+ *
+ * Since the protocol is expected to be stable across Spark versions, existing fields cannot be
+ * added or removed, though new optional fields can be added. In the rare event that forward or
+ * backward compatibility is broken, Spark must introduce a new protocol version (e.g. v2).
+ *
+ * The client and the server must communicate using the same version of the protocol. If there
+ * is a mismatch, the server will respond with the highest protocol version it supports. A future
+ * implementation of this client can use that information to retry using the version specified
+ * by the server.
+ */
+private[spark] class StandaloneRestClient extends Logging {
+ import StandaloneRestClient._
+
+ /**
+ * Submit an application specified by the parameters in the provided request.
+ *
+ * If the submission was successful, poll the status of the submission and report
+ * it to the user. Otherwise, report the error message provided by the server.
+ */
+ def createSubmission(
+ master: String,
+ request: CreateSubmissionRequest): SubmitRestProtocolResponse = {
+ logInfo(s"Submitting a request to launch an application in $master.")
+ validateMaster(master)
+ val url = getSubmitUrl(master)
+ val response = postJson(url, request.toJson)
+ response match {
+ case s: CreateSubmissionResponse =>
+ reportSubmissionStatus(master, s)
+ handleRestResponse(s)
+ case unexpected =>
+ handleUnexpectedRestResponse(unexpected)
+ }
+ response
+ }
+
+ /** Request that the server kill the specified submission. */
+ def killSubmission(master: String, submissionId: String): SubmitRestProtocolResponse = {
+ logInfo(s"Submitting a request to kill submission $submissionId in $master.")
+ validateMaster(master)
+ val response = post(getKillUrl(master, submissionId))
+ response match {
+ case k: KillSubmissionResponse => handleRestResponse(k)
+ case unexpected => handleUnexpectedRestResponse(unexpected)
+ }
+ response
+ }
+
+ /** Request the status of a submission from the server. */
+ def requestSubmissionStatus(
+ master: String,
+ submissionId: String,
+ quiet: Boolean = false): SubmitRestProtocolResponse = {
+ logInfo(s"Submitting a request for the status of submission $submissionId in $master.")
+ validateMaster(master)
+ val response = get(getStatusUrl(master, submissionId))
+ response match {
+ case s: SubmissionStatusResponse => if (!quiet) { handleRestResponse(s) }
+ case unexpected => handleUnexpectedRestResponse(unexpected)
+ }
+ response
+ }
+
+ /** Construct a message that captures the specified parameters for submitting an application. */
+ def constructSubmitRequest(
+ appResource: String,
+ mainClass: String,
+ appArgs: Array[String],
+ sparkProperties: Map[String, String],
+ environmentVariables: Map[String, String]): CreateSubmissionRequest = {
+ val message = new CreateSubmissionRequest
+ message.clientSparkVersion = sparkVersion
+ message.appResource = appResource
+ message.mainClass = mainClass
+ message.appArgs = appArgs
+ message.sparkProperties = sparkProperties
+ message.environmentVariables = environmentVariables
+ message.validate()
+ message
+ }
+
+ /** Send a GET request to the specified URL. */
+ private def get(url: URL): SubmitRestProtocolResponse = {
+ logDebug(s"Sending GET request to server at $url.")
+ val conn = url.openConnection().asInstanceOf[HttpURLConnection]
+ conn.setRequestMethod("GET")
+ readResponse(conn)
+ }
+
+ /** Send a POST request to the specified URL. */
+ private def post(url: URL): SubmitRestProtocolResponse = {
+ logDebug(s"Sending POST request to server at $url.")
+ val conn = url.openConnection().asInstanceOf[HttpURLConnection]
+ conn.setRequestMethod("POST")
+ readResponse(conn)
+ }
+
+ /** Send a POST request with the given JSON as the body to the specified URL. */
+ private def postJson(url: URL, json: String): SubmitRestProtocolResponse = {
+ logDebug(s"Sending POST request to server at $url:\n$json")
+ val conn = url.openConnection().asInstanceOf[HttpURLConnection]
+ conn.setRequestMethod("POST")
+ conn.setRequestProperty("Content-Type", "application/json")
+ conn.setRequestProperty("charset", "utf-8")
+ conn.setDoOutput(true)
+ val out = new DataOutputStream(conn.getOutputStream)
+ out.write(json.getBytes(Charsets.UTF_8))
+ out.close()
+ readResponse(conn)
+ }
+
+ /**
+ * Read the response from the server and return it as a validated [[SubmitRestProtocolResponse]].
+ * If the response represents an error, report the embedded message to the user.
+ */
+ private def readResponse(connection: HttpURLConnection): SubmitRestProtocolResponse = {
+ try {
+ val responseJson = Source.fromInputStream(connection.getInputStream).mkString
+ logDebug(s"Response from the server:\n$responseJson")
+ val response = SubmitRestProtocolMessage.fromJson(responseJson)
+ response.validate()
+ response match {
+ // If the response is an error, log the message
+ case error: ErrorResponse =>
+ logError(s"Server responded with error:\n${error.message}")
+ error
+ // Otherwise, simply return the response
+ case response: SubmitRestProtocolResponse => response
+ case unexpected =>
+ throw new SubmitRestProtocolException(
+ s"Message received from server was not a response:\n${unexpected.toJson}")
+ }
+ } catch {
+ case unreachable @ (_: FileNotFoundException | _: SocketException) =>
+ throw new SubmitRestConnectionException(
+ s"Unable to connect to server ${connection.getURL}", unreachable)
+ case malformed @ (_: SubmitRestProtocolException | _: JsonMappingException) =>
+ throw new SubmitRestProtocolException(
+ "Malformed response received from server", malformed)
+ }
+ }
+
+ /** Return the REST URL for creating a new submission. */
+ private def getSubmitUrl(master: String): URL = {
+ val baseUrl = getBaseUrl(master)
+ new URL(s"$baseUrl/create")
+ }
+
+ /** Return the REST URL for killing an existing submission. */
+ private def getKillUrl(master: String, submissionId: String): URL = {
+ val baseUrl = getBaseUrl(master)
+ new URL(s"$baseUrl/kill/$submissionId")
+ }
+
+ /** Return the REST URL for requesting the status of an existing submission. */
+ private def getStatusUrl(master: String, submissionId: String): URL = {
+ val baseUrl = getBaseUrl(master)
+ new URL(s"$baseUrl/status/$submissionId")
+ }
+
+ /** Return the base URL for communicating with the server, including the protocol version. */
+ private def getBaseUrl(master: String): String = {
+ val masterUrl = master.stripPrefix("spark://").stripSuffix("/")
+ s"http://$masterUrl/$PROTOCOL_VERSION/submissions"
+ }
+
+ /** Throw an exception if this is not standalone mode. */
+ private def validateMaster(master: String): Unit = {
+ if (!master.startsWith("spark://")) {
+ throw new IllegalArgumentException("This REST client is only supported in standalone mode.")
+ }
+ }
+
+ /** Report the status of a newly created submission. */
+ private def reportSubmissionStatus(
+ master: String,
+ submitResponse: CreateSubmissionResponse): Unit = {
+ if (submitResponse.success) {
+ val submissionId = submitResponse.submissionId
+ if (submissionId != null) {
+ logInfo(s"Submission successfully created as $submissionId. Polling submission state...")
+ pollSubmissionStatus(master, submissionId)
+ } else {
+ // should never happen
+ logError("Application successfully submitted, but submission ID was not provided!")
+ }
+ } else {
+ val failMessage = Option(submitResponse.message).map { ": " + _ }.getOrElse("")
+ logError("Application submission failed" + failMessage)
+ }
+ }
+
+ /**
+ * Poll the status of the specified submission and log it.
+ * This retries up to a fixed number of times before giving up.
+ */
+ private def pollSubmissionStatus(master: String, submissionId: String): Unit = {
+ (1 to REPORT_DRIVER_STATUS_MAX_TRIES).foreach { _ =>
+ val response = requestSubmissionStatus(master, submissionId, quiet = true)
+ val statusResponse = response match {
+ case s: SubmissionStatusResponse => s
+ case _ => return // unexpected type, let upstream caller handle it
+ }
+ if (statusResponse.success) {
+ val driverState = Option(statusResponse.driverState)
+ val workerId = Option(statusResponse.workerId)
+ val workerHostPort = Option(statusResponse.workerHostPort)
+ val exception = Option(statusResponse.message)
+ // Log driver state, if present
+ driverState match {
+ case Some(state) => logInfo(s"State of driver $submissionId is now $state.")
+ case _ => logError(s"State of driver $submissionId was not found!")
+ }
+ // Log worker node, if present
+ (workerId, workerHostPort) match {
+ case (Some(id), Some(hp)) => logInfo(s"Driver is running on worker $id at $hp.")
+ case _ =>
+ }
+ // Log exception stack trace, if present
+ exception.foreach { e => logError(e) }
+ return
+ }
+ Thread.sleep(REPORT_DRIVER_STATUS_INTERVAL)
+ }
+ logError(s"Error: Master did not recognize driver $submissionId.")
+ }
+
+ /** Log the response sent by the server in the REST application submission protocol. */
+ private def handleRestResponse(response: SubmitRestProtocolResponse): Unit = {
+ logInfo(s"Server responded with ${response.messageType}:\n${response.toJson}")
+ }
+
+ /** Log an appropriate error if the response sent by the server is not of the expected type. */
+ private def handleUnexpectedRestResponse(unexpected: SubmitRestProtocolResponse): Unit = {
+ logError(s"Error: Server responded with message of unexpected type ${unexpected.messageType}.")
+ }
+}
+
+private[spark] object StandaloneRestClient {
+ val REPORT_DRIVER_STATUS_INTERVAL = 1000
+ val REPORT_DRIVER_STATUS_MAX_TRIES = 10
+ val PROTOCOL_VERSION = "v1"
+
+ /** Submit an application, assuming Spark parameters are specified through system properties. */
+ def main(args: Array[String]): Unit = {
+ if (args.size < 2) {
+ sys.error("Usage: StandaloneRestClient [app resource] [main class] [app args*]")
+ sys.exit(1)
+ }
+ val appResource = args(0)
+ val mainClass = args(1)
+ val appArgs = args.slice(2, args.size)
+ val conf = new SparkConf
+ val master = conf.getOption("spark.master").getOrElse {
+ throw new IllegalArgumentException("'spark.master' must be set.")
+ }
+ val sparkProperties = conf.getAll.toMap
+ val environmentVariables = sys.env.filter { case (k, _) => k.startsWith("SPARK_") }
+ val client = new StandaloneRestClient
+ val submitRequest = client.constructSubmitRequest(
+ appResource, mainClass, appArgs, sparkProperties, environmentVariables)
+ client.createSubmission(master, submitRequest)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
new file mode 100644
index 0000000000000..2033d67e1f394
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala
@@ -0,0 +1,449 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import java.io.{DataOutputStream, File}
+import java.net.InetSocketAddress
+import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse}
+
+import scala.io.Source
+
+import akka.actor.ActorRef
+import com.fasterxml.jackson.databind.JsonMappingException
+import com.google.common.base.Charsets
+import org.eclipse.jetty.server.Server
+import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler}
+import org.eclipse.jetty.util.thread.QueuedThreadPool
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion}
+import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription}
+import org.apache.spark.deploy.ClientArguments._
+
+/**
+ * A server that responds to requests submitted by the [[StandaloneRestClient]].
+ * This is intended to be embedded in the standalone Master and used in cluster mode only.
+ *
+ * This server responds with different HTTP codes depending on the situation:
+ * 200 OK - Request was processed successfully
+ * 400 BAD REQUEST - Request was malformed, not successfully validated, or of unexpected type
+ * 468 UNKNOWN PROTOCOL VERSION - Request specified a protocol this server does not understand
+ * 500 INTERNAL SERVER ERROR - Server throws an exception internally while processing the request
+ *
+ * The server always includes a JSON representation of the relevant [[SubmitRestProtocolResponse]]
+ * in the HTTP body. If an error occurs, however, the server will include an [[ErrorResponse]]
+ * instead of the one expected by the client. If the construction of this error response itself
+ * fails, the response will consist of an empty body with a response code that indicates internal
+ * server error.
+ *
+ * @param host the address this server should bind to
+ * @param requestedPort the port this server will attempt to bind to
+ * @param masterActor reference to the Master actor to which requests can be sent
+ * @param masterUrl the URL of the Master new drivers will attempt to connect to
+ * @param masterConf the conf used by the Master
+ */
+private[spark] class StandaloneRestServer(
+ host: String,
+ requestedPort: Int,
+ masterActor: ActorRef,
+ masterUrl: String,
+ masterConf: SparkConf)
+ extends Logging {
+
+ import StandaloneRestServer._
+
+ private var _server: Option[Server] = None
+ private val baseContext = s"/$PROTOCOL_VERSION/submissions"
+
+ // A mapping from servlets to the URL prefixes they are responsible for
+ private val servletToContext = Map[StandaloneRestServlet, String](
+ new SubmitRequestServlet(masterActor, masterUrl, masterConf) -> s"$baseContext/create/*",
+ new KillRequestServlet(masterActor, masterConf) -> s"$baseContext/kill/*",
+ new StatusRequestServlet(masterActor, masterConf) -> s"$baseContext/status/*",
+ new ErrorServlet -> "/" // default handler
+ )
+
+ /** Start the server and return the bound port. */
+ def start(): Int = {
+ val (server, boundPort) = Utils.startServiceOnPort[Server](requestedPort, doStart, masterConf)
+ _server = Some(server)
+ logInfo(s"Started REST server for submitting applications on port $boundPort")
+ boundPort
+ }
+
+ /**
+ * Map the servlets to their corresponding contexts and attach them to a server.
+ * Return a 2-tuple of the started server and the bound port.
+ */
+ private def doStart(startPort: Int): (Server, Int) = {
+ val server = new Server(new InetSocketAddress(host, startPort))
+ val threadPool = new QueuedThreadPool
+ threadPool.setDaemon(true)
+ server.setThreadPool(threadPool)
+ val mainHandler = new ServletContextHandler
+ mainHandler.setContextPath("/")
+ servletToContext.foreach { case (servlet, prefix) =>
+ mainHandler.addServlet(new ServletHolder(servlet), prefix)
+ }
+ server.setHandler(mainHandler)
+ server.start()
+ val boundPort = server.getConnectors()(0).getLocalPort
+ (server, boundPort)
+ }
+
+ def stop(): Unit = {
+ _server.foreach(_.stop())
+ }
+}
+
+private object StandaloneRestServer {
+ val PROTOCOL_VERSION = StandaloneRestClient.PROTOCOL_VERSION
+ val SC_UNKNOWN_PROTOCOL_VERSION = 468
+}
+
+/**
+ * An abstract servlet for handling requests passed to the [[StandaloneRestServer]].
+ */
+private abstract class StandaloneRestServlet extends HttpServlet with Logging {
+
+ /** Service a request. If an exception is thrown in the process, indicate server error. */
+ protected override def service(
+ request: HttpServletRequest,
+ response: HttpServletResponse): Unit = {
+ try {
+ super.service(request, response)
+ } catch {
+ case e: Exception =>
+ logError("Exception while handling request", e)
+ response.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
+ }
+ }
+
+ /**
+ * Serialize the given response message to JSON and send it through the response servlet.
+ * This validates the response before sending it to ensure it is properly constructed.
+ */
+ protected def sendResponse(
+ responseMessage: SubmitRestProtocolResponse,
+ responseServlet: HttpServletResponse): Unit = {
+ val message = validateResponse(responseMessage, responseServlet)
+ responseServlet.setContentType("application/json")
+ responseServlet.setCharacterEncoding("utf-8")
+ responseServlet.setStatus(HttpServletResponse.SC_OK)
+ val content = message.toJson.getBytes(Charsets.UTF_8)
+ val out = new DataOutputStream(responseServlet.getOutputStream)
+ out.write(content)
+ out.close()
+ }
+
+ /**
+ * Return any fields in the client request message that the server does not know about.
+ *
+ * The mechanism for this is to reconstruct the JSON on the server side and compare the
+ * diff between this JSON and the one generated on the client side. Any fields that are
+ * only in the client JSON are treated as unexpected.
+ */
+ protected def findUnknownFields(
+ requestJson: String,
+ requestMessage: SubmitRestProtocolMessage): Array[String] = {
+ val clientSideJson = parse(requestJson)
+ val serverSideJson = parse(requestMessage.toJson)
+ val Diff(_, _, unknown) = clientSideJson.diff(serverSideJson)
+ unknown match {
+ case j: JObject => j.obj.map { case (k, _) => k }.toArray
+ case _ => Array.empty[String] // No difference
+ }
+ }
+
+ /** Return a human readable String representation of the exception. */
+ protected def formatException(e: Throwable): String = {
+ val stackTraceString = e.getStackTrace.map { "\t" + _ }.mkString("\n")
+ s"$e\n$stackTraceString"
+ }
+
+ /** Construct an error message to signal the fact that an exception has been thrown. */
+ protected def handleError(message: String): ErrorResponse = {
+ val e = new ErrorResponse
+ e.serverSparkVersion = sparkVersion
+ e.message = message
+ e
+ }
+
+ /**
+ * Validate the response to ensure that it is correctly constructed.
+ *
+ * If it is, simply return the message as is. Otherwise, return an error response instead
+ * to propagate the exception back to the client and set the appropriate error code.
+ */
+ private def validateResponse(
+ responseMessage: SubmitRestProtocolResponse,
+ responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
+ try {
+ responseMessage.validate()
+ responseMessage
+ } catch {
+ case e: Exception =>
+ responseServlet.setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR)
+ handleError("Internal server error: " + formatException(e))
+ }
+ }
+}
+
+/**
+ * A servlet for handling kill requests passed to the [[StandaloneRestServer]].
+ */
+private class KillRequestServlet(masterActor: ActorRef, conf: SparkConf)
+ extends StandaloneRestServlet {
+
+ /**
+ * If a submission ID is specified in the URL, have the Master kill the corresponding
+ * driver and return an appropriate response to the client. Otherwise, return error.
+ */
+ protected override def doPost(
+ request: HttpServletRequest,
+ response: HttpServletResponse): Unit = {
+ val submissionId = request.getPathInfo.stripPrefix("/")
+ val responseMessage =
+ if (submissionId.nonEmpty) {
+ handleKill(submissionId)
+ } else {
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError("Submission ID is missing in kill request.")
+ }
+ sendResponse(responseMessage, response)
+ }
+
+ private def handleKill(submissionId: String): KillSubmissionResponse = {
+ val askTimeout = AkkaUtils.askTimeout(conf)
+ val response = AkkaUtils.askWithReply[DeployMessages.KillDriverResponse](
+ DeployMessages.RequestKillDriver(submissionId), masterActor, askTimeout)
+ val k = new KillSubmissionResponse
+ k.serverSparkVersion = sparkVersion
+ k.message = response.message
+ k.submissionId = submissionId
+ k.success = response.success
+ k
+ }
+}
+
+/**
+ * A servlet for handling status requests passed to the [[StandaloneRestServer]].
+ */
+private class StatusRequestServlet(masterActor: ActorRef, conf: SparkConf)
+ extends StandaloneRestServlet {
+
+ /**
+ * If a submission ID is specified in the URL, request the status of the corresponding
+ * driver from the Master and include it in the response. Otherwise, return error.
+ */
+ protected override def doGet(
+ request: HttpServletRequest,
+ response: HttpServletResponse): Unit = {
+ val submissionId = request.getPathInfo.stripPrefix("/")
+ val responseMessage =
+ if (submissionId.nonEmpty) {
+ handleStatus(submissionId)
+ } else {
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError("Submission ID is missing in status request.")
+ }
+ sendResponse(responseMessage, response)
+ }
+
+ private def handleStatus(submissionId: String): SubmissionStatusResponse = {
+ val askTimeout = AkkaUtils.askTimeout(conf)
+ val response = AkkaUtils.askWithReply[DeployMessages.DriverStatusResponse](
+ DeployMessages.RequestDriverStatus(submissionId), masterActor, askTimeout)
+ val message = response.exception.map { s"Exception from the cluster:\n" + formatException(_) }
+ val d = new SubmissionStatusResponse
+ d.serverSparkVersion = sparkVersion
+ d.submissionId = submissionId
+ d.success = response.found
+ d.driverState = response.state.map(_.toString).orNull
+ d.workerId = response.workerId.orNull
+ d.workerHostPort = response.workerHostPort.orNull
+ d.message = message.orNull
+ d
+ }
+}
+
+/**
+ * A servlet for handling submit requests passed to the [[StandaloneRestServer]].
+ */
+private class SubmitRequestServlet(
+ masterActor: ActorRef,
+ masterUrl: String,
+ conf: SparkConf)
+ extends StandaloneRestServlet {
+
+ /**
+ * Submit an application to the Master with parameters specified in the request.
+ *
+ * The request is assumed to be a [[SubmitRestProtocolRequest]] in the form of JSON.
+ * If the request is successfully processed, return an appropriate response to the
+ * client indicating so. Otherwise, return error instead.
+ */
+ protected override def doPost(
+ requestServlet: HttpServletRequest,
+ responseServlet: HttpServletResponse): Unit = {
+ val responseMessage =
+ try {
+ val requestMessageJson = Source.fromInputStream(requestServlet.getInputStream).mkString
+ val requestMessage = SubmitRestProtocolMessage.fromJson(requestMessageJson)
+ // The response should have already been validated on the client.
+ // In case this is not true, validate it ourselves to avoid potential NPEs.
+ requestMessage.validate()
+ handleSubmit(requestMessageJson, requestMessage, responseServlet)
+ } catch {
+ // The client failed to provide a valid JSON, so this is not our fault
+ case e @ (_: JsonMappingException | _: SubmitRestProtocolException) =>
+ responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError("Malformed request: " + formatException(e))
+ }
+ sendResponse(responseMessage, responseServlet)
+ }
+
+ /**
+ * Handle the submit request and construct an appropriate response to return to the client.
+ *
+ * This assumes that the request message is already successfully validated.
+ * If the request message is not of the expected type, return error to the client.
+ */
+ private def handleSubmit(
+ requestMessageJson: String,
+ requestMessage: SubmitRestProtocolMessage,
+ responseServlet: HttpServletResponse): SubmitRestProtocolResponse = {
+ requestMessage match {
+ case submitRequest: CreateSubmissionRequest =>
+ val askTimeout = AkkaUtils.askTimeout(conf)
+ val driverDescription = buildDriverDescription(submitRequest)
+ val response = AkkaUtils.askWithReply[DeployMessages.SubmitDriverResponse](
+ DeployMessages.RequestSubmitDriver(driverDescription), masterActor, askTimeout)
+ val submitResponse = new CreateSubmissionResponse
+ submitResponse.serverSparkVersion = sparkVersion
+ submitResponse.message = response.message
+ submitResponse.success = response.success
+ submitResponse.submissionId = response.driverId.orNull
+ val unknownFields = findUnknownFields(requestMessageJson, requestMessage)
+ if (unknownFields.nonEmpty) {
+ // If there are fields that the server does not know about, warn the client
+ submitResponse.unknownFields = unknownFields
+ }
+ submitResponse
+ case unexpected =>
+ responseServlet.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ handleError(s"Received message of unexpected type ${unexpected.messageType}.")
+ }
+ }
+
+ /**
+ * Build a driver description from the fields specified in the submit request.
+ *
+ * This involves constructing a command that takes into account memory, java options,
+ * classpath and other settings to launch the driver. This does not currently consider
+ * fields used by python applications since python is not supported in standalone
+ * cluster mode yet.
+ */
+ private def buildDriverDescription(request: CreateSubmissionRequest): DriverDescription = {
+ // Required fields, including the main class because python is not yet supported
+ val appResource = Option(request.appResource).getOrElse {
+ throw new SubmitRestMissingFieldException("Application jar is missing.")
+ }
+ val mainClass = Option(request.mainClass).getOrElse {
+ throw new SubmitRestMissingFieldException("Main class is missing.")
+ }
+
+ // Optional fields
+ val sparkProperties = request.sparkProperties
+ val driverMemory = sparkProperties.get("spark.driver.memory")
+ val driverCores = sparkProperties.get("spark.driver.cores")
+ val driverExtraJavaOptions = sparkProperties.get("spark.driver.extraJavaOptions")
+ val driverExtraClassPath = sparkProperties.get("spark.driver.extraClassPath")
+ val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath")
+ val superviseDriver = sparkProperties.get("spark.driver.supervise")
+ val appArgs = request.appArgs
+ val environmentVariables = request.environmentVariables
+
+ // Construct driver description
+ val conf = new SparkConf(false)
+ .setAll(sparkProperties)
+ .set("spark.master", masterUrl)
+ val extraClassPath = driverExtraClassPath.toSeq.flatMap(_.split(File.pathSeparator))
+ val extraLibraryPath = driverExtraLibraryPath.toSeq.flatMap(_.split(File.pathSeparator))
+ val extraJavaOpts = driverExtraJavaOptions.map(Utils.splitCommandString).getOrElse(Seq.empty)
+ val sparkJavaOpts = Utils.sparkJavaOpts(conf)
+ val javaOpts = sparkJavaOpts ++ extraJavaOpts
+ val command = new Command(
+ "org.apache.spark.deploy.worker.DriverWrapper",
+ Seq("{{WORKER_URL}}", mainClass) ++ appArgs, // args to the DriverWrapper
+ environmentVariables, extraClassPath, extraLibraryPath, javaOpts)
+ val actualDriverMemory = driverMemory.map(Utils.memoryStringToMb).getOrElse(DEFAULT_MEMORY)
+ val actualDriverCores = driverCores.map(_.toInt).getOrElse(DEFAULT_CORES)
+ val actualSuperviseDriver = superviseDriver.map(_.toBoolean).getOrElse(DEFAULT_SUPERVISE)
+ new DriverDescription(
+ appResource, actualDriverMemory, actualDriverCores, actualSuperviseDriver, command)
+ }
+}
+
+/**
+ * A default servlet that handles error cases that are not captured by other servlets.
+ */
+private class ErrorServlet extends StandaloneRestServlet {
+ private val serverVersion = StandaloneRestServer.PROTOCOL_VERSION
+
+ /** Service a faulty request by returning an appropriate error message to the client. */
+ protected override def service(
+ request: HttpServletRequest,
+ response: HttpServletResponse): Unit = {
+ val path = request.getPathInfo
+ val parts = path.stripPrefix("/").split("/").toSeq
+ var versionMismatch = false
+ var msg =
+ parts match {
+ case Nil =>
+ // http://host:port/
+ "Missing protocol version."
+ case `serverVersion` :: Nil =>
+ // http://host:port/correct-version
+ "Missing the /submissions prefix."
+ case `serverVersion` :: "submissions" :: Nil =>
+ // http://host:port/correct-version/submissions
+ "Missing an action: please specify one of /create, /kill, or /status."
+ case unknownVersion :: _ =>
+ // http://host:port/unknown-version/*
+ versionMismatch = true
+ s"Unknown protocol version '$unknownVersion'."
+ case _ =>
+ // never reached
+ s"Malformed path $path."
+ }
+ msg += s" Please submit requests through http://[host]:[port]/$serverVersion/submissions/..."
+ val error = handleError(msg)
+ // If there is a version mismatch, include the highest protocol version that
+ // this server supports in case the client wants to retry with our version
+ if (versionMismatch) {
+ error.highestProtocolVersion = serverVersion
+ response.setStatus(StandaloneRestServer.SC_UNKNOWN_PROTOCOL_VERSION)
+ } else {
+ response.setStatus(HttpServletResponse.SC_BAD_REQUEST)
+ }
+ sendResponse(error, response)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala
new file mode 100644
index 0000000000000..d7a0bdbe10778
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolException.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+/**
+ * An exception thrown in the REST application submission protocol.
+ */
+private[spark] class SubmitRestProtocolException(message: String, cause: Throwable = null)
+ extends Exception(message, cause)
+
+/**
+ * An exception thrown if a field is missing from a [[SubmitRestProtocolMessage]].
+ */
+private[spark] class SubmitRestMissingFieldException(message: String)
+ extends SubmitRestProtocolException(message)
+
+/**
+ * An exception thrown if the REST client cannot reach the REST server.
+ */
+private[spark] class SubmitRestConnectionException(message: String, cause: Throwable)
+ extends SubmitRestProtocolException(message, cause)
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
new file mode 100644
index 0000000000000..b877898231e3e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolMessage.scala
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import scala.util.Try
+
+import com.fasterxml.jackson.annotation._
+import com.fasterxml.jackson.annotation.JsonAutoDetect.Visibility
+import com.fasterxml.jackson.annotation.JsonInclude.Include
+import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper, SerializationFeature}
+import com.fasterxml.jackson.module.scala.DefaultScalaModule
+import org.json4s.JsonAST._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.util.Utils
+
+/**
+ * An abstract message exchanged in the REST application submission protocol.
+ *
+ * This message is intended to be serialized to and deserialized from JSON in the exchange.
+ * Each message can either be a request or a response and consists of three common fields:
+ * (1) the action, which fully specifies the type of the message
+ * (2) the Spark version of the client / server
+ * (3) an optional message
+ */
+@JsonInclude(Include.NON_NULL)
+@JsonAutoDetect(getterVisibility = Visibility.ANY, setterVisibility = Visibility.ANY)
+@JsonPropertyOrder(alphabetic = true)
+private[spark] abstract class SubmitRestProtocolMessage {
+ @JsonIgnore
+ val messageType = Utils.getFormattedClassName(this)
+
+ val action: String = messageType
+ var message: String = null
+
+ // For JSON deserialization
+ private def setAction(a: String): Unit = { }
+
+ /**
+ * Serialize the message to JSON.
+ * This also ensures that the message is valid and its fields are in the expected format.
+ */
+ def toJson: String = {
+ validate()
+ SubmitRestProtocolMessage.mapper.writeValueAsString(this)
+ }
+
+ /**
+ * Assert the validity of the message.
+ * If the validation fails, throw a [[SubmitRestProtocolException]].
+ */
+ final def validate(): Unit = {
+ try {
+ doValidate()
+ } catch {
+ case e: Exception =>
+ throw new SubmitRestProtocolException(s"Validation of message $messageType failed!", e)
+ }
+ }
+
+ /** Assert the validity of the message */
+ protected def doValidate(): Unit = {
+ if (action == null) {
+ throw new SubmitRestMissingFieldException(s"The action field is missing in $messageType")
+ }
+ }
+
+ /** Assert that the specified field is set in this message. */
+ protected def assertFieldIsSet[T](value: T, name: String): Unit = {
+ if (value == null) {
+ throw new SubmitRestMissingFieldException(s"'$name' is missing in message $messageType.")
+ }
+ }
+
+ /**
+ * Assert a condition when validating this message.
+ * If the assertion fails, throw a [[SubmitRestProtocolException]].
+ */
+ protected def assert(condition: Boolean, failMessage: String): Unit = {
+ if (!condition) { throw new SubmitRestProtocolException(failMessage) }
+ }
+}
+
+/**
+ * Helper methods to process serialized [[SubmitRestProtocolMessage]]s.
+ */
+private[spark] object SubmitRestProtocolMessage {
+ private val packagePrefix = this.getClass.getPackage.getName
+ private val mapper = new ObjectMapper()
+ .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
+ .enable(SerializationFeature.INDENT_OUTPUT)
+ .registerModule(DefaultScalaModule)
+
+ /**
+ * Parse the value of the action field from the given JSON.
+ * If the action field is not found, throw a [[SubmitRestMissingFieldException]].
+ */
+ def parseAction(json: String): String = {
+ parse(json).asInstanceOf[JObject].obj
+ .find { case (f, _) => f == "action" }
+ .map { case (_, v) => v.asInstanceOf[JString].s }
+ .getOrElse {
+ throw new SubmitRestMissingFieldException(s"Action field not found in JSON:\n$json")
+ }
+ }
+
+ /**
+ * Construct a [[SubmitRestProtocolMessage]] from its JSON representation.
+ *
+ * This method first parses the action from the JSON and uses it to infer the message type.
+ * Note that the action must represent one of the [[SubmitRestProtocolMessage]]s defined in
+ * this package. Otherwise, a [[ClassNotFoundException]] will be thrown.
+ */
+ def fromJson(json: String): SubmitRestProtocolMessage = {
+ val className = parseAction(json)
+ val clazz = Class.forName(packagePrefix + "." + className)
+ .asSubclass[SubmitRestProtocolMessage](classOf[SubmitRestProtocolMessage])
+ fromJson(json, clazz)
+ }
+
+ /**
+ * Construct a [[SubmitRestProtocolMessage]] from its JSON representation.
+ *
+ * This method determines the type of the message from the class provided instead of
+ * inferring it from the action field. This is useful for deserializing JSON that
+ * represents custom user-defined messages.
+ */
+ def fromJson[T <: SubmitRestProtocolMessage](json: String, clazz: Class[T]): T = {
+ mapper.readValue(json, clazz)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala
new file mode 100644
index 0000000000000..9e1fd8c40cabd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolRequest.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import scala.util.Try
+
+import org.apache.spark.util.Utils
+
+/**
+ * An abstract request sent from the client in the REST application submission protocol.
+ */
+private[spark] abstract class SubmitRestProtocolRequest extends SubmitRestProtocolMessage {
+ var clientSparkVersion: String = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(clientSparkVersion, "clientSparkVersion")
+ }
+}
+
+/**
+ * A request to launch a new application in the REST application submission protocol.
+ */
+private[spark] class CreateSubmissionRequest extends SubmitRestProtocolRequest {
+ var appResource: String = null
+ var mainClass: String = null
+ var appArgs: Array[String] = null
+ var sparkProperties: Map[String, String] = null
+ var environmentVariables: Map[String, String] = null
+
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assert(sparkProperties != null, "No Spark properties set!")
+ assertFieldIsSet(appResource, "appResource")
+ assertPropertyIsSet("spark.app.name")
+ assertPropertyIsBoolean("spark.driver.supervise")
+ assertPropertyIsNumeric("spark.driver.cores")
+ assertPropertyIsNumeric("spark.cores.max")
+ assertPropertyIsMemory("spark.driver.memory")
+ assertPropertyIsMemory("spark.executor.memory")
+ }
+
+ private def assertPropertyIsSet(key: String): Unit =
+ assertFieldIsSet(sparkProperties.getOrElse(key, null), key)
+
+ private def assertPropertyIsBoolean(key: String): Unit =
+ assertProperty[Boolean](key, "boolean", _.toBoolean)
+
+ private def assertPropertyIsNumeric(key: String): Unit =
+ assertProperty[Int](key, "numeric", _.toInt)
+
+ private def assertPropertyIsMemory(key: String): Unit =
+ assertProperty[Int](key, "memory", Utils.memoryStringToMb)
+
+ /** Assert that a Spark property can be converted to a certain type. */
+ private def assertProperty[T](key: String, valueType: String, convert: (String => T)): Unit = {
+ sparkProperties.get(key).foreach { value =>
+ Try(convert(value)).getOrElse {
+ throw new SubmitRestProtocolException(
+ s"Property '$key' expected $valueType value: actual was '$value'.")
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala
new file mode 100644
index 0000000000000..16dfe041d4bea
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/rest/SubmitRestProtocolResponse.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.rest
+
+import java.lang.Boolean
+
+/**
+ * An abstract response sent from the server in the REST application submission protocol.
+ */
+private[spark] abstract class SubmitRestProtocolResponse extends SubmitRestProtocolMessage {
+ var serverSparkVersion: String = null
+ var success: Boolean = null
+ var unknownFields: Array[String] = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(serverSparkVersion, "serverSparkVersion")
+ }
+}
+
+/**
+ * A response to a [[CreateSubmissionRequest]] in the REST application submission protocol.
+ */
+private[spark] class CreateSubmissionResponse extends SubmitRestProtocolResponse {
+ var submissionId: String = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(success, "success")
+ }
+}
+
+/**
+ * A response to a kill request in the REST application submission protocol.
+ */
+private[spark] class KillSubmissionResponse extends SubmitRestProtocolResponse {
+ var submissionId: String = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(submissionId, "submissionId")
+ assertFieldIsSet(success, "success")
+ }
+}
+
+/**
+ * A response to a status request in the REST application submission protocol.
+ */
+private[spark] class SubmissionStatusResponse extends SubmitRestProtocolResponse {
+ var submissionId: String = null
+ var driverState: String = null
+ var workerId: String = null
+ var workerHostPort: String = null
+
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(submissionId, "submissionId")
+ assertFieldIsSet(success, "success")
+ }
+}
+
+/**
+ * An error response message used in the REST application submission protocol.
+ */
+private[spark] class ErrorResponse extends SubmitRestProtocolResponse {
+ // The highest protocol version that the server knows about
+ // This is set when the client specifies an unknown version
+ var highestProtocolVersion: String = null
+ protected override def doValidate(): Unit = {
+ super.doValidate()
+ assertFieldIsSet(message, "message")
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index 28e9662db5da9..3e013c32096c5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -115,9 +115,19 @@ object CommandUtils extends Logging {
val userClassPath = command.classPathEntries ++ Seq(classPath)
val javaVersion = System.getProperty("java.version")
- val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None
+
+ val javaOpts = workerLocalOpts ++ command.javaOpts
+
+ val permGenOpt =
+ if (!javaVersion.startsWith("1.8") && !javaOpts.exists(_.startsWith("-XX:MaxPermSize="))) {
+ // do not specify -XX:MaxPermSize if it was already specified by user
+ Some("-XX:MaxPermSize=128m")
+ } else {
+ None
+ }
+
Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++
- permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts
+ permGenOpt ++ javaOpts ++ memoryOpts
}
/** Spawn a thread that will redirect a given stream to a file */
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index f4fedc6327ab9..0add3064da452 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -26,7 +26,7 @@ import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.spark.{SparkConf, Logging}
-import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
+import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
import org.apache.spark.util.logging.FileAppender
@@ -43,10 +43,12 @@ private[spark] class ExecutorRunner(
val worker: ActorRef,
val workerId: String,
val host: String,
+ val webUiPort: Int,
val sparkHome: File,
val executorDir: File,
val workerUrl: String,
val conf: SparkConf,
+ val appLocalDirs: Seq[String],
var state: ExecutorState.Value)
extends Logging {
@@ -77,7 +79,7 @@ private[spark] class ExecutorRunner(
/**
* Kill executor process, wait for exit and notify worker to update resource status.
*
- * @param message the exception message which caused the executor's death
+ * @param message the exception message which caused the executor's death
*/
private def killProcess(message: Option[String]) {
var exitCode: Option[Int] = None
@@ -129,9 +131,16 @@ private[spark] class ExecutorRunner(
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
builder.directory(executorDir)
+ builder.environment.put("SPARK_LOCAL_DIRS", appLocalDirs.mkString(","))
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0")
+
+ // Add webUI log urls
+ val baseUrl = s"http://$host:$webUiPort/logPage/?appId=$appId&executorId=$execId&logType="
+ builder.environment.put("SPARK_LOG_URL_STDERR", s"${baseUrl}stderr")
+ builder.environment.put("SPARK_LOG_URL_STDOUT", s"${baseUrl}stdout")
+
process = builder.start()
val header = "Spark Executor Command: %s\n%s\n\n".format(
command.mkString("\"", "\" \"", "\""), "=" * 40)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index eb11163538b20..10929eb516041 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -23,7 +23,7 @@ import java.text.SimpleDateFormat
import java.util.{UUID, Date}
import scala.collection.JavaConversions._
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{HashMap, HashSet}
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random
@@ -31,8 +31,8 @@ import scala.util.Random
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
-import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
+import org.apache.spark.{Logging, SecurityManager, SparkConf}
+import org.apache.spark.deploy.{Command, ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.{DriverState, Master}
import org.apache.spark.deploy.worker.ui.WorkerWebUI
@@ -40,7 +40,7 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils}
/**
- * @param masterUrls Each url should look like spark://host:port.
+ * @param masterAkkaUrls Each url should be a valid akka url.
*/
private[spark] class Worker(
host: String,
@@ -48,7 +48,7 @@ private[spark] class Worker(
webUiPort: Int,
cores: Int,
memory: Int,
- masterUrls: Array[String],
+ masterAkkaUrls: Array[String],
actorSystemName: String,
actorName: String,
workDirPath: String = null,
@@ -93,7 +93,12 @@ private[spark] class Worker(
var masterAddress: Address = null
var activeMasterUrl: String = ""
var activeMasterWebUiUrl : String = ""
- val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName)
+ val akkaUrl = AkkaUtils.address(
+ AkkaUtils.protocol(context.system),
+ actorSystemName,
+ host,
+ port,
+ actorName)
@volatile var registered = false
@volatile var connected = false
val workerId = generateWorkerId()
@@ -109,6 +114,8 @@ private[spark] class Worker(
val finishedExecutors = new HashMap[String, ExecutorRunner]
val drivers = new HashMap[String, DriverRunner]
val finishedDrivers = new HashMap[String, DriverRunner]
+ val appDirectories = new HashMap[String, Seq[String]]
+ val finishedApps = new HashSet[String]
// The shuffle service is not actually started unless configured.
val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr)
@@ -153,6 +160,7 @@ private[spark] class Worker(
assert(!registered)
logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format(
host, port, cores, Utils.megabytesToString(memory)))
+ logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}")
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
@@ -163,18 +171,17 @@ private[spark] class Worker(
metricsSystem.registerSource(workerSource)
metricsSystem.start()
+ // Attach the worker metrics servlet handler to the web ui after the metrics system is started.
+ metricsSystem.getServletHandlers.foreach(webUi.attachHandler)
}
def changeMaster(url: String, uiUrl: String) {
+ // activeMasterUrl it's a valid Spark url since we receive it from master.
activeMasterUrl = url
activeMasterWebUiUrl = uiUrl
- master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
- masterAddress = activeMasterUrl match {
- case Master.sparkUrlRegex(_host, _port) =>
- Address("akka.tcp", Master.systemName, _host, _port.toInt)
- case x =>
- throw new SparkException("Invalid spark URL: " + x)
- }
+ master = context.actorSelection(
+ Master.toAkkaUrl(activeMasterUrl, AkkaUtils.protocol(context.system)))
+ masterAddress = Master.toAkkaAddress(activeMasterUrl, AkkaUtils.protocol(context.system))
connected = true
// Cancel any outstanding re-registration attempts because we found a new master
registrationRetryTimer.foreach(_.cancel())
@@ -182,9 +189,9 @@ private[spark] class Worker(
}
private def tryRegisterAllMasters() {
- for (masterUrl <- masterUrls) {
- logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
+ for (masterAkkaUrl <- masterAkkaUrls) {
+ logInfo("Connecting to master " + masterAkkaUrl + "...")
+ val actor = context.actorSelection(masterAkkaUrl)
actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
}
}
@@ -292,7 +299,7 @@ private[spark] class Worker(
val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir)
dir.isDirectory && !isAppStillRunning &&
!Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS)
- }.foreach { dir =>
+ }.foreach { dir =>
logInfo(s"Removing directory: ${dir.getPath}")
Utils.deleteRecursively(dir)
}
@@ -337,8 +344,30 @@ private[spark] class Worker(
throw new IOException("Failed to create directory " + executorDir)
}
- val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
- self, workerId, host, sparkHome, executorDir, akkaUrl, conf, ExecutorState.LOADING)
+ // Create local dirs for the executor. These are passed to the executor via the
+ // SPARK_LOCAL_DIRS environment variable, and deleted by the Worker when the
+ // application finishes.
+ val appLocalDirs = appDirectories.get(appId).getOrElse {
+ Utils.getOrCreateLocalRootDirs(conf).map { dir =>
+ Utils.createDirectory(dir).getAbsolutePath()
+ }.toSeq
+ }
+ appDirectories(appId) = appLocalDirs
+ val manager = new ExecutorRunner(
+ appId,
+ execId,
+ appDesc.copy(command = Worker.maybeUpdateSSLSettings(appDesc.command, conf)),
+ cores_,
+ memory_,
+ self,
+ workerId,
+ host,
+ webUiPort,
+ sparkHome,
+ executorDir,
+ akkaUrl,
+ conf,
+ appLocalDirs, ExecutorState.LOADING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -375,6 +404,7 @@ private[spark] class Worker(
message.map(" message " + _).getOrElse("") +
exitStatus.map(" exitStatus " + _).getOrElse(""))
}
+ maybeCleanupApplication(appId)
}
case KillExecutor(masterUrl, appId, execId) =>
@@ -393,7 +423,14 @@ private[spark] class Worker(
case LaunchDriver(driverId, driverDesc) => {
logInfo(s"Asked to launch driver $driverId")
- val driver = new DriverRunner(conf, driverId, workDir, sparkHome, driverDesc, self, akkaUrl)
+ val driver = new DriverRunner(
+ conf,
+ driverId,
+ workDir,
+ sparkHome,
+ driverDesc.copy(command = Worker.maybeUpdateSSLSettings(driverDesc.command, conf)),
+ self,
+ akkaUrl)
drivers(driverId) = driver
driver.start()
@@ -444,6 +481,9 @@ private[spark] class Worker(
case ReregisterWithMaster =>
reregisterWithMaster()
+ case ApplicationFinished(id) =>
+ finishedApps += id
+ maybeCleanupApplication(id)
}
private def masterDisconnected() {
@@ -452,6 +492,19 @@ private[spark] class Worker(
registerWithMaster()
}
+ private def maybeCleanupApplication(id: String): Unit = {
+ val shouldCleanup = finishedApps.contains(id) && !executors.values.exists(_.appId == id)
+ if (shouldCleanup) {
+ finishedApps -= id
+ appDirectories.remove(id).foreach { dirList =>
+ logInfo(s"Cleaning up local directories for application $id")
+ dirList.foreach { dir =>
+ Utils.deleteRecursively(new File(dir))
+ }
+ }
+ }
+ }
+
def generateWorkerId(): String = {
"worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port)
}
@@ -494,9 +547,32 @@ private[spark] object Worker extends Logging {
val securityMgr = new SecurityManager(conf)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port,
conf = conf, securityManager = securityMgr)
+ val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl(_, AkkaUtils.protocol(actorSystem)))
actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
- masterUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
+ masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName)
(actorSystem, boundPort)
}
+ private[spark] def isUseLocalNodeSSLConfig(cmd: Command): Boolean = {
+ val pattern = """\-Dspark\.ssl\.useNodeLocalConf\=(.+)""".r
+ val result = cmd.javaOpts.collectFirst {
+ case pattern(_result) => _result.toBoolean
+ }
+ result.getOrElse(false)
+ }
+
+ private[spark] def maybeUpdateSSLSettings(cmd: Command, conf: SparkConf): Command = {
+ val prefix = "spark.ssl."
+ val useNLC = "spark.ssl.useNodeLocalConf"
+ if (isUseLocalNodeSSLConfig(cmd)) {
+ val newJavaOpts = cmd.javaOpts
+ .filter(opt => !opt.startsWith(s"-D$prefix")) ++
+ conf.getAll.collect { case (key, value) if key.startsWith(prefix) => s"-D$key=$value" } :+
+ s"-D$useNLC=true"
+ cmd.copy(javaOpts = newJavaOpts)
+ } else {
+ cmd
+ }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index b07942a9ca729..7ac81a2d87efd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -50,7 +50,6 @@ class WorkerWebUI(
attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static"))
attachHandler(createServletHandler("/log",
(request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr))
- worker.metricsSystem.getServletHandlers.foreach(attachHandler)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 5f46f3b1f085e..3a42f8b157977 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import scala.concurrent.Await
-import akka.actor.{Actor, ActorSelection, ActorSystem, Props}
+import akka.actor.{Actor, ActorSelection, Props}
import akka.pattern.Patterns
import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
@@ -38,8 +38,7 @@ private[spark] class CoarseGrainedExecutorBackend(
executorId: String,
hostPort: String,
cores: Int,
- sparkProperties: Seq[(String, String)],
- actorSystem: ActorSystem)
+ env: SparkEnv)
extends Actor with ActorLogReceive with ExecutorBackend with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
@@ -50,16 +49,21 @@ private[spark] class CoarseGrainedExecutorBackend(
override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
driver = context.actorSelection(driverUrl)
- driver ! RegisterExecutor(executorId, hostPort, cores)
+ driver ! RegisterExecutor(executorId, hostPort, cores, extractLogUrls)
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
}
+ def extractLogUrls: Map[String, String] = {
+ val prefix = "SPARK_LOG_URL_"
+ sys.env.filterKeys(_.startsWith(prefix))
+ .map(e => (e._1.substring(prefix.length).toLowerCase, e._2))
+ }
+
override def receiveWithLogging = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
val (hostname, _) = Utils.parseHostPort(hostPort)
- executor = new Executor(executorId, hostname, sparkProperties, cores, isLocal = false,
- actorSystem)
+ executor = new Executor(executorId, hostname, env, isLocal = false)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -70,10 +74,11 @@ private[spark] class CoarseGrainedExecutorBackend(
logError("Received LaunchTask command but executor was null")
System.exit(1)
} else {
- val ser = SparkEnv.get.closureSerializer.newInstance()
+ val ser = env.closureSerializer.newInstance()
val taskDesc = ser.deserialize[TaskDescription](data.value)
logInfo("Got assigned task " + taskDesc.taskId)
- executor.launchTask(this, taskDesc.taskId, taskDesc.name, taskDesc.serializedTask)
+ executor.launchTask(this, taskId = taskDesc.taskId, attemptNumber = taskDesc.attemptNumber,
+ taskDesc.name, taskDesc.serializedTask)
}
case KillTask(taskId, _, interruptThread) =>
@@ -85,8 +90,12 @@ private[spark] class CoarseGrainedExecutorBackend(
}
case x: DisassociatedEvent =>
- logError(s"Driver $x disassociated! Shutting down.")
- System.exit(1)
+ if (x.remoteAddress == driver.anchorPath.address) {
+ logError(s"Driver $x disassociated! Shutting down.")
+ System.exit(1)
+ } else {
+ logWarning(s"Received irrelevant DisassociatedEvent $x")
+ }
case StopExecutor =>
logInfo("Driver commanded a shutdown")
@@ -120,7 +129,11 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val executorConf = new SparkConf
val port = executorConf.getInt("spark.executor.port", 0)
val (fetcher, _) = AkkaUtils.createActorSystem(
- "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf))
+ "driverPropsFetcher",
+ hostname,
+ port,
+ executorConf,
+ new SecurityManager(executorConf))
val driver = fetcher.actorSelection(driverUrl)
val timeout = AkkaUtils.askTimeout(executorConf)
val fut = Patterns.ask(driver, RetrieveSparkProps, timeout)
@@ -128,21 +141,33 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
Seq[(String, String)](("spark.app.id", appId))
fetcher.shutdown()
- // Create a new ActorSystem using driver's Spark properties to run the backend.
- val driverConf = new SparkConf().setAll(props)
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- SparkEnv.executorActorSystemName,
- hostname, port, driverConf, new SecurityManager(driverConf))
- // set it
+ // Create SparkEnv using properties we fetched from the driver.
+ val driverConf = new SparkConf()
+ for ((key, value) <- props) {
+ // this is required for SSL in standalone mode
+ if (SparkConf.isExecutorStartupConf(key)) {
+ driverConf.setIfMissing(key, value)
+ } else {
+ driverConf.set(key, value)
+ }
+ }
+ val env = SparkEnv.createExecutorEnv(
+ driverConf, executorId, hostname, port, cores, isLocal = false)
+
+ // SparkEnv sets spark.driver.port so it shouldn't be 0 anymore.
+ val boundPort = env.conf.getInt("spark.executor.port", 0)
+ assert(boundPort != 0)
+
+ // Start the CoarseGrainedExecutorBackend actor.
val sparkHostPort = hostname + ":" + boundPort
- actorSystem.actorOf(
+ env.actorSystem.actorOf(
Props(classOf[CoarseGrainedExecutorBackend],
- driverUrl, executorId, sparkHostPort, cores, props, actorSystem),
+ driverUrl, executorId, sparkHostPort, cores, env),
name = "Executor")
workerUrl.foreach { url =>
- actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
+ env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
}
- actorSystem.awaitTermination()
+ env.actorSystem.awaitTermination()
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 835157fc520aa..5141483d1e745 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -26,7 +26,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
-import akka.actor.{Props, ActorSystem}
+import akka.actor.Props
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
@@ -41,13 +41,14 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
*/
private[spark] class Executor(
executorId: String,
- slaveHostname: String,
- properties: Seq[(String, String)],
- numCores: Int,
- isLocal: Boolean = false,
- actorSystem: ActorSystem = null)
+ executorHostname: String,
+ env: SparkEnv,
+ isLocal: Boolean = false)
extends Logging
{
+
+ logInfo(s"Starting executor ID $executorId on host $executorHostname")
+
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -55,19 +56,17 @@ private[spark] class Executor(
private val EMPTY_BYTE_BUFFER = ByteBuffer.wrap(new Array[Byte](0))
+ private val conf = env.conf
+
@volatile private var isStopped = false
// No ip or host:port - just hostname
- Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ Utils.checkHost(executorHostname, "Expected executed slave to be a hostname")
// must not have port specified.
- assert (0 == Utils.parseHostPort(slaveHostname)._2)
+ assert (0 == Utils.parseHostPort(executorHostname)._2)
// Make sure the local hostname we report matches the cluster scheduler's name for this host
- Utils.setCustomHostname(slaveHostname)
-
- // Set spark.* properties from executor arg
- val conf = new SparkConf(true)
- conf.setAll(properties)
+ Utils.setCustomHostname(executorHostname)
if (!isLocal) {
// Setup an uncaught exception handler for non-local mode.
@@ -76,22 +75,14 @@ private[spark] class Executor(
Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
}
+ // Start worker thread pool
+ val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
+
val executorSource = new ExecutorSource(this, executorId)
- // Initialize Spark environment (using system properties read above)
- conf.set("spark.executor.id", executorId)
- private val env = {
- if (!isLocal) {
- val port = conf.getInt("spark.executor.port", 0)
- val _env = SparkEnv.createExecutorEnv(
- conf, executorId, slaveHostname, port, numCores, isLocal, actorSystem)
- SparkEnv.set(_env)
- _env.metricsSystem.registerSource(executorSource)
- _env.blockManager.initialize(conf.getAppId)
- _env
- } else {
- SparkEnv.get
- }
+ if (!isLocal) {
+ env.metricsSystem.registerSource(executorSource)
+ env.blockManager.initialize(conf.getAppId)
}
// Create an actor for receiving RPCs from the driver
@@ -113,17 +104,19 @@ private[spark] class Executor(
// Limit of bytes for total size of results (default is 1GB)
private val maxResultSize = Utils.getMaxResultSize(conf)
- // Start worker thread pool
- val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
-
// Maintains the list of running tasks.
private val runningTasks = new ConcurrentHashMap[Long, TaskRunner]
startDriverHeartbeater()
def launchTask(
- context: ExecutorBackend, taskId: Long, taskName: String, serializedTask: ByteBuffer) {
- val tr = new TaskRunner(context, taskId, taskName, serializedTask)
+ context: ExecutorBackend,
+ taskId: Long,
+ attemptNumber: Int,
+ taskName: String,
+ serializedTask: ByteBuffer) {
+ val tr = new TaskRunner(context, taskId = taskId, attemptNumber = attemptNumber, taskName,
+ serializedTask)
runningTasks.put(taskId, tr)
threadPool.execute(tr)
}
@@ -145,13 +138,20 @@ private[spark] class Executor(
}
}
+ private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+
class TaskRunner(
- execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
+ execBackend: ExecutorBackend,
+ val taskId: Long,
+ val attemptNumber: Int,
+ taskName: String,
+ serializedTask: ByteBuffer)
extends Runnable {
@volatile private var killed = false
@volatile var task: Task[Any] = _
@volatile var attemptedTask: Option[Task[Any]] = None
+ @volatile var startGCTime: Long = _
def kill(interruptThread: Boolean) {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
@@ -164,15 +164,13 @@ private[spark] class Executor(
override def run() {
val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
- val ser = SparkEnv.get.closureSerializer.newInstance()
+ val ser = env.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
- def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
- val startGCTime = gcTime
+ startGCTime = gcTime
try {
- Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
@@ -193,7 +191,7 @@ private[spark] class Executor(
// Run the actual task and measure its runtime.
taskStart = System.currentTimeMillis()
- val value = task.run(taskId.toInt)
+ val value = task.run(taskAttemptId = taskId, attemptNumber = attemptNumber)
val taskFinish = System.currentTimeMillis()
// If the task has been killed, let's fail it.
@@ -201,16 +199,16 @@ private[spark] class Executor(
throw new TaskKilledException
}
- val resultSer = SparkEnv.get.serializer.newInstance()
+ val resultSer = env.serializer.newInstance()
val beforeSerialization = System.currentTimeMillis()
val valueBytes = resultSer.serialize(value)
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
- m.executorDeserializeTime = taskStart - deserializeStartTime
- m.executorRunTime = taskFinish - taskStart
- m.jvmGCTime = gcTime - startGCTime
- m.resultSerializationTime = afterSerialization - beforeSerialization
+ m.setExecutorDeserializeTime(taskStart - deserializeStartTime)
+ m.setExecutorRunTime(taskFinish - taskStart)
+ m.setJvmGCTime(gcTime - startGCTime)
+ m.setResultSerializationTime(afterSerialization - beforeSerialization)
}
val accumUpdates = Accumulators.values
@@ -261,8 +259,8 @@ private[spark] class Executor(
val serviceTime = System.currentTimeMillis() - taskStart
val metrics = attemptedTask.flatMap(t => t.metrics)
for (m <- metrics) {
- m.executorRunTime = serviceTime
- m.jvmGCTime = gcTime - startGCTime
+ m.setExecutorRunTime(serviceTime)
+ m.setJvmGCTime(gcTime - startGCTime)
}
val reason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
@@ -278,6 +276,8 @@ private[spark] class Executor(
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
+ // Release memory used by this thread for accumulators
+ Accumulators.clear()
runningTasks.remove(taskId)
}
}
@@ -375,10 +375,15 @@ private[spark] class Executor(
while (!isStopped) {
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
+ val curGCTime = gcTime
+
for (taskRunner <- runningTasks.values()) {
- if (!taskRunner.attemptedTask.isEmpty) {
+ if (taskRunner.attemptedTask.nonEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
- metrics.updateShuffleReadMetrics
+ metrics.updateShuffleReadMetrics()
+ metrics.updateInputMetrics()
+ metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime)
+
if (isLocal) {
// JobProgressListener will hold an reference of it during
// onExecutorMetricsUpdate(), then JobProgressListener can not see
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
index 218ed7b5d2d39..8011e75944aac 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
@@ -39,7 +39,17 @@ private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: Class
super.addURL(url)
}
override def findClass(name: String): Class[_] = {
- super.findClass(name)
+ val loaded = super.findLoadedClass(name)
+ if (loaded != null) {
+ return loaded
+ }
+ try {
+ super.findClass(name)
+ } catch {
+ case e: ClassNotFoundException => {
+ parentClassLoader.loadClass(name)
+ }
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index f15e6bc33fb41..cfd672e1d8a97 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -22,12 +22,13 @@ import java.nio.ByteBuffer
import scala.collection.JavaConversions._
import org.apache.mesos.protobuf.ByteString
-import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver, MesosNativeLibrary}
+import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver}
import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
-import org.apache.spark.{Logging, TaskState}
+import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.scheduler.cluster.mesos.{MesosTaskLaunchData}
import org.apache.spark.util.{SignalLogger, Utils}
private[spark] class MesosExecutorBackend
@@ -64,19 +65,27 @@ private[spark] class MesosExecutorBackend
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++
Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue))
+ val conf = new SparkConf(loadDefaults = true).setAll(properties)
+ val port = conf.getInt("spark.executor.port", 0)
+ val env = SparkEnv.createExecutorEnv(
+ conf, executorId, slaveInfo.getHostname, port, cpusPerTask, isLocal = false)
+
executor = new Executor(
executorId,
slaveInfo.getHostname,
- properties,
- cpusPerTask)
+ env)
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
val taskId = taskInfo.getTaskId.getValue.toLong
+ val taskData = MesosTaskLaunchData.fromByteString(taskInfo.getData)
if (executor == null) {
logError("Received launchTask but executor was null")
} else {
- executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer)
+ SparkHadoopUtil.get.runAsSparkUser { () =>
+ executor.launchTask(this, taskId = taskId, attemptNumber = taskData.attemptNumber,
+ taskInfo.getName, taskData.serializedTask)
+ }
}
}
@@ -108,11 +117,8 @@ private[spark] class MesosExecutorBackend
private[spark] object MesosExecutorBackend extends Logging {
def main(args: Array[String]) {
SignalLogger.register(log)
- SparkHadoopUtil.get.runAsSparkUser { () =>
- MesosNativeLibrary.load()
- // Create a new Executor and start it running
- val runner = new MesosExecutorBackend()
- new MesosExecutorDriver(runner).run()
- }
+ // Create a new Executor and start it running
+ val runner = new MesosExecutorBackend()
+ new MesosExecutorDriver(runner).run()
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 51b5328cb4c8f..d05659193b334 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,6 +17,10 @@
package org.apache.spark.executor
+import java.util.concurrent.atomic.AtomicLong
+
+import org.apache.spark.executor.DataReadMethod.DataReadMethod
+
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.annotation.DeveloperApi
@@ -39,48 +43,78 @@ class TaskMetrics extends Serializable {
/**
* Host's name the task runs on
*/
- var hostname: String = _
-
+ private var _hostname: String = _
+ def hostname = _hostname
+ private[spark] def setHostname(value: String) = _hostname = value
+
/**
* Time taken on the executor to deserialize this task
*/
- var executorDeserializeTime: Long = _
-
+ private var _executorDeserializeTime: Long = _
+ def executorDeserializeTime = _executorDeserializeTime
+ private[spark] def setExecutorDeserializeTime(value: Long) = _executorDeserializeTime = value
+
+
/**
* Time the executor spends actually running the task (including fetching shuffle data)
*/
- var executorRunTime: Long = _
-
+ private var _executorRunTime: Long = _
+ def executorRunTime = _executorRunTime
+ private[spark] def setExecutorRunTime(value: Long) = _executorRunTime = value
+
/**
* The number of bytes this task transmitted back to the driver as the TaskResult
*/
- var resultSize: Long = _
+ private var _resultSize: Long = _
+ def resultSize = _resultSize
+ private[spark] def setResultSize(value: Long) = _resultSize = value
+
/**
* Amount of time the JVM spent in garbage collection while executing this task
*/
- var jvmGCTime: Long = _
+ private var _jvmGCTime: Long = _
+ def jvmGCTime = _jvmGCTime
+ private[spark] def setJvmGCTime(value: Long) = _jvmGCTime = value
/**
* Amount of time spent serializing the task result
*/
- var resultSerializationTime: Long = _
+ private var _resultSerializationTime: Long = _
+ def resultSerializationTime = _resultSerializationTime
+ private[spark] def setResultSerializationTime(value: Long) = _resultSerializationTime = value
/**
* The number of in-memory bytes spilled by this task
*/
- var memoryBytesSpilled: Long = _
+ private var _memoryBytesSpilled: Long = _
+ def memoryBytesSpilled = _memoryBytesSpilled
+ private[spark] def incMemoryBytesSpilled(value: Long) = _memoryBytesSpilled += value
+ private[spark] def decMemoryBytesSpilled(value: Long) = _memoryBytesSpilled -= value
/**
* The number of on-disk bytes spilled by this task
*/
- var diskBytesSpilled: Long = _
+ private var _diskBytesSpilled: Long = _
+ def diskBytesSpilled = _diskBytesSpilled
+ def incDiskBytesSpilled(value: Long) = _diskBytesSpilled += value
+ def decDiskBytesSpilled(value: Long) = _diskBytesSpilled -= value
/**
* If this task reads from a HadoopRDD or from persisted data, metrics on how much data was read
* are stored here.
*/
- var inputMetrics: Option[InputMetrics] = None
+ private var _inputMetrics: Option[InputMetrics] = None
+
+ def inputMetrics = _inputMetrics
+
+ /**
+ * This should only be used when recreating TaskMetrics, not when updating input metrics in
+ * executors
+ */
+ private[spark] def setInputMetrics(inputMetrics: Option[InputMetrics]) {
+ _inputMetrics = inputMetrics
+ }
/**
* If this task writes data externally (e.g. to a distributed filesystem), metrics on how much
@@ -133,19 +167,48 @@ class TaskMetrics extends Serializable {
readMetrics
}
+ /**
+ * Returns the input metrics object that the task should use. Currently, if
+ * there exists an input metric with the same readMethod, we return that one
+ * so the caller can accumulate bytes read. If the readMethod is different
+ * than previously seen by this task, we return a new InputMetric but don't
+ * record it.
+ *
+ * Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed,
+ * we can store all the different inputMetrics (one per readMethod).
+ */
+ private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod):
+ InputMetrics =synchronized {
+ _inputMetrics match {
+ case None =>
+ val metrics = new InputMetrics(readMethod)
+ _inputMetrics = Some(metrics)
+ metrics
+ case Some(metrics @ InputMetrics(method)) if method == readMethod =>
+ metrics
+ case Some(InputMetrics(method)) =>
+ new InputMetrics(readMethod)
+ }
+ }
+
/**
* Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
*/
- private[spark] def updateShuffleReadMetrics() = synchronized {
+ private[spark] def updateShuffleReadMetrics(): Unit = synchronized {
val merged = new ShuffleReadMetrics()
for (depMetrics <- depsShuffleReadMetrics) {
- merged.fetchWaitTime += depMetrics.fetchWaitTime
- merged.localBlocksFetched += depMetrics.localBlocksFetched
- merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
- merged.remoteBytesRead += depMetrics.remoteBytesRead
+ merged.incFetchWaitTime(depMetrics.fetchWaitTime)
+ merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
+ merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
+ merged.incRemoteBytesRead(depMetrics.remoteBytesRead)
+ merged.incRecordsRead(depMetrics.recordsRead)
}
_shuffleReadMetrics = Some(merged)
}
+
+ private[spark] def updateInputMetrics(): Unit = synchronized {
+ inputMetrics.foreach(_.updateBytesRead())
+ }
}
private[spark] object TaskMetrics {
@@ -179,10 +242,42 @@ object DataWriteMethod extends Enumeration with Serializable {
*/
@DeveloperApi
case class InputMetrics(readMethod: DataReadMethod.Value) {
+
+ /**
+ * This is volatile so that it is visible to the updater thread.
+ */
+ @volatile @transient var bytesReadCallback: Option[() => Long] = None
+
/**
* Total bytes read.
*/
- var bytesRead: Long = 0L
+ private var _bytesRead: Long = _
+ def bytesRead: Long = _bytesRead
+ def incBytesRead(bytes: Long) = _bytesRead += bytes
+
+ /**
+ * Total records read.
+ */
+ private var _recordsRead: Long = _
+ def recordsRead: Long = _recordsRead
+ def incRecordsRead(records: Long) = _recordsRead += records
+
+ /**
+ * Invoke the bytesReadCallback and mutate bytesRead.
+ */
+ def updateBytesRead() {
+ bytesReadCallback.foreach { c =>
+ _bytesRead = c()
+ }
+ }
+
+ /**
+ * Register a function that can be called to get up-to-date information on how many bytes the task
+ * has read from an input source.
+ */
+ def setBytesReadCallback(f: Option[() => Long]) {
+ bytesReadCallback = f
+ }
}
/**
@@ -194,7 +289,16 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
/**
* Total bytes written
*/
- var bytesWritten: Long = 0L
+ private var _bytesWritten: Long = _
+ def bytesWritten = _bytesWritten
+ private[spark] def setBytesWritten(value : Long) = _bytesWritten = value
+
+ /**
+ * Total records written
+ */
+ private var _recordsWritten: Long = 0L
+ def recordsWritten = _recordsWritten
+ private[spark] def setRecordsWritten(value: Long) = _recordsWritten = value
}
/**
@@ -203,32 +307,52 @@ case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
*/
@DeveloperApi
class ShuffleReadMetrics extends Serializable {
- /**
- * Number of blocks fetched in this shuffle by this task (remote or local)
- */
- def totalBlocksFetched: Int = remoteBlocksFetched + localBlocksFetched
-
/**
* Number of remote blocks fetched in this shuffle by this task
*/
- var remoteBlocksFetched: Int = _
-
+ private var _remoteBlocksFetched: Int = _
+ def remoteBlocksFetched = _remoteBlocksFetched
+ private[spark] def incRemoteBlocksFetched(value: Int) = _remoteBlocksFetched += value
+ private[spark] def decRemoteBlocksFetched(value: Int) = _remoteBlocksFetched -= value
+
/**
* Number of local blocks fetched in this shuffle by this task
*/
- var localBlocksFetched: Int = _
+ private var _localBlocksFetched: Int = _
+ def localBlocksFetched = _localBlocksFetched
+ private[spark] def incLocalBlocksFetched(value: Int) = _localBlocksFetched += value
+ private[spark] def decLocalBlocksFetched(value: Int) = _localBlocksFetched -= value
/**
* Time the task spent waiting for remote shuffle blocks. This only includes the time
* blocking on shuffle input data. For instance if block B is being fetched while the task is
* still not finished processing block A, it is not considered to be blocking on block B.
*/
- var fetchWaitTime: Long = _
-
+ private var _fetchWaitTime: Long = _
+ def fetchWaitTime = _fetchWaitTime
+ private[spark] def incFetchWaitTime(value: Long) = _fetchWaitTime += value
+ private[spark] def decFetchWaitTime(value: Long) = _fetchWaitTime -= value
+
/**
* Total number of remote bytes read from the shuffle by this task
*/
- var remoteBytesRead: Long = _
+ private var _remoteBytesRead: Long = _
+ def remoteBytesRead = _remoteBytesRead
+ private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value
+ private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value
+
+ /**
+ * Number of blocks fetched in this shuffle by this task (remote or local)
+ */
+ def totalBlocksFetched = _remoteBlocksFetched + _localBlocksFetched
+
+ /**
+ * Total number of records read from the shuffle by this task
+ */
+ private var _recordsRead: Long = _
+ def recordsRead = _recordsRead
+ private[spark] def incRecordsRead(value: Long) = _recordsRead += value
+ private[spark] def decRecordsRead(value: Long) = _recordsRead -= value
}
/**
@@ -240,10 +364,25 @@ class ShuffleWriteMetrics extends Serializable {
/**
* Number of bytes written for the shuffle by this task
*/
- @volatile var shuffleBytesWritten: Long = _
-
+ @volatile private var _shuffleBytesWritten: Long = _
+ def shuffleBytesWritten = _shuffleBytesWritten
+ private[spark] def incShuffleBytesWritten(value: Long) = _shuffleBytesWritten += value
+ private[spark] def decShuffleBytesWritten(value: Long) = _shuffleBytesWritten -= value
+
/**
* Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
*/
- @volatile var shuffleWriteTime: Long = _
+ @volatile private var _shuffleWriteTime: Long = _
+ def shuffleWriteTime= _shuffleWriteTime
+ private[spark] def incShuffleWriteTime(value: Long) = _shuffleWriteTime += value
+ private[spark] def decShuffleWriteTime(value: Long) = _shuffleWriteTime -= value
+
+ /**
+ * Total number of records written to the shuffle by this task
+ */
+ @volatile private var _shuffleRecordsWritten: Long = _
+ def shuffleRecordsWritten = _shuffleRecordsWritten
+ private[spark] def incShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten += value
+ private[spark] def decShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten -= value
+ private[spark] def setShuffleRecordsWritten(value: Long) = _shuffleRecordsWritten = value
}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
index 89b29af2000c8..c219d21fbefa9 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, LongWritable}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+import org.apache.spark.deploy.SparkHadoopUtil
/**
* Custom Input Format for reading and splitting flat binary files that contain records,
@@ -33,7 +34,7 @@ private[spark] object FixedLengthBinaryInputFormat {
/** Retrieves the record length property from a Hadoop configuration */
def getRecordLength(context: JobContext): Int = {
- context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt
+ SparkHadoopUtil.get.getConfigurationFromJobContext(context).get(RECORD_LENGTH_PROPERTY).toInt
}
}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
index 36a1e5d475f46..67a96925da019 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.io.compress.CompressionCodecFactory
import org.apache.hadoop.io.{BytesWritable, LongWritable}
import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.spark.deploy.SparkHadoopUtil
/**
* FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat.
@@ -82,7 +83,7 @@ private[spark] class FixedLengthBinaryRecordReader
// the actual file we will be reading from
val file = fileSplit.getPath
// job configuration
- val job = context.getConfiguration
+ val job = SparkHadoopUtil.get.getConfigurationFromJobContext(context)
// check compression
val codec = new CompressionCodecFactory(job).getCodec(file)
if (codec != null) {
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
index 457472547fcbb..593a62b3e3b32 100644
--- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAt
import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.deploy.SparkHadoopUtil
/**
* A general format for reading whole files in as streams, byte arrays,
@@ -145,7 +146,8 @@ class PortableDataStream(
private val confBytes = {
val baos = new ByteArrayOutputStream()
- context.getConfiguration.write(new DataOutputStream(baos))
+ SparkHadoopUtil.get.getConfigurationFromJobContext(context).
+ write(new DataOutputStream(baos))
baos.toByteArray
}
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index d3601cca832b2..aaef7c74eea33 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -19,7 +19,6 @@ package org.apache.spark.input
import scala.collection.JavaConversions._
-import org.apache.hadoop.conf.{Configuration, Configurable}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.JobContext
@@ -38,18 +37,13 @@ private[spark] class WholeTextFileInputFormat
override protected def isSplitable(context: JobContext, file: Path): Boolean = false
- private var conf: Configuration = _
- def setConf(c: Configuration) {
- conf = c
- }
- def getConf: Configuration = conf
-
override def createRecordReader(
split: InputSplit,
context: TaskAttemptContext): RecordReader[String, String] = {
- val reader = new WholeCombineFileRecordReader(split, context)
- reader.setConf(conf)
+ val reader =
+ new ConfigurableCombineFileRecordReader(split, context, classOf[WholeTextFileRecordReader])
+ reader.setConf(getConf)
reader
}
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
index 6d59b24eb0596..31bde8a78f3c6 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
@@ -17,7 +17,7 @@
package org.apache.spark.input
-import org.apache.hadoop.conf.{Configuration, Configurable}
+import org.apache.hadoop.conf.{Configuration, Configurable => HConfigurable}
import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.io.Text
@@ -26,6 +26,19 @@ import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader}
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.spark.deploy.SparkHadoopUtil
+
+
+/**
+ * A trait to implement [[org.apache.hadoop.conf.Configurable Configurable]] interface.
+ */
+private[spark] trait Configurable extends HConfigurable {
+ private var conf: Configuration = _
+ def setConf(c: Configuration) {
+ conf = c
+ }
+ def getConf: Configuration = conf
+}
/**
* A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file
@@ -38,14 +51,9 @@ private[spark] class WholeTextFileRecordReader(
index: Integer)
extends RecordReader[String, String] with Configurable {
- private var conf: Configuration = _
- def setConf(c: Configuration) {
- conf = c
- }
- def getConf: Configuration = conf
-
private[this] val path = split.getPath(index)
- private[this] val fs = path.getFileSystem(context.getConfiguration)
+ private[this] val fs = path.getFileSystem(
+ SparkHadoopUtil.get.getConfigurationFromJobContext(context))
// True means the current file has been processed, then skip it.
private[this] var processed = false
@@ -87,29 +95,24 @@ private[spark] class WholeTextFileRecordReader(
/**
- * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file
- * out in a key-value pair, where the key is the file path and the value is the entire content of
- * the file.
+ * A [[org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader CombineFileRecordReader]]
+ * that can pass Hadoop Configuration to [[org.apache.hadoop.conf.Configurable Configurable]]
+ * RecordReaders.
*/
-private[spark] class WholeCombineFileRecordReader(
+private[spark] class ConfigurableCombineFileRecordReader[K, V](
split: InputSplit,
- context: TaskAttemptContext)
- extends CombineFileRecordReader[String, String](
+ context: TaskAttemptContext,
+ recordReaderClass: Class[_ <: RecordReader[K, V] with HConfigurable])
+ extends CombineFileRecordReader[K, V](
split.asInstanceOf[CombineFileSplit],
context,
- classOf[WholeTextFileRecordReader]
+ recordReaderClass
) with Configurable {
- private var conf: Configuration = _
- def setConf(c: Configuration) {
- conf = c
- }
- def getConf: Configuration = conf
-
override def initNextRecordReader(): Boolean = {
val r = super.initNextRecordReader()
if (r) {
- this.curReader.asInstanceOf[WholeTextFileRecordReader].setConf(conf)
+ this.curReader.asInstanceOf[HConfigurable].setConf(getConf)
}
r
}
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 1ac7f4e448eb1..f856890d279f4 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -21,11 +21,12 @@ import java.io.{InputStream, OutputStream}
import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream}
import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream}
-import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream}
+import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream}
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
+import org.apache.spark.Logging
/**
* :: DeveloperApi ::
@@ -44,25 +45,33 @@ trait CompressionCodec {
def compressedInputStream(s: InputStream): InputStream
}
-
private[spark] object CompressionCodec {
+ private val configKey = "spark.io.compression.codec"
private val shortCompressionCodecNames = Map(
"lz4" -> classOf[LZ4CompressionCodec].getName,
"lzf" -> classOf[LZFCompressionCodec].getName,
"snappy" -> classOf[SnappyCompressionCodec].getName)
def createCodec(conf: SparkConf): CompressionCodec = {
- createCodec(conf, conf.get("spark.io.compression.codec", DEFAULT_COMPRESSION_CODEC))
+ createCodec(conf, conf.get(configKey, DEFAULT_COMPRESSION_CODEC))
}
def createCodec(conf: SparkConf, codecName: String): CompressionCodec = {
val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase, codecName)
- val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader)
- .getConstructor(classOf[SparkConf])
- ctor.newInstance(conf).asInstanceOf[CompressionCodec]
+ val codec = try {
+ val ctor = Class.forName(codecClass, true, Utils.getContextOrSparkClassLoader)
+ .getConstructor(classOf[SparkConf])
+ Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
+ } catch {
+ case e: ClassNotFoundException => None
+ case e: IllegalArgumentException => None
+ }
+ codec.getOrElse(throw new IllegalArgumentException(s"Codec [$codecName] is not available. " +
+ s"Consider setting $configKey=$FALLBACK_COMPRESSION_CODEC"))
}
+ val FALLBACK_COMPRESSION_CODEC = "lzf"
val DEFAULT_COMPRESSION_CODEC = "snappy"
val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq
}
@@ -120,6 +129,12 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec {
@DeveloperApi
class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec {
+ try {
+ Snappy.getNativeLibraryVersion
+ } catch {
+ case e: Error => throw new IllegalArgumentException
+ }
+
override def compressedOutputStream(s: OutputStream): OutputStream = {
val blockSize = conf.getInt("spark.io.compression.snappy.block.size", 32768)
new SnappyOutputStream(s, blockSize)
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 5dd67b0cbf683..83e8eb71260eb 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -76,22 +76,36 @@ private[spark] class MetricsSystem private (
private val sources = new mutable.ArrayBuffer[Source]
private val registry = new MetricRegistry()
+ private var running: Boolean = false
+
// Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui
private var metricsServlet: Option[MetricsServlet] = None
- /** Get any UI handlers used by this metrics system. */
- def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array())
+ /**
+ * Get any UI handlers used by this metrics system; can only be called after start().
+ */
+ def getServletHandlers = {
+ require(running, "Can only call getServletHandlers on a running MetricsSystem")
+ metricsServlet.map(_.getHandlers).getOrElse(Array())
+ }
metricsConfig.initialize()
def start() {
+ require(!running, "Attempting to start a MetricsSystem that is already running")
+ running = true
registerSources()
registerSinks()
sinks.foreach(_.start)
}
def stop() {
- sinks.foreach(_.stop)
+ if (running) {
+ sinks.foreach(_.stop)
+ } else {
+ logWarning("Stopping a MetricsSystem that is not running")
+ }
+ running = false
}
def report() {
@@ -107,7 +121,7 @@ private[spark] class MetricsSystem private (
* @return An unique metric name for each combination of
* application, executor/driver and metric source.
*/
- def buildRegistryName(source: Source): String = {
+ private[spark] def buildRegistryName(source: Source): String = {
val appId = conf.getOption("spark.app.id")
val executorId = conf.getOption("spark.executor.id")
val defaultName = MetricRegistry.name(source.sourceName)
@@ -116,8 +130,8 @@ private[spark] class MetricsSystem private (
if (appId.isDefined && executorId.isDefined) {
MetricRegistry.name(appId.get, executorId.get, source.sourceName)
} else {
- // Only Driver and Executor are set spark.app.id and spark.executor.id.
- // For instance, Master and Worker are not related to a specific application.
+ // Only Driver and Executor set spark.app.id and spark.executor.id.
+ // Other instance types, e.g. Master and Worker, are not related to a specific application.
val warningMsg = s"Using default name $defaultName for source because %s is not set."
if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) }
if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) }
@@ -144,7 +158,7 @@ private[spark] class MetricsSystem private (
})
}
- def registerSources() {
+ private def registerSources() {
val instConfig = metricsConfig.getInstance(instance)
val sourceConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SOURCE_REGEX)
@@ -160,7 +174,7 @@ private[spark] class MetricsSystem private (
}
}
- def registerSinks() {
+ private def registerSinks() {
val instConfig = metricsConfig.getInstance(instance)
val sinkConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SINK_REGEX)
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
index d7b5f5c40efae..2d25ebd66159f 100644
--- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -22,7 +22,7 @@ import java.util.Properties
import java.util.concurrent.TimeUnit
import com.codahale.metrics.MetricRegistry
-import com.codahale.metrics.graphite.{Graphite, GraphiteReporter}
+import com.codahale.metrics.graphite.{GraphiteUDP, Graphite, GraphiteReporter}
import org.apache.spark.SecurityManager
import org.apache.spark.metrics.MetricsSystem
@@ -38,6 +38,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric
val GRAPHITE_KEY_PERIOD = "period"
val GRAPHITE_KEY_UNIT = "unit"
val GRAPHITE_KEY_PREFIX = "prefix"
+ val GRAPHITE_KEY_PROTOCOL = "protocol"
def propertyToOption(prop: String): Option[String] = Option(property.getProperty(prop))
@@ -66,7 +67,11 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric
MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
- val graphite: Graphite = new Graphite(new InetSocketAddress(host, port))
+ val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match {
+ case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port))
+ case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port))
+ case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p")
+ }
val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry)
.convertDurationsTo(TimeUnit.MILLISECONDS)
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 0027cbb0ff1fb..3f0950dae1f24 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -60,7 +60,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
}
transportContext = new TransportContext(transportConf, rpcHandler)
clientFactory = transportContext.createClientFactory(bootstrap.toList)
- server = transportContext.createServer()
+ server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0))
appId = conf.getAppId
logInfo("Server created on " + server.getPort)
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index df4b085d2251e..ee22c6656e69e 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -81,11 +81,24 @@ private[nio] class ConnectionManager(
private val ackTimeoutMonitor =
new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor"))
- private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
+ private val ackTimeout =
+ conf.getInt("spark.core.connection.ack.wait.timeout", conf.getInt("spark.network.timeout", 120))
+
+ // Get the thread counts from the Spark Configuration.
+ //
+ // Even though the ThreadPoolExecutor constructor takes both a minimum and maximum value,
+ // we only query for the minimum value because we are using LinkedBlockingDeque.
+ //
+ // The JavaDoc for ThreadPoolExecutor points out that when using a LinkedBlockingDeque (which is
+ // an unbounded queue) no more than corePoolSize threads will ever be created, so only the "min"
+ // parameter is necessary.
+ private val handlerThreadCount = conf.getInt("spark.core.connection.handler.threads.min", 20)
+ private val ioThreadCount = conf.getInt("spark.core.connection.io.threads.min", 4)
+ private val connectThreadCount = conf.getInt("spark.core.connection.connect.threads.min", 1)
private val handleMessageExecutor = new ThreadPoolExecutor(
- conf.getInt("spark.core.connection.handler.threads.min", 20),
- conf.getInt("spark.core.connection.handler.threads.max", 60),
+ handlerThreadCount,
+ handlerThreadCount,
conf.getInt("spark.core.connection.handler.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
Utils.namedThreadFactory("handle-message-executor")) {
@@ -96,12 +109,11 @@ private[nio] class ConnectionManager(
logError("Error in handleMessageExecutor is not handled properly", t)
}
}
-
}
private val handleReadWriteExecutor = new ThreadPoolExecutor(
- conf.getInt("spark.core.connection.io.threads.min", 4),
- conf.getInt("spark.core.connection.io.threads.max", 32),
+ ioThreadCount,
+ ioThreadCount,
conf.getInt("spark.core.connection.io.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
Utils.namedThreadFactory("handle-read-write-executor")) {
@@ -112,14 +124,13 @@ private[nio] class ConnectionManager(
logError("Error in handleReadWriteExecutor is not handled properly", t)
}
}
-
}
// Use a different, yet smaller, thread pool - infrequently used with very short lived tasks :
// which should be executed asap
private val handleConnectExecutor = new ThreadPoolExecutor(
- conf.getInt("spark.core.connection.connect.threads.min", 1),
- conf.getInt("spark.core.connection.connect.threads.max", 8),
+ connectThreadCount,
+ connectThreadCount,
conf.getInt("spark.core.connection.connect.threads.keepalive", 60), TimeUnit.SECONDS,
new LinkedBlockingDeque[Runnable](),
Utils.namedThreadFactory("handle-connect-executor")) {
@@ -130,7 +141,6 @@ private[nio] class ConnectionManager(
logError("Error in handleConnectExecutor is not handled properly", t)
}
}
-
}
private val serverChannel = ServerSocketChannel.open()
@@ -164,7 +174,7 @@ private[nio] class ConnectionManager(
serverChannel.socket.bind(new InetSocketAddress(port))
(serverChannel, serverChannel.socket.getLocalPort)
}
- Utils.startServiceOnPort[ServerSocketChannel](port, startService, name)
+ Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name)
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
@@ -174,14 +184,16 @@ private[nio] class ConnectionManager(
// to be able to track asynchronous messages
private val idCount: AtomicInteger = new AtomicInteger(1)
+ private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+ private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
+
private val selectorThread = new Thread("connection-manager-thread") {
override def run() = ConnectionManager.this.run()
}
selectorThread.setDaemon(true)
+ // start this thread last, since it invokes run(), which accesses members above
selectorThread.start()
- private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
-
private def triggerWrite(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null)
if (conn == null) return
@@ -222,7 +234,6 @@ private[nio] class ConnectionManager(
} )
}
- private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]()
private def triggerRead(key: SelectionKey) {
val conn = connectionsByKey.getOrElse(key, null)
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index 5ad73c3d27f47..b6249b492150a 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -27,8 +27,7 @@ package org.apache
* contains operations available only on RDDs of Doubles; and
* [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that can
* be saved as SequenceFiles. These operations are automatically available on any RDD of the right
- * type (e.g. RDD[(Int, Int)] through implicit conversions except `saveAsSequenceFile`. You need to
- * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work.
+ * type (e.g. RDD[(Int, Int)] through implicit conversions.
*
* Java programmers should reference the [[org.apache.spark.api.java]] package
* for Spark programming APIs in Java.
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 7ba1182f0ed27..1c13e2c372845 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -95,7 +95,8 @@ private[spark] object CheckpointRDD extends Logging {
val finalOutputName = splitIdToFile(ctx.partitionId)
val finalOutputPath = new Path(outputDir, finalOutputName)
- val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptId)
+ val tempOutputPath =
+ new Path(outputDir, "." + finalOutputName + "-attempt-" + ctx.attemptNumber)
if (fs.exists(tempOutputPath)) {
throw new IOException("Checkpoint failed: temporary path " +
@@ -119,7 +120,7 @@ private[spark] object CheckpointRDD extends Logging {
logInfo("Deleting tempOutputPath " + tempOutputPath)
fs.delete(tempOutputPath, false)
throw new IOException("Checkpoint failed: failed to save output of task: "
- + ctx.attemptId + " and final output path does not exist")
+ + ctx.attemptNumber + " and final output path does not exist")
} else {
// Some other copy of this task must've finished before us and renamed it
logInfo("Final output path " + finalOutputPath + " already exists; not overwriting it")
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index ffc0a8a6d67eb..07398a6fa62f6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -60,7 +60,7 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
*
- * Note: This is an internal API. We recommend users use RDD.coGroup(...) instead of
+ * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of
* instantiating this directly.
* @param rdds parent RDDs.
@@ -70,8 +70,8 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) {
- // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs).
- // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner.
+ // For example, `(k, a) cogroup (k, b)` produces k -> Array(ArrayBuffer as, ArrayBuffer bs).
+ // Each ArrayBuffer is represented as a CoGroup, and the resulting Array as a CoGroupCombiner.
// CoGroupValue is the intermediate state of each value before being merged in compute.
private type CoGroup = CompactBuffer[Any]
private type CoGroupValue = (Any, Int) // Int is dependency number
@@ -159,8 +159,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part:
for ((it, depNum) <- rddIterators) {
map.insertAll(it.map(pair => (pair._1, new CoGroupValue(pair._2, depNum))))
}
- context.taskMetrics.memoryBytesSpilled += map.memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += map.diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(map.memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(map.diskBytesSpilled)
new InterruptibleIterator(context,
map.iterator.asInstanceOf[Iterator[(K, Array[Iterable[_]])]])
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 9fab1d78abb04..b073eba8a1574 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -35,11 +35,10 @@ import org.apache.spark.util.Utils
* @param preferredLocation the preferred location for this partition
*/
private[spark] case class CoalescedRDDPartition(
- index: Int,
- @transient rdd: RDD[_],
- parentsIndices: Array[Int],
- @transient preferredLocation: String = ""
- ) extends Partition {
+ index: Int,
+ @transient rdd: RDD[_],
+ parentsIndices: Array[Int],
+ @transient preferredLocation: Option[String] = None) extends Partition {
var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
@throws(classOf[IOException])
@@ -55,9 +54,10 @@ private[spark] case class CoalescedRDDPartition(
* @return locality of this coalesced partition between 0 and 1
*/
def localFraction: Double = {
- val loc = parents.count(p =>
- rdd.context.getPreferredLocs(rdd, p.index).map(tl => tl.host).contains(preferredLocation))
-
+ val loc = parents.count { p =>
+ val parentPreferredLocations = rdd.context.getPreferredLocs(rdd, p.index).map(_.host)
+ preferredLocation.exists(parentPreferredLocations.contains)
+ }
if (parents.size == 0) 0.0 else (loc.toDouble / parents.size.toDouble)
}
}
@@ -73,9 +73,9 @@ private[spark] case class CoalescedRDDPartition(
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
private[spark] class CoalescedRDD[T: ClassTag](
- @transient var prev: RDD[T],
- maxPartitions: Int,
- balanceSlack: Double = 0.10)
+ @transient var prev: RDD[T],
+ maxPartitions: Int,
+ balanceSlack: Double = 0.10)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
override def getPartitions: Array[Partition] = {
@@ -113,7 +113,7 @@ private[spark] class CoalescedRDD[T: ClassTag](
* @return the machine most preferred by split
*/
override def getPreferredLocations(partition: Partition): Seq[String] = {
- List(partition.asInstanceOf[CoalescedRDDPartition].preferredLocation)
+ partition.asInstanceOf[CoalescedRDDPartition].preferredLocation.toSeq
}
}
@@ -147,7 +147,7 @@ private[spark] class CoalescedRDD[T: ClassTag](
*
*/
-private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
+private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size
def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean =
@@ -341,8 +341,14 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc
}
}
-private[spark] case class PartitionGroup(prefLoc: String = "") {
+private case class PartitionGroup(prefLoc: Option[String] = None) {
var arr = mutable.ArrayBuffer[Partition]()
-
def size = arr.size
}
+
+private object PartitionGroup {
+ def apply(prefLoc: String): PartitionGroup = {
+ require(prefLoc != "", "Preferred location must not be empty")
+ PartitionGroup(Some(prefLoc))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index a157e36e2286e..486e86ce1bb19 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -35,16 +35,18 @@ import org.apache.hadoop.mapred.Reporter
import org.apache.hadoop.mapred.JobID
import org.apache.hadoop.mapred.TaskAttemptID
import org.apache.hadoop.mapred.TaskID
+import org.apache.hadoop.mapred.lib.CombineFileSplit
import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark.executor.DataReadMethod
import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.{NextIterator, Utils}
import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
+import org.apache.spark.storage.StorageLevel
/**
* A Spark split class that wraps around a Hadoop InputSplit.
@@ -131,7 +133,7 @@ class HadoopRDD[K, V](
// used to build JobTracker ID
private val createTime = new Date()
- private val shouldCloneJobConf = sc.conf.get("spark.hadoop.cloneConf", "false").toBoolean
+ private val shouldCloneJobConf = sc.conf.getBoolean("spark.hadoop.cloneConf", false)
// Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads.
protected def getJobConf(): JobConf = {
@@ -213,23 +215,24 @@ class HadoopRDD[K, V](
logInfo("Input split: " + split.inputSplit)
val jobConf = getJobConf()
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
+ val inputMetrics = context.taskMetrics
+ .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) {
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
- split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf)
- } else {
- None
- }
- if (bytesReadCallback.isDefined) {
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
+ split.inputSplit.value match {
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ case _ => None
+ }
}
+ inputMetrics.setBytesReadCallback(bytesReadCallback)
var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
+ context.stageId, theSplit.index, context.attemptNumber, jobConf)
reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
// Register an on-task-completion callback to close the input stream.
@@ -237,8 +240,6 @@ class HadoopRDD[K, V](
val key: K = reader.createKey()
val value: V = reader.createValue()
- var recordsSinceMetricsUpdate = 0
-
override def getNext() = {
try {
finished = !reader.next(key, value)
@@ -246,15 +247,8 @@ class HadoopRDD[K, V](
case eof: EOFException =>
finished = true
}
-
- // Update bytes read metric every few records
- if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES
- && bytesReadCallback.isDefined) {
- recordsSinceMetricsUpdate = 0
- val bytesReadFn = bytesReadCallback.get
- inputMetrics.bytesRead = bytesReadFn()
- } else {
- recordsSinceMetricsUpdate += 1
+ if (!finished) {
+ inputMetrics.incRecordsRead(1)
}
(key, value)
}
@@ -263,14 +257,13 @@ class HadoopRDD[K, V](
try {
reader.close()
if (bytesReadCallback.isDefined) {
- val bytesReadFn = bytesReadCallback.get
- inputMetrics.bytesRead = bytesReadFn()
- } else if (split.inputSplit.value.isInstanceOf[FileSplit]) {
+ inputMetrics.updateBytesRead()
+ } else if (split.inputSplit.value.isInstanceOf[FileSplit] ||
+ split.inputSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
- inputMetrics.bytesRead = split.inputSplit.value.getLength
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ inputMetrics.incBytesRead(split.inputSplit.value.getLength)
} catch {
case e: java.io.IOException =>
logWarning("Unable to get input size to set InputMetrics for task", e)
@@ -318,6 +311,15 @@ class HadoopRDD[K, V](
// Do nothing. Hadoop RDD should not be checkpointed.
}
+ override def persist(storageLevel: StorageLevel): this.type = {
+ if (storageLevel.deserialized) {
+ logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" +
+ " behavior because Hadoop's RecordReader reuses the same Writable object for all records." +
+ " Use a map transformation to make copies of the records.")
+ }
+ super.persist(storageLevel)
+ }
+
def getConf: Configuration = getJobConf()
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index e55d03d391e03..7fb94840df99c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -25,20 +25,17 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
-import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.input.WholeTextFileInputFormat
-import org.apache.spark.InterruptibleIterator
-import org.apache.spark.Logging
-import org.apache.spark.Partition
-import org.apache.spark.SerializableWritable
-import org.apache.spark.{SparkContext, TaskContext}
-import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark._
+import org.apache.spark.executor.DataReadMethod
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.Utils
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.storage.StorageLevel
private[spark] class NewHadoopPartition(
rddId: Int,
@@ -109,18 +106,19 @@ class NewHadoopRDD[K, V](
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
+ val inputMetrics = context.taskMetrics
+ .getInputMetricsForReadMethod(DataReadMethod.Hadoop)
+
// Find a function that will return the FileSystem bytes read by this thread. Do this before
// creating RecordReader, because RecordReader's constructor might read some bytes
- val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
- SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
- split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf)
- } else {
- None
- }
- if (bytesReadCallback.isDefined) {
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ val bytesReadCallback = inputMetrics.bytesReadCallback.orElse {
+ split.serializableHadoopSplit.value match {
+ case _: FileSplit | _: CombineFileSplit =>
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback()
+ case _ => None
+ }
}
+ inputMetrics.setBytesReadCallback(bytesReadCallback)
val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
@@ -153,34 +151,23 @@ class NewHadoopRDD[K, V](
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
-
- // Update bytes read metric every few records
- if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES
- && bytesReadCallback.isDefined) {
- recordsSinceMetricsUpdate = 0
- val bytesReadFn = bytesReadCallback.get
- inputMetrics.bytesRead = bytesReadFn()
- } else {
- recordsSinceMetricsUpdate += 1
+ if (!finished) {
+ inputMetrics.incRecordsRead(1)
}
-
(reader.getCurrentKey, reader.getCurrentValue)
}
private def close() {
try {
reader.close()
-
- // Update metrics with final amount
if (bytesReadCallback.isDefined) {
- val bytesReadFn = bytesReadCallback.get
- inputMetrics.bytesRead = bytesReadFn()
- } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
+ inputMetrics.updateBytesRead()
+ } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] ||
+ split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) {
// If we can't get the bytes read from the FS stats, fall back to the split size,
// which may be inaccurate.
try {
- inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ inputMetrics.incBytesRead(split.serializableHadoopSplit.value.getLength)
} catch {
case e: java.io.IOException =>
logWarning("Unable to get input size to set InputMetrics for task", e)
@@ -223,6 +210,16 @@ class NewHadoopRDD[K, V](
locs.getOrElse(split.getLocations.filter(_ != "localhost"))
}
+ override def persist(storageLevel: StorageLevel): this.type = {
+ if (storageLevel.deserialized) {
+ logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" +
+ " behavior because Hadoop's RecordReader reuses the same Writable object for all records." +
+ " Use a map transformation to make copies of the records.")
+ }
+ super.persist(storageLevel)
+ }
+
+
def getConf: Configuration = confBroadcast.value.value
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index c43e1f2fe135e..955b42c3baaa1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -25,6 +25,7 @@ import scala.collection.{Map, mutable}
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
+import scala.util.DynamicVariable
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
import org.apache.hadoop.conf.{Configurable, Configuration}
@@ -33,7 +34,7 @@ import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
-RecordWriter => NewRecordWriter}
+ RecordWriter => NewRecordWriter}
import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
@@ -84,7 +85,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
- val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ val aggregator = new Aggregator[K, V, C](
+ self.context.clean(createCombiner),
+ self.context.clean(mergeValue),
+ self.context.clean(mergeCombiners))
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(iter => {
val context = TaskContext.get()
@@ -120,11 +124,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def aggregateByKey[U: ClassTag](zeroValue: U, partitioner: Partitioner)(seqOp: (U, V) => U,
combOp: (U, U) => U): RDD[(K, U)] = {
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
- val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)
- lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
val createZero = () => cachedSerializer.deserialize[U](ByteBuffer.wrap(zeroArray))
combineByKey[U]((v: V) => seqOp(createZero(), v), seqOp, combOp, partitioner)
@@ -165,12 +169,12 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*/
def foldByKey(zeroValue: V, partitioner: Partitioner)(func: (V, V) => V): RDD[(K, V)] = {
// Serialize the zero value to a byte array so that we can get a new clone of it on each key
- val zeroBuffer = SparkEnv.get.closureSerializer.newInstance().serialize(zeroValue)
+ val zeroBuffer = SparkEnv.get.serializer.newInstance().serialize(zeroValue)
val zeroArray = new Array[Byte](zeroBuffer.limit)
zeroBuffer.get(zeroArray)
// When deserializing, use a lazy val to create just one instance of the serializer per task
- lazy val cachedSerializer = SparkEnv.get.closureSerializer.newInstance()
+ lazy val cachedSerializer = SparkEnv.get.serializer.newInstance()
val createZero = () => cachedSerializer.deserialize[V](ByteBuffer.wrap(zeroArray))
combineByKey[V]((v: V) => func(createZero(), v), func, func, partitioner)
@@ -433,6 +437,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
* or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ *
+ * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any
+ * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]].
*/
def groupByKey(partitioner: Partitioner): RDD[(K, Iterable[V])] = {
// groupByKey shouldn't use map side combine because map side combine does not
@@ -454,6 +461,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
* Note: This operation may be very expensive. If you are grouping in order to perform an
* aggregation (such as a sum or average) over each key, using [[PairRDDFunctions.aggregateByKey]]
* or [[PairRDDFunctions.reduceByKey]] will provide much better performance.
+ *
+ * Note: As currently implemented, groupByKey must be able to hold all the key-value pairs for any
+ * key in memory. If a key has too many values, it can result in an [[OutOfMemoryError]].
*/
def groupByKey(numPartitions: Int): RDD[(K, Iterable[V])] = {
groupByKey(new HashPartitioner(numPartitions))
@@ -480,7 +490,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues( pair =>
- for (v <- pair._1; w <- pair._2) yield (v, w)
+ for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w)
)
}
@@ -493,9 +503,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = {
this.cogroup(other, partitioner).flatMapValues { pair =>
if (pair._2.isEmpty) {
- pair._1.map(v => (v, None))
+ pair._1.iterator.map(v => (v, None))
} else {
- for (v <- pair._1; w <- pair._2) yield (v, Some(w))
+ for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, Some(w))
}
}
}
@@ -510,9 +520,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
: RDD[(K, (Option[V], W))] = {
this.cogroup(other, partitioner).flatMapValues { pair =>
if (pair._1.isEmpty) {
- pair._2.map(w => (None, w))
+ pair._2.iterator.map(w => (None, w))
} else {
- for (v <- pair._1; w <- pair._2) yield (Some(v), w)
+ for (v <- pair._1.iterator; w <- pair._2.iterator) yield (Some(v), w)
}
}
}
@@ -528,9 +538,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def fullOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner)
: RDD[(K, (Option[V], Option[W]))] = {
this.cogroup(other, partitioner).flatMapValues {
- case (vs, Seq()) => vs.map(v => (Some(v), None))
- case (Seq(), ws) => ws.map(w => (None, Some(w)))
- case (vs, ws) => for (v <- vs; w <- ws) yield (Some(v), Some(w))
+ case (vs, Seq()) => vs.iterator.map(v => (Some(v), None))
+ case (Seq(), ws) => ws.iterator.map(w => (None, Some(w)))
+ case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), Some(w))
}
}
@@ -961,19 +971,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val outfmt = job.getOutputFormatClass
val jobFormat = outfmt.newInstance
- if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) {
+ if (isOutputSpecValidationEnabled) {
// FileOutputFormat ignores the filesystem parameter
jobFormat.checkOutputSpecs(job)
}
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
val config = wrappedConf.value
- // Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
- // around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
- attemptNumber)
+ context.attemptNumber)
val hadoopContext = newTaskAttemptContext(config, attemptId)
val format = outfmt.newInstance
format match {
@@ -983,11 +990,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
- val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
+ var recordsWritten = 0L
try {
- var recordsWritten = 0L
while (iter.hasNext) {
val pair = iter.next()
writer.write(pair._1, pair._2)
@@ -1000,7 +1007,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.close(hadoopContext)
}
committer.commitTask(hadoopContext)
- bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
+ bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
+ outputMetrics.setRecordsWritten(recordsWritten)
1
} : Int
@@ -1039,7 +1047,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
valueClass.getSimpleName + ")")
- if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) {
+ if (isOutputSpecValidationEnabled) {
// FileOutputFormat ignores the filesystem parameter
val ignoredFs = FileSystem.get(hadoopConf)
hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf)
@@ -1052,14 +1060,14 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
- val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt
- val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
- writer.setup(context.stageId, context.partitionId, attemptNumber)
+ writer.setup(context.stageId, context.partitionId, taskAttemptId)
writer.open()
+ var recordsWritten = 0L
try {
- var recordsWritten = 0L
while (iter.hasNext) {
val record = iter.next()
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
@@ -1072,18 +1080,16 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.close()
}
writer.commit()
- bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
+ bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
+ outputMetrics.setRecordsWritten(recordsWritten)
}
self.context.runJob(self, writeToFile)
writer.commitJob()
}
- private def initHadoopOutputMetrics(context: TaskContext, config: Configuration)
- : (OutputMetrics, Option[() => Long]) = {
- val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir"))
- .map(new Path(_))
- .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config))
+ private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = {
+ val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback()
val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
if (bytesWrittenCallback.isDefined) {
context.taskMetrics.outputMetrics = Some(outputMetrics)
@@ -1093,9 +1099,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long],
outputMetrics: OutputMetrics, recordsWritten: Long): Unit = {
- if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0
- && bytesWrittenCallback.isDefined) {
- bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
+ if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0) {
+ bytesWrittenCallback.foreach { fn => outputMetrics.setBytesWritten(fn()) }
+ outputMetrics.setRecordsWritten(recordsWritten)
}
}
@@ -1114,8 +1120,22 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
private[spark] def valueClass: Class[_] = vt.runtimeClass
private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord)
+
+ // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation
+ // setting can take effect:
+ private def isOutputSpecValidationEnabled: Boolean = {
+ val validationDisabled = PairRDDFunctions.disableOutputSpecValidation.value
+ val enabledInConf = self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)
+ enabledInConf && !validationDisabled
+ }
}
private[spark] object PairRDDFunctions {
val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256
+
+ /**
+ * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case
+ * basis; see SPARK-4835 for more details.
+ */
+ val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 87b22de6ae697..f12d0cffaba34 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -111,7 +111,8 @@ private object ParallelCollectionRDD {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
- * it efficient to run Spark over RDDs representing large sets of numbers.
+ * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
+ * is an inclusive Range, we use inclusive range for the last slice.
*/
def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1) {
@@ -127,19 +128,15 @@ private object ParallelCollectionRDD {
})
}
seq match {
- case r: Range.Inclusive => {
- val sign = if (r.step < 0) {
- -1
- } else {
- 1
- }
- slice(new Range(
- r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
- }
case r: Range => {
- positions(r.length, numSlices).map({
- case (start, end) =>
+ positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) =>
+ // If the range is inclusive, use inclusive range for the last slice
+ if (r.isInclusive && index == numSlices - 1) {
+ new Range.Inclusive(r.start + start * r.step, r.end, r.step)
+ }
+ else {
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
+ }
}).toSeq.asInstanceOf[Seq[Seq[T]]]
}
case nr: NumericRange[_] => {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 214f22bc5b603..fe55a5124f3b6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -25,11 +25,8 @@ import scala.language.implicitConversions
import scala.reflect.{classTag, ClassTag}
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
-import org.apache.hadoop.io.BytesWritable
+import org.apache.hadoop.io.{Writable, BytesWritable, NullWritable, Text}
import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.hadoop.io.NullWritable
-import org.apache.hadoop.io.Text
-import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred.TextOutputFormat
import org.apache.spark._
@@ -57,8 +54,7 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* [[org.apache.spark.rdd.SequenceFileRDDFunctions]] contains operations available on RDDs that
* can be saved as SequenceFiles.
* All operations are automatically available on any RDD of the right type (e.g. RDD[(Int, Int)]
- * through implicit conversions except `saveAsSequenceFile`. You need to
- * `import org.apache.spark.SparkContext._` to make `saveAsSequenceFile` work.
+ * through implicit.
*
* Internally, each RDD is characterized by five main properties:
*
@@ -76,10 +72,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* on RDD internals.
*/
abstract class RDD[T: ClassTag](
- @transient private var sc: SparkContext,
+ @transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
+ if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have defined nested RDDs without running jobs with them.
+ logWarning("Spark does not support nested RDDs (see SPARK-5063)")
+ }
+
+ private def sc: SparkContext = {
+ if (_sc == null) {
+ throw new SparkException(
+ "RDD transformations and actions can only be invoked by the driver, not inside of other " +
+ "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
+ "the values transformation and count action cannot be performed inside of the rdd1.map " +
+ "transformation. For more information, see SPARK-5063.")
+ }
+ _sc
+ }
+
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
@@ -587,8 +600,8 @@ abstract class RDD[T: ClassTag](
* print line function (like out.println()) as the 2nd parameter.
* An example of pipe the RDD data of groupBy() in a streaming way,
* instead of constructing a huge String to concat all the elements:
- * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
- * for (e <- record._2){f(e)}
+ * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) =
+ * for (e <- record._2){f(e)}
* @param separateWorkingDir Use separate working directories for each task.
* @return the result RDD
*/
@@ -824,7 +837,7 @@ abstract class RDD[T: ClassTag](
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtract(other: RDD[T]): RDD[T] =
subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size)))
@@ -883,6 +896,38 @@ abstract class RDD[T: ClassTag](
jobResult.getOrElse(throw new UnsupportedOperationException("empty collection"))
}
+ /**
+ * Reduces the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree (default: 2)
+ * @see [[org.apache.spark.rdd.RDD#reduce]]
+ */
+ def treeReduce(f: (T, T) => T, depth: Int = 2): T = {
+ require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+ val cleanF = context.clean(f)
+ val reducePartition: Iterator[T] => Option[T] = iter => {
+ if (iter.hasNext) {
+ Some(iter.reduceLeft(cleanF))
+ } else {
+ None
+ }
+ }
+ val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it)))
+ val op: (Option[T], Option[T]) => Option[T] = (c, x) => {
+ if (c.isDefined && x.isDefined) {
+ Some(cleanF(c.get, x.get))
+ } else if (c.isDefined) {
+ c
+ } else if (x.isDefined) {
+ x
+ } else {
+ None
+ }
+ }
+ partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth)
+ .getOrElse(throw new UnsupportedOperationException("empty collection"))
+ }
+
/**
* Aggregate the elements of each partition, and then the results for all the partitions, using a
* given associative function and a neutral "zero value". The function op(t1, t2) is allowed to
@@ -918,6 +963,37 @@ abstract class RDD[T: ClassTag](
jobResult
}
+ /**
+ * Aggregates the elements of this RDD in a multi-level tree pattern.
+ *
+ * @param depth suggested depth of the tree (default: 2)
+ * @see [[org.apache.spark.rdd.RDD#aggregate]]
+ */
+ def treeAggregate[U: ClassTag](zeroValue: U)(
+ seqOp: (U, T) => U,
+ combOp: (U, U) => U,
+ depth: Int = 2): U = {
+ require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.")
+ if (partitions.size == 0) {
+ return Utils.clone(zeroValue, context.env.closureSerializer.newInstance())
+ }
+ val cleanSeqOp = context.clean(seqOp)
+ val cleanCombOp = context.clean(combOp)
+ val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)
+ var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it)))
+ var numPartitions = partiallyAggregated.partitions.size
+ val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2)
+ // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation.
+ while (numPartitions > scale + numPartitions / scale) {
+ numPartitions /= scale
+ val curNumPartitions = numPartitions
+ partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) =>
+ iter.map((i % curNumPartitions, _))
+ }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values
+ }
+ partiallyAggregated.reduce(cleanCombOp)
+ }
+
/**
* Return the number of elements in the RDD.
*/
@@ -947,7 +1023,7 @@ abstract class RDD[T: ClassTag](
*
* Note that this method should only be used if the resulting map is expected to be small, as
* the whole thing is loaded into the driver's memory.
- * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which
+ * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which
* returns an RDD[T, Long] instead of a map.
*/
def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = {
@@ -985,7 +1061,7 @@ abstract class RDD[T: ClassTag](
* Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available
* here.
*
- * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p`
+ * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p`
* would trigger sparse representation of registers, which may reduce the memory consumption
* and increase accuracy when the cardinality is small.
*
@@ -1146,15 +1222,20 @@ abstract class RDD[T: ClassTag](
if (num == 0) {
Array.empty
} else {
- mapPartitions { items =>
+ val mapRDDs = mapPartitions { items =>
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
Iterator.single(queue)
- }.reduce { (queue1, queue2) =>
- queue1 ++= queue2
- queue1
- }.toArray.sorted(ord)
+ }
+ if (mapRDDs.partitions.size == 0) {
+ Array.empty
+ } else {
+ mapRDDs.reduce { (queue1, queue2) =>
+ queue1 ++= queue2
+ queue1
+ }.toArray.sorted(ord)
+ }
}
}
@@ -1170,11 +1251,36 @@ abstract class RDD[T: ClassTag](
* */
def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min)
+ /**
+ * @return true if and only if the RDD contains no elements at all. Note that an RDD
+ * may be empty even when it has at least 1 partition.
+ */
+ def isEmpty(): Boolean = partitions.length == 0 || take(1).length == 0
+
/**
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) {
- this.map(x => (NullWritable.get(), new Text(x.toString)))
+ // https://issues.apache.org/jira/browse/SPARK-2075
+ //
+ // NullWritable is a `Comparable` in Hadoop 1.+, so the compiler cannot find an implicit
+ // Ordering for it and will use the default `null`. However, it's a `Comparable[NullWritable]`
+ // in Hadoop 2.+, so the compiler will call the implicit `Ordering.ordered` method to create an
+ // Ordering for `NullWritable`. That's why the compiler will generate different anonymous
+ // classes for `saveAsTextFile` in Hadoop 1.+ and Hadoop 2.+.
+ //
+ // Therefore, here we provide an explicit Ordering `null` to make sure the compiler generate
+ // same bytecodes for `saveAsTextFile`.
+ val nullWritableClassTag = implicitly[ClassTag[NullWritable]]
+ val textClassTag = implicitly[ClassTag[Text]]
+ val r = this.mapPartitions { iter =>
+ val text = new Text()
+ iter.map { x =>
+ text.set(x.toString)
+ (NullWritable.get(), text)
+ }
+ }
+ RDD.rddToPairRDDFunctions(r)(nullWritableClassTag, textClassTag, null)
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
}
@@ -1182,7 +1288,17 @@ abstract class RDD[T: ClassTag](
* Save this RDD as a compressed text file, using string representations of elements.
*/
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
- this.map(x => (NullWritable.get(), new Text(x.toString)))
+ // https://issues.apache.org/jira/browse/SPARK-2075
+ val nullWritableClassTag = implicitly[ClassTag[NullWritable]]
+ val textClassTag = implicitly[ClassTag[Text]]
+ val r = this.mapPartitions { iter =>
+ val text = new Text()
+ iter.map { x =>
+ text.set(x.toString)
+ (NullWritable.get(), text)
+ }
+ }
+ RDD.rddToPairRDDFunctions(r)(nullWritableClassTag, textClassTag, null)
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
}
@@ -1263,7 +1379,7 @@ abstract class RDD[T: ClassTag](
/**
* Private API for changing an RDD's ClassTag.
- * Used for internal Java <-> Scala API compatibility.
+ * Used for internal Java-Scala API compatibility.
*/
private[spark] def retag(cls: Class[T]): RDD[T] = {
val classTag: ClassTag[T] = ClassTag.apply(cls)
@@ -1272,7 +1388,7 @@ abstract class RDD[T: ClassTag](
/**
* Private API for changing an RDD's ClassTag.
- * Used for internal Java <-> Scala API compatibility.
+ * Used for internal Java-Scala API compatibility.
*/
private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = {
this.mapPartitions(identity, preservesPartitioning = true)(classTag)
@@ -1407,7 +1523,7 @@ abstract class RDD[T: ClassTag](
*/
object RDD {
- // The following implicit functions were in SparkContext before 1.2 and users had to
+ // The following implicit functions were in SparkContext before 1.3 and users had to
// `import SparkContext._` to enable them. Now we move them here to make the compiler find
// them automatically. However, we still keep the old functions in SparkContext for backward
// compatibility and forward to the following functions directly.
@@ -1421,9 +1537,15 @@ object RDD {
new AsyncRDDActions(rdd)
}
- implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
- rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = {
- new SequenceFileRDDFunctions(rdd)
+ implicit def rddToSequenceFileRDDFunctions[K, V](rdd: RDD[(K, V)])
+ (implicit kt: ClassTag[K], vt: ClassTag[V],
+ keyWritableFactory: WritableFactory[K],
+ valueWritableFactory: WritableFactory[V])
+ : SequenceFileRDDFunctions[K, V] = {
+ implicit val keyConverter = keyWritableFactory.convert
+ implicit val valueConverter = valueWritableFactory.convert
+ new SequenceFileRDDFunctions(rdd,
+ keyWritableFactory.writableClass(kt), valueWritableFactory.writableClass(vt))
}
implicit def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag](rdd: RDD[(K, V)])
diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
index 2b48916951430..059f8963691f0 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
@@ -30,13 +30,35 @@ import org.apache.spark.Logging
* through an implicit conversion. Note that this can't be part of PairRDDFunctions because
* we need more implicit parameters to convert our keys and values to Writable.
*
- * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions.
*/
class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag](
- self: RDD[(K, V)])
+ self: RDD[(K, V)],
+ _keyWritableClass: Class[_ <: Writable],
+ _valueWritableClass: Class[_ <: Writable])
extends Logging
with Serializable {
+ @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0")
+ def this(self: RDD[(K, V)]) {
+ this(self, null, null)
+ }
+
+ private val keyWritableClass =
+ if (_keyWritableClass == null) {
+ // pre 1.3.0, we need to use Reflection to get the Writable class
+ getWritableClass[K]()
+ } else {
+ _keyWritableClass
+ }
+
+ private val valueWritableClass =
+ if (_valueWritableClass == null) {
+ // pre 1.3.0, we need to use Reflection to get the Writable class
+ getWritableClass[V]()
+ } else {
+ _valueWritableClass
+ }
+
private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = {
val c = {
if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) {
@@ -55,6 +77,7 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
c.asInstanceOf[Class[_ <: Writable]]
}
+
/**
* Output the RDD as a Hadoop SequenceFile using the Writable types we infer from the RDD's key
* and value types. If the key or value are Writable, then we use their classes directly;
@@ -65,26 +88,28 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag
def saveAsSequenceFile(path: String, codec: Option[Class[_ <: CompressionCodec]] = None) {
def anyToWritable[U <% Writable](u: U): Writable = u
- val keyClass = getWritableClass[K]
- val valueClass = getWritableClass[V]
- val convertKey = !classOf[Writable].isAssignableFrom(self.keyClass)
- val convertValue = !classOf[Writable].isAssignableFrom(self.valueClass)
+ // TODO We cannot force the return type of `anyToWritable` be same as keyWritableClass and
+ // valueWritableClass at the compile time. To implement that, we need to add type parameters to
+ // SequenceFileRDDFunctions. however, SequenceFileRDDFunctions is a public class so it will be a
+ // breaking change.
+ val convertKey = self.keyClass != keyWritableClass
+ val convertValue = self.valueClass != valueWritableClass
- logInfo("Saving as sequence file of type (" + keyClass.getSimpleName + "," +
- valueClass.getSimpleName + ")" )
+ logInfo("Saving as sequence file of type (" + keyWritableClass.getSimpleName + "," +
+ valueWritableClass.getSimpleName + ")" )
val format = classOf[SequenceFileOutputFormat[Writable, Writable]]
val jobConf = new JobConf(self.context.hadoopConfiguration)
if (!convertKey && !convertValue) {
- self.saveAsHadoopFile(path, keyClass, valueClass, format, jobConf, codec)
+ self.saveAsHadoopFile(path, keyWritableClass, valueWritableClass, format, jobConf, codec)
} else if (!convertKey && convertValue) {
self.map(x => (x._1,anyToWritable(x._2))).saveAsHadoopFile(
- path, keyClass, valueClass, format, jobConf, codec)
+ path, keyWritableClass, valueWritableClass, format, jobConf, codec)
} else if (convertKey && !convertValue) {
self.map(x => (anyToWritable(x._1),x._2)).saveAsHadoopFile(
- path, keyClass, valueClass, format, jobConf, codec)
+ path, keyWritableClass, valueWritableClass, format, jobConf, codec)
} else if (convertKey && convertValue) {
self.map(x => (anyToWritable(x._1),anyToWritable(x._2))).saveAsHadoopFile(
- path, keyClass, valueClass, format, jobConf, codec)
+ path, keyWritableClass, valueWritableClass, format, jobConf, codec)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index cb8ccfbdbdcbb..1cfe98673773a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.io.NotSerializableException
import java.util.Properties
+import java.util.concurrent.{TimeUnit, Executors}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map, Stack}
@@ -28,8 +29,6 @@ import scala.language.postfixOps
import scala.reflect.ClassTag
import scala.util.control.NonFatal
-import akka.actor._
-import akka.actor.SupervisorStrategy.Stop
import akka.pattern.ask
import akka.util.Timeout
@@ -39,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
-import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
+import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
/**
@@ -67,8 +66,6 @@ class DAGScheduler(
clock: Clock = SystemClock)
extends Logging {
- import DAGScheduler._
-
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
this(
sc,
@@ -112,14 +109,10 @@ class DAGScheduler(
// stray messages to detect.
private val failedEpoch = new HashMap[String, Long]
- private val dagSchedulerActorSupervisor =
- env.actorSystem.actorOf(Props(new DAGSchedulerActorSupervisor(this)))
-
// A closure serializer that we reuse.
// This is only safe because DAGScheduler runs in a single thread.
private val closureSerializer = SparkEnv.get.closureSerializer.newInstance()
- private[scheduler] var eventProcessActor: ActorRef = _
/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
@@ -127,26 +120,20 @@ class DAGScheduler(
/** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
- private def initializeEventProcessActor() {
- // blocking the thread until supervisor is started, which ensures eventProcessActor is
- // not null before any job is submitted
- implicit val timeout = Timeout(30 seconds)
- val initEventActorReply =
- dagSchedulerActorSupervisor ? Props(new DAGSchedulerEventProcessActor(this))
- eventProcessActor = Await.result(initEventActorReply, timeout.duration).
- asInstanceOf[ActorRef]
- }
+ private val messageScheduler =
+ Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("dag-scheduler-message"))
- initializeEventProcessActor()
+ private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
+ taskScheduler.setDAGScheduler(this)
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
- eventProcessActor ! BeginEvent(task, taskInfo)
+ eventProcessLoop.post(BeginEvent(task, taskInfo))
}
// Called to report that a task has completed and results are being fetched remotely.
def taskGettingResult(taskInfo: TaskInfo) {
- eventProcessActor ! GettingResultEvent(taskInfo)
+ eventProcessLoop.post(GettingResultEvent(taskInfo))
}
// Called by TaskScheduler to report task completions or failures.
@@ -157,7 +144,8 @@ class DAGScheduler(
accumUpdates: Map[Long, Any],
taskInfo: TaskInfo,
taskMetrics: TaskMetrics) {
- eventProcessActor ! CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics)
+ eventProcessLoop.post(
+ CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
}
/**
@@ -179,18 +167,18 @@ class DAGScheduler(
// Called by TaskScheduler when an executor fails.
def executorLost(execId: String) {
- eventProcessActor ! ExecutorLost(execId)
+ eventProcessLoop.post(ExecutorLost(execId))
}
// Called by TaskScheduler when a host is added
def executorAdded(execId: String, host: String) {
- eventProcessActor ! ExecutorAdded(execId, host)
+ eventProcessLoop.post(ExecutorAdded(execId, host))
}
// Called by TaskScheduler to cancel an entire TaskSet due to either repeated failures or
// cancellation of the job itself.
def taskSetFailed(taskSet: TaskSet, reason: String) {
- eventProcessActor ! TaskSetFailed(taskSet, reason)
+ eventProcessLoop.post(TaskSetFailed(taskSet, reason))
}
private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
@@ -495,8 +483,8 @@ class DAGScheduler(
assert(partitions.size > 0)
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val waiter = new JobWaiter(this, jobId, partitions.size, resultHandler)
- eventProcessActor ! JobSubmitted(
- jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties)
+ eventProcessLoop.post(JobSubmitted(
+ jobId, rdd, func2, partitions.toArray, allowLocal, callSite, waiter, properties))
waiter
}
@@ -536,8 +524,8 @@ class DAGScheduler(
val func2 = func.asInstanceOf[(TaskContext, Iterator[_]) => _]
val partitions = (0 until rdd.partitions.size).toArray
val jobId = nextJobId.getAndIncrement()
- eventProcessActor ! JobSubmitted(
- jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties)
+ eventProcessLoop.post(JobSubmitted(
+ jobId, rdd, func2, partitions, allowLocal = false, callSite, listener, properties))
listener.awaitResult() // Will throw an exception if the job fails
}
@@ -546,19 +534,19 @@ class DAGScheduler(
*/
def cancelJob(jobId: Int) {
logInfo("Asked to cancel job " + jobId)
- eventProcessActor ! JobCancelled(jobId)
+ eventProcessLoop.post(JobCancelled(jobId))
}
def cancelJobGroup(groupId: String) {
logInfo("Asked to cancel job group " + groupId)
- eventProcessActor ! JobGroupCancelled(groupId)
+ eventProcessLoop.post(JobGroupCancelled(groupId))
}
/**
* Cancel all jobs that are running or waiting in the queue.
*/
def cancelAllJobs() {
- eventProcessActor ! AllJobsCancelled
+ eventProcessLoop.post(AllJobsCancelled)
}
private[scheduler] def doCancelAllJobs() {
@@ -574,7 +562,7 @@ class DAGScheduler(
* Cancel all jobs associated with a running or scheduled stage.
*/
def cancelStage(stageId: Int) {
- eventProcessActor ! StageCancelled(stageId)
+ eventProcessLoop.post(StageCancelled(stageId))
}
/**
@@ -634,8 +622,8 @@ class DAGScheduler(
try {
val rdd = job.finalStage.rdd
val split = rdd.partitions(job.partitions(0))
- val taskContext =
- new TaskContextImpl(job.finalStage.id, job.partitions(0), 0, true)
+ val taskContext = new TaskContextImpl(job.finalStage.id, job.partitions(0), taskAttemptId = 0,
+ attemptNumber = 0, runningLocally = true)
TaskContextHelper.setTaskContext(taskContext)
try {
val result = job.func(taskContext, rdd.iterator(split, taskContext))
@@ -660,7 +648,7 @@ class DAGScheduler(
// completion events or stage abort
stageIdToStage -= s.id
jobIdToStageIds -= job.jobId
- listenerBus.post(SparkListenerJobEnd(job.jobId, jobResult))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), jobResult))
}
}
@@ -709,7 +697,7 @@ class DAGScheduler(
stage.latestInfo.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
}
- listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
}
}
@@ -748,9 +736,11 @@ class DAGScheduler(
logInfo("Missing parents: " + getMissingParentStages(finalStage))
val shouldRunLocally =
localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
+ val jobSubmissionTime = clock.getTime()
if (shouldRunLocally) {
// Compute very short actions like first() or take() with no parent stages locally.
- listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties))
+ listenerBus.post(
+ SparkListenerJobStart(job.jobId, jobSubmissionTime, Seq.empty, properties))
runLocally(job)
} else {
jobIdToActiveJob(jobId) = job
@@ -758,7 +748,8 @@ class DAGScheduler(
finalStage.resultOfJob = Some(job)
val stageIds = jobIdToStageIds(jobId).toArray
val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
- listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties))
+ listenerBus.post(
+ SparkListenerJobStart(job.jobId, jobSubmissionTime, stageInfos, properties))
submitStage(finalStage)
}
}
@@ -865,26 +856,6 @@ class DAGScheduler(
}
if (tasks.size > 0) {
- // Preemptively serialize a task to make sure it can be serialized. We are catching this
- // exception here because it would be fairly hard to catch the non-serializable exception
- // down the road, where we have several different implementations for local scheduler and
- // cluster schedulers.
- //
- // We've already serialized RDDs and closures in taskBinary, but here we check for all other
- // objects such as Partition.
- try {
- closureSerializer.serialize(tasks.head)
- } catch {
- case e: NotSerializableException =>
- abortStage(stage, "Task not serializable: " + e.toString)
- runningStages -= stage
- return
- case NonFatal(e) => // Other exceptions, such as IllegalArgumentException from Kryo.
- abortStage(stage, s"Task serialization failed: $e\n${e.getStackTraceString}")
- runningStages -= stage
- return
- }
-
logInfo("Submitting " + tasks.size + " missing tasks from " + stage + " (" + stage.rdd + ")")
stage.pendingTasks ++= tasks
logDebug("New pending tasks: " + stage.pendingTasks)
@@ -984,7 +955,8 @@ class DAGScheduler(
if (job.numFinished == job.numPartitions) {
markStageAsFinished(stage)
cleanupStateForJobAndIndependentStages(job)
- listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
+ listenerBus.post(
+ SparkListenerJobEnd(job.jobId, clock.getTime(), JobSucceeded))
}
// taskSucceeded runs some user code that might throw an exception. Make sure
@@ -1078,16 +1050,15 @@ class DAGScheduler(
if (disallowStageRetryForTest) {
abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
- } else if (failedStages.isEmpty && eventProcessActor != null) {
+ } else if (failedStages.isEmpty) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
- // in that case the event will already have been scheduled. eventProcessActor may be
- // null during unit tests.
+ // in that case the event will already have been scheduled.
// TODO: Cancel running tasks in the stage
- import env.actorSystem.dispatcher
logInfo(s"Resubmitting $mapStage (${mapStage.name}) and " +
s"$failedStage (${failedStage.name}) due to fetch failure")
- env.actorSystem.scheduler.scheduleOnce(
- RESUBMIT_TIMEOUT, eventProcessActor, ResubmitFailedStages)
+ messageScheduler.schedule(new Runnable {
+ override def run(): Unit = eventProcessLoop.post(ResubmitFailedStages)
+ }, DAGScheduler.RESUBMIT_TIMEOUT, TimeUnit.MILLISECONDS)
}
failedStages += failedStage
failedStages += mapStage
@@ -1253,7 +1224,7 @@ class DAGScheduler(
if (ableToCancelStages) {
job.listener.jobFailed(error)
cleanupStateForJobAndIndependentStages(job)
- listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
}
}
@@ -1345,46 +1316,21 @@ class DAGScheduler(
def stop() {
logInfo("Stopping DAGScheduler")
- dagSchedulerActorSupervisor ! PoisonPill
+ eventProcessLoop.stop()
taskScheduler.stop()
}
-}
-
-private[scheduler] class DAGSchedulerActorSupervisor(dagScheduler: DAGScheduler)
- extends Actor with Logging {
-
- override val supervisorStrategy =
- OneForOneStrategy() {
- case x: Exception =>
- logError("eventProcesserActor failed; shutting down SparkContext", x)
- try {
- dagScheduler.doCancelAllJobs()
- } catch {
- case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
- }
- dagScheduler.sc.stop()
- Stop
- }
- def receive = {
- case p: Props => sender ! context.actorOf(p)
- case _ => logWarning("received unknown message in DAGSchedulerActorSupervisor")
- }
+ // Start the event thread at the end of the constructor
+ eventProcessLoop.start()
}
-private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGScheduler)
- extends Actor with Logging {
-
- override def preStart() {
- // set DAGScheduler for taskScheduler to ensure eventProcessActor is always
- // valid when the messages arrive
- dagScheduler.taskScheduler.setDAGScheduler(dagScheduler)
- }
+private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler)
+ extends EventLoop[DAGSchedulerEvent]("dag-scheduler-event-loop") with Logging {
/**
* The main event loop of the DAG scheduler.
*/
- def receive = {
+ override def onReceive(event: DAGSchedulerEvent): Unit = event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
dagScheduler.handleJobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite,
listener, properties)
@@ -1423,7 +1369,17 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.resubmitFailedStages()
}
- override def postStop() {
+ override def onError(e: Throwable): Unit = {
+ logError("DAGSchedulerEventProcessLoop failed; shutting down SparkContext", e)
+ try {
+ dagScheduler.doCancelAllJobs()
+ } catch {
+ case t: Throwable => logError("DAGScheduler failed to cancel all jobs.", t)
+ }
+ dagScheduler.sc.stop()
+ }
+
+ override def onStop() {
// Cancel any active jobs in postStop hook
dagScheduler.cleanUpAfterSchedulerStop()
}
@@ -1433,9 +1389,5 @@ private[spark] object DAGScheduler {
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
- val RESUBMIT_TIMEOUT = 200.milliseconds
-
- // The time, in millis, to wake up between polls of the completion queue in order to potentially
- // resubmit failed stages
- val POLL_TIMEOUT = 10L
+ val RESUBMIT_TIMEOUT = 200
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 597dbc884913c..30075c172bdb1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -17,20 +17,23 @@
package org.apache.spark.scheduler
+import java.io._
+import java.net.URI
+
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
+import com.google.common.base.Charsets
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path}
import org.apache.hadoop.fs.permission.FsPermission
import org.json4s.JsonAST.JValue
import org.json4s.jackson.JsonMethods._
-import org.apache.spark.{Logging, SparkConf, SparkContext}
+import org.apache.spark.{Logging, SparkConf, SPARK_VERSION}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.io.CompressionCodec
-import org.apache.spark.SPARK_VERSION
-import org.apache.spark.util.{FileLogger, JsonProtocol, Utils}
+import org.apache.spark.util.{JsonProtocol, Utils}
/**
* A SparkListener that logs events to persistent storage.
@@ -58,36 +61,78 @@ private[spark] class EventLoggingListener(
private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false)
private val testing = sparkConf.getBoolean("spark.eventLog.testing", false)
private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024
- val logDir = EventLoggingListener.getLogDirPath(logBaseDir, appId)
- val logDirName: String = logDir.split("/").last
- protected val logger = new FileLogger(logDir, sparkConf, hadoopConf, outputBufferSize,
- shouldCompress, shouldOverwrite, Some(LOG_FILE_PERMISSIONS))
+ private val fileSystem = Utils.getHadoopFileSystem(new URI(logBaseDir), hadoopConf)
+
+ // Only defined if the file system scheme is not local
+ private var hadoopDataStream: Option[FSDataOutputStream] = None
+
+ // The Hadoop APIs have changed over time, so we use reflection to figure out
+ // the correct method to use to flush a hadoop data stream. See SPARK-1518
+ // for details.
+ private val hadoopFlushMethod = {
+ val cls = classOf[FSDataOutputStream]
+ scala.util.Try(cls.getMethod("hflush")).getOrElse(cls.getMethod("sync"))
+ }
+
+ private var writer: Option[PrintWriter] = None
// For testing. Keep track of all JSON serialized events that have been logged.
private[scheduler] val loggedEvents = new ArrayBuffer[JValue]
+ // Visible for tests only.
+ private[scheduler] val logPath = getLogPath(logBaseDir, appId)
+
/**
- * Begin logging events.
- * If compression is used, log a file that indicates which compression library is used.
+ * Creates the log file in the configured log directory.
*/
def start() {
- logger.start()
- logInfo("Logging events to %s".format(logDir))
- if (shouldCompress) {
- val codec =
- sparkConf.get("spark.io.compression.codec", CompressionCodec.DEFAULT_COMPRESSION_CODEC)
- logger.newFile(COMPRESSION_CODEC_PREFIX + codec)
+ if (!fileSystem.isDirectory(new Path(logBaseDir))) {
+ throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.")
}
- logger.newFile(SPARK_VERSION_PREFIX + SPARK_VERSION)
- logger.newFile(LOG_PREFIX + logger.fileIndex)
+
+ val workingPath = logPath + IN_PROGRESS
+ val uri = new URI(workingPath)
+ val path = new Path(workingPath)
+ val defaultFs = FileSystem.getDefaultUri(hadoopConf).getScheme
+ val isDefaultLocal = defaultFs == null || defaultFs == "file"
+
+ if (shouldOverwrite && fileSystem.exists(path)) {
+ logWarning(s"Event log $path already exists. Overwriting...")
+ fileSystem.delete(path, true)
+ }
+
+ /* The Hadoop LocalFileSystem (r1.0.4) has known issues with syncing (HADOOP-7844).
+ * Therefore, for local files, use FileOutputStream instead. */
+ val dstream =
+ if ((isDefaultLocal && uri.getScheme == null) || uri.getScheme == "file") {
+ new FileOutputStream(uri.getPath)
+ } else {
+ hadoopDataStream = Some(fileSystem.create(path))
+ hadoopDataStream.get
+ }
+
+ val compressionCodec =
+ if (shouldCompress) {
+ Some(CompressionCodec.createCodec(sparkConf))
+ } else {
+ None
+ }
+
+ fileSystem.setPermission(path, LOG_FILE_PERMISSIONS)
+ val logStream = initEventLog(new BufferedOutputStream(dstream, outputBufferSize),
+ compressionCodec)
+ writer = Some(new PrintWriter(logStream))
+
+ logInfo("Logging events to %s".format(logPath))
}
/** Log the event as JSON. */
private def logEvent(event: SparkListenerEvent, flushLogger: Boolean = false) {
val eventJson = JsonProtocol.sparkEventToJson(event)
- logger.logLine(compact(render(eventJson)))
+ writer.foreach(_.println(compact(render(eventJson))))
if (flushLogger) {
- logger.flush()
+ writer.foreach(_.flush())
+ hadoopDataStream.foreach(hadoopFlushMethod.invoke(_))
}
if (testing) {
loggedEvents += eventJson
@@ -123,130 +168,168 @@ private[spark] class EventLoggingListener(
logEvent(event, flushLogger = true)
override def onApplicationEnd(event: SparkListenerApplicationEnd) =
logEvent(event, flushLogger = true)
+ override def onExecutorAdded(event: SparkListenerExecutorAdded) =
+ logEvent(event, flushLogger = true)
+ override def onExecutorRemoved(event: SparkListenerExecutorRemoved) =
+ logEvent(event, flushLogger = true)
+
// No-op because logging every update would be overkill
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate) { }
/**
- * Stop logging events.
- * In addition, create an empty special file to indicate application completion.
+ * Stop logging events. The event log file will be renamed so that it loses the
+ * ".inprogress" suffix.
*/
def stop() = {
- logger.newFile(APPLICATION_COMPLETE)
- logger.stop()
+ writer.foreach(_.close())
+
+ val target = new Path(logPath)
+ if (fileSystem.exists(target)) {
+ if (shouldOverwrite) {
+ logWarning(s"Event log $target already exists. Overwriting...")
+ fileSystem.delete(target, true)
+ } else {
+ throw new IOException("Target log file already exists (%s)".format(logPath))
+ }
+ }
+ fileSystem.rename(new Path(logPath + IN_PROGRESS), target)
}
+
}
private[spark] object EventLoggingListener extends Logging {
+ // Suffix applied to the names of files still being written by applications.
+ val IN_PROGRESS = ".inprogress"
val DEFAULT_LOG_DIR = "/tmp/spark-events"
- val LOG_PREFIX = "EVENT_LOG_"
- val SPARK_VERSION_PREFIX = "SPARK_VERSION_"
- val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_"
- val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
- val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
- // A cache for compression codecs to avoid creating the same codec many times
- private val codecMap = new mutable.HashMap[String, CompressionCodec]
+ private val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
- def isEventLogFile(fileName: String): Boolean = {
- fileName.startsWith(LOG_PREFIX)
- }
+ // Marker for the end of header data in a log file. After this marker, log data, potentially
+ // compressed, will be found.
+ private val HEADER_END_MARKER = "=== LOG_HEADER_END ==="
- def isSparkVersionFile(fileName: String): Boolean = {
- fileName.startsWith(SPARK_VERSION_PREFIX)
- }
+ // To avoid corrupted files causing the heap to fill up. Value is arbitrary.
+ private val MAX_HEADER_LINE_LENGTH = 4096
- def isCompressionCodecFile(fileName: String): Boolean = {
- fileName.startsWith(COMPRESSION_CODEC_PREFIX)
- }
+ // A cache for compression codecs to avoid creating the same codec many times
+ private val codecMap = new mutable.HashMap[String, CompressionCodec]
- def isApplicationCompleteFile(fileName: String): Boolean = {
- fileName == APPLICATION_COMPLETE
- }
+ /**
+ * Write metadata about the event log to the given stream.
+ *
+ * The header is a serialized version of a map, except it does not use Java serialization to
+ * avoid incompatibilities between different JDKs. It writes one map entry per line, in
+ * "key=value" format.
+ *
+ * The very last entry in the header is the `HEADER_END_MARKER` marker, so that the parsing code
+ * can know when to stop.
+ *
+ * The format needs to be kept in sync with the openEventLog() method below. Also, it cannot
+ * change in new Spark versions without some other way of detecting the change (like some
+ * metadata encoded in the file name).
+ *
+ * @param logStream Raw output stream to the even log file.
+ * @param compressionCodec Optional compression codec to use.
+ * @return A stream where to write event log data. This may be a wrapper around the original
+ * stream (for example, when compression is enabled).
+ */
+ def initEventLog(
+ logStream: OutputStream,
+ compressionCodec: Option[CompressionCodec]): OutputStream = {
+ val meta = mutable.HashMap(("version" -> SPARK_VERSION))
+ compressionCodec.foreach { codec =>
+ meta += ("compressionCodec" -> codec.getClass().getName())
+ }
- def parseSparkVersion(fileName: String): String = {
- if (isSparkVersionFile(fileName)) {
- fileName.replaceAll(SPARK_VERSION_PREFIX, "")
- } else ""
- }
+ def write(entry: String) = {
+ val bytes = entry.getBytes(Charsets.UTF_8)
+ if (bytes.length > MAX_HEADER_LINE_LENGTH) {
+ throw new IOException(s"Header entry too long: ${entry}")
+ }
+ logStream.write(bytes, 0, bytes.length)
+ }
- def parseCompressionCodec(fileName: String): String = {
- if (isCompressionCodecFile(fileName)) {
- fileName.replaceAll(COMPRESSION_CODEC_PREFIX, "")
- } else ""
+ meta.foreach { case (k, v) => write(s"$k=$v\n") }
+ write(s"$HEADER_END_MARKER\n")
+ compressionCodec.map(_.compressedOutputStream(logStream)).getOrElse(logStream)
}
/**
- * Return a file-system-safe path to the log directory for the given application.
+ * Return a file-system-safe path to the log file for the given application.
*
- * @param logBaseDir A base directory for the path to the log directory for given application.
+ * @param logBaseDir Directory where the log file will be written.
* @param appId A unique app ID.
* @return A path which consists of file-system-safe characters.
*/
- def getLogDirPath(logBaseDir: String, appId: String): String = {
+ def getLogPath(logBaseDir: String, appId: String): String = {
val name = appId.replaceAll("[ :/]", "-").replaceAll("[${}'\"]", "_").toLowerCase
Utils.resolveURI(logBaseDir) + "/" + name.stripSuffix("/")
}
/**
- * Parse the event logging information associated with the logs in the given directory.
+ * Opens an event log file and returns an input stream to the event data.
*
- * Specifically, this looks for event log files, the Spark version file, the compression
- * codec file (if event logs are compressed), and the application completion file (if the
- * application has run to completion).
+ * @return 2-tuple (event input stream, Spark version of event data)
*/
- def parseLoggingInfo(logDir: Path, fileSystem: FileSystem): EventLoggingInfo = {
+ def openEventLog(log: Path, fs: FileSystem): (InputStream, String) = {
+ // It's not clear whether FileSystem.open() throws FileNotFoundException or just plain
+ // IOException when a file does not exist, so try our best to throw a proper exception.
+ if (!fs.exists(log)) {
+ throw new FileNotFoundException(s"File $log does not exist.")
+ }
+
+ val in = new BufferedInputStream(fs.open(log))
+ // Read a single line from the input stream without buffering.
+ // We cannot use BufferedReader because we must avoid reading
+ // beyond the end of the header, after which the content of the
+ // file may be compressed.
+ def readLine(): String = {
+ val bytes = new ByteArrayOutputStream()
+ var next = in.read()
+ var count = 0
+ while (next != '\n') {
+ if (next == -1) {
+ throw new IOException("Unexpected end of file.")
+ }
+ bytes.write(next)
+ count = count + 1
+ if (count > MAX_HEADER_LINE_LENGTH) {
+ throw new IOException("Maximum header line length exceeded.")
+ }
+ next = in.read()
+ }
+ new String(bytes.toByteArray(), Charsets.UTF_8)
+ }
+
+ // Parse the header metadata in the form of k=v pairs
+ // This assumes that every line before the header end marker follows this format
try {
- val fileStatuses = fileSystem.listStatus(logDir)
- val filePaths =
- if (fileStatuses != null) {
- fileStatuses.filter(!_.isDir).map(_.getPath).toSeq
- } else {
- Seq[Path]()
+ val meta = new mutable.HashMap[String, String]()
+ var foundEndMarker = false
+ while (!foundEndMarker) {
+ readLine() match {
+ case HEADER_END_MARKER =>
+ foundEndMarker = true
+ case entry =>
+ val prop = entry.split("=", 2)
+ if (prop.length != 2) {
+ throw new IllegalArgumentException("Invalid metadata in log file.")
+ }
+ meta += (prop(0) -> prop(1))
}
- if (filePaths.isEmpty) {
- logWarning("No files found in logging directory %s".format(logDir))
}
- EventLoggingInfo(
- logPaths = filePaths.filter { path => isEventLogFile(path.getName) },
- sparkVersion = filePaths
- .find { path => isSparkVersionFile(path.getName) }
- .map { path => parseSparkVersion(path.getName) }
- .getOrElse(""),
- compressionCodec = filePaths
- .find { path => isCompressionCodecFile(path.getName) }
- .map { path =>
- val codec = EventLoggingListener.parseCompressionCodec(path.getName)
- val conf = new SparkConf
- conf.set("spark.io.compression.codec", codec)
- codecMap.getOrElseUpdate(codec, CompressionCodec.createCodec(conf))
- },
- applicationComplete = filePaths.exists { path => isApplicationCompleteFile(path.getName) }
- )
+
+ val sparkVersion = meta.get("version").getOrElse(
+ throw new IllegalArgumentException("Missing Spark version in log metadata."))
+ val codec = meta.get("compressionCodec").map { codecName =>
+ codecMap.getOrElseUpdate(codecName, CompressionCodec.createCodec(new SparkConf, codecName))
+ }
+ (codec.map(_.compressedInputStream(in)).getOrElse(in), sparkVersion)
} catch {
case e: Exception =>
- logError("Exception in parsing logging info from directory %s".format(logDir), e)
- EventLoggingInfo.empty
+ in.close()
+ throw e
}
}
- /**
- * Parse the event logging information associated with the logs in the given directory.
- */
- def parseLoggingInfo(logDir: String, fileSystem: FileSystem): EventLoggingInfo = {
- parseLoggingInfo(new Path(logDir), fileSystem)
- }
-}
-
-
-/**
- * Information needed to process the event logs associated with an application.
- */
-private[spark] case class EventLoggingInfo(
- logPaths: Seq[Path],
- sparkVersion: String,
- compressionCodec: Option[CompressionCodec],
- applicationComplete: Boolean = false)
-
-private[spark] object EventLoggingInfo {
- def empty = EventLoggingInfo(Seq[Path](), "", None, applicationComplete = false)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
index 36a6e6338faa6..be23056e7d423 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala
@@ -17,10 +17,9 @@
package org.apache.spark.scheduler
-import java.util.concurrent.{LinkedBlockingQueue, Semaphore}
+import java.util.concurrent.atomic.AtomicBoolean
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.util.AsynchronousListenerBus
/**
* Asynchronously passes SparkListenerEvents to registered SparkListeners.
@@ -29,113 +28,19 @@ import org.apache.spark.util.Utils
* has started will events be actually propagated to all attached listeners. This listener bus
* is stopped when it receives a SparkListenerShutdown event, which is posted using stop().
*/
-private[spark] class LiveListenerBus extends SparkListenerBus with Logging {
-
- /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than
- * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */
- private val EVENT_QUEUE_CAPACITY = 10000
- private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY)
- private var queueFullErrorMessageLogged = false
- private var started = false
-
- // A counter that represents the number of events produced and consumed in the queue
- private val eventLock = new Semaphore(0)
-
- private val listenerThread = new Thread("SparkListenerBus") {
- setDaemon(true)
- override def run(): Unit = Utils.logUncaughtExceptions {
- while (true) {
- eventLock.acquire()
- // Atomically remove and process this event
- LiveListenerBus.this.synchronized {
- val event = eventQueue.poll
- if (event == SparkListenerShutdown) {
- // Get out of the while loop and shutdown the daemon thread
- return
- }
- Option(event).foreach(postToAll)
- }
- }
- }
- }
-
- /**
- * Start sending events to attached listeners.
- *
- * This first sends out all buffered events posted before this listener bus has started, then
- * listens for any additional events asynchronously while the listener bus is still running.
- * This should only be called once.
- */
- def start() {
- if (started) {
- throw new IllegalStateException("Listener bus already started!")
+private[spark] class LiveListenerBus
+ extends AsynchronousListenerBus[SparkListener, SparkListenerEvent]("SparkListenerBus")
+ with SparkListenerBus {
+
+ private val logDroppedEvent = new AtomicBoolean(false)
+
+ override def onDropEvent(event: SparkListenerEvent): Unit = {
+ if (logDroppedEvent.compareAndSet(false, true)) {
+ // Only log the following message once to avoid duplicated annoying logs.
+ logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
+ "This likely means one of the SparkListeners is too slow and cannot keep up with " +
+ "the rate at which tasks are being started by the scheduler.")
}
- listenerThread.start()
- started = true
}
- def post(event: SparkListenerEvent) {
- val eventAdded = eventQueue.offer(event)
- if (eventAdded) {
- eventLock.release()
- } else {
- logQueueFullErrorMessage()
- }
- }
-
- /**
- * For testing only. Wait until there are no more events in the queue, or until the specified
- * time has elapsed. Return true if the queue has emptied and false is the specified time
- * elapsed before the queue emptied.
- */
- def waitUntilEmpty(timeoutMillis: Int): Boolean = {
- val finishTime = System.currentTimeMillis + timeoutMillis
- while (!queueIsEmpty) {
- if (System.currentTimeMillis > finishTime) {
- return false
- }
- /* Sleep rather than using wait/notify, because this is used only for testing and
- * wait/notify add overhead in the general case. */
- Thread.sleep(10)
- }
- true
- }
-
- /**
- * For testing only. Return whether the listener daemon thread is still alive.
- */
- def listenerThreadIsAlive: Boolean = synchronized { listenerThread.isAlive }
-
- /**
- * Return whether the event queue is empty.
- *
- * The use of synchronized here guarantees that all events that once belonged to this queue
- * have already been processed by all attached listeners, if this returns true.
- */
- def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty }
-
- /**
- * Log an error message to indicate that the event queue is full. Do this only once.
- */
- private def logQueueFullErrorMessage(): Unit = {
- if (!queueFullErrorMessageLogged) {
- if (listenerThread.isAlive) {
- logError("Dropping SparkListenerEvent because no remaining room in event queue. " +
- "This likely means one of the SparkListeners is too slow and cannot keep up with" +
- "the rate at which tasks are being started by the scheduler.")
- } else {
- logError("SparkListenerBus thread is dead! This means SparkListenerEvents have not" +
- "been (and will no longer be) propagated to listeners for some time.")
- }
- queueFullErrorMessageLogged = true
- }
- }
-
- def stop() {
- if (!started) {
- throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!")
- }
- post(SparkListenerShutdown)
- listenerThread.join()
- }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 01d5943d777f3..1efce124c0a6b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -122,7 +122,7 @@ private[spark] class CompressedMapStatus(
/**
* A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
- * plus a bitmap for tracking which blocks are non-empty. During serialization, this bitmap
+ * plus a bitmap for tracking which blocks are empty. During serialization, this bitmap
* is compressed.
*
* @param loc location where the task is being executed
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
index f89724d4ea196..584f4e7789d1a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
@@ -17,74 +17,45 @@
package org.apache.spark.scheduler
-import java.io.{BufferedInputStream, InputStream}
+import java.io.{InputStream, IOException}
import scala.io.Source
-import org.apache.hadoop.fs.{Path, FileSystem}
import org.json4s.jackson.JsonMethods._
import org.apache.spark.Logging
-import org.apache.spark.io.CompressionCodec
import org.apache.spark.util.JsonProtocol
/**
- * A SparkListenerBus that replays logged events from persisted storage.
- *
- * This assumes the given paths are valid log files, where each line can be deserialized into
- * exactly one SparkListenerEvent.
+ * A SparkListenerBus that can be used to replay events from serialized event data.
*/
-private[spark] class ReplayListenerBus(
- logPaths: Seq[Path],
- fileSystem: FileSystem,
- compressionCodec: Option[CompressionCodec])
- extends SparkListenerBus with Logging {
-
- private var replayed = false
-
- if (logPaths.length == 0) {
- logWarning("Log path provided contains no log files.")
- }
+private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
/**
- * Replay each event in the order maintained in the given logs.
- * This should only be called exactly once.
+ * Replay each event in the order maintained in the given stream. The stream is expected to
+ * contain one JSON-encoded SparkListenerEvent per line.
+ *
+ * This method can be called multiple times, but the listener behavior is undefined after any
+ * error is thrown by this method.
+ *
+ * @param logData Stream containing event log data.
+ * @param version Spark version that generated the events.
*/
- def replay() {
- assert(!replayed, "ReplayListenerBus cannot replay events more than once")
- logPaths.foreach { path =>
- // Keep track of input streams at all levels to close them later
- // This is necessary because an exception can occur in between stream initializations
- var fileStream: Option[InputStream] = None
- var bufferedStream: Option[InputStream] = None
- var compressStream: Option[InputStream] = None
- var currentLine = ""
- try {
- fileStream = Some(fileSystem.open(path))
- bufferedStream = Some(new BufferedInputStream(fileStream.get))
- compressStream = Some(wrapForCompression(bufferedStream.get))
-
- // Parse each line as an event and post the event to all attached listeners
- val lines = Source.fromInputStream(compressStream.get).getLines()
- lines.foreach { line =>
- currentLine = line
- postToAll(JsonProtocol.sparkEventFromJson(parse(line)))
- }
- } catch {
- case e: Exception =>
- logError("Exception in parsing Spark event log %s".format(path), e)
- logError("Malformed line: %s\n".format(currentLine))
- } finally {
- fileStream.foreach(_.close())
- bufferedStream.foreach(_.close())
- compressStream.foreach(_.close())
+ def replay(logData: InputStream, version: String) {
+ var currentLine: String = null
+ try {
+ val lines = Source.fromInputStream(logData).getLines()
+ lines.foreach { line =>
+ currentLine = line
+ postToAll(JsonProtocol.sparkEventFromJson(parse(line)))
}
+ } catch {
+ case ioe: IOException =>
+ throw ioe
+ case e: Exception =>
+ logError("Exception in parsing Spark event log.", e)
+ logError("Malformed line: %s\n".format(currentLine))
}
- replayed = true
}
- /** If a compression codec is specified, wrap the given stream in a compression stream. */
- private def wrapForCompression(stream: InputStream): InputStream = {
- compressionCodec.map(_.compressedInputStream(stream)).getOrElse(stream)
- }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index b62b0c1312693..dd28ddb31de1f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -25,6 +25,7 @@ import scala.collection.mutable
import org.apache.spark.{Logging, TaskEndReason}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.{Distribution, Utils}
@@ -58,6 +59,7 @@ case class SparkListenerTaskEnd(
@DeveloperApi
case class SparkListenerJobStart(
jobId: Int,
+ time: Long,
stageInfos: Seq[StageInfo],
properties: Properties = null)
extends SparkListenerEvent {
@@ -67,7 +69,11 @@ case class SparkListenerJobStart(
}
@DeveloperApi
-case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent
+case class SparkListenerJobEnd(
+ jobId: Int,
+ time: Long,
+ jobResult: JobResult)
+ extends SparkListenerEvent
@DeveloperApi
case class SparkListenerEnvironmentUpdate(environmentDetails: Map[String, Seq[(String, String)]])
@@ -84,6 +90,14 @@ case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockMan
@DeveloperApi
case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent
+@DeveloperApi
+case class SparkListenerExecutorAdded(time: Long, executorId: String, executorInfo: ExecutorInfo)
+ extends SparkListenerEvent
+
+@DeveloperApi
+case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String)
+ extends SparkListenerEvent
+
/**
* Periodic updates from executors.
* @param execId executor id
@@ -102,14 +116,12 @@ case class SparkListenerApplicationStart(appName: String, appId: Option[String],
@DeveloperApi
case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent
-/** An event used in the listener to shutdown the listener daemon thread. */
-private[spark] case object SparkListenerShutdown extends SparkListenerEvent
-
/**
* :: DeveloperApi ::
* Interface for listening to events from the Spark scheduler. Note that this is an internal
- * interface which might change in different Spark releases.
+ * interface which might change in different Spark releases. Java clients should extend
+ * {@link JavaSparkListener}
*/
@DeveloperApi
trait SparkListener {
@@ -183,6 +195,16 @@ trait SparkListener {
* Called when the driver receives task metrics from an executor in a heartbeat.
*/
def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) { }
+
+ /**
+ * Called when the driver registers a new executor.
+ */
+ def onExecutorAdded(executorAdded: SparkListenerExecutorAdded) { }
+
+ /**
+ * Called when the driver removes an executor.
+ */
+ def onExecutorRemoved(executorRemoved: SparkListenerExecutorRemoved) { }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index e79ffd7a3587d..fe8a19a2c0cb9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -17,74 +17,47 @@
package org.apache.spark.scheduler
-import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.util.ListenerBus
/**
- * A SparkListenerEvent bus that relays events to its listeners
+ * A [[SparkListenerEvent]] bus that relays [[SparkListenerEvent]]s to its listeners
*/
-private[spark] trait SparkListenerBus extends Logging {
-
- // SparkListeners attached to this event bus
- protected val sparkListeners = new ArrayBuffer[SparkListener]
- with mutable.SynchronizedBuffer[SparkListener]
-
- def addListener(listener: SparkListener) {
- sparkListeners += listener
- }
+private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] {
- /**
- * Post an event to all attached listeners.
- * This does nothing if the event is SparkListenerShutdown.
- */
- def postToAll(event: SparkListenerEvent) {
+ override def onPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = {
event match {
case stageSubmitted: SparkListenerStageSubmitted =>
- foreachListener(_.onStageSubmitted(stageSubmitted))
+ listener.onStageSubmitted(stageSubmitted)
case stageCompleted: SparkListenerStageCompleted =>
- foreachListener(_.onStageCompleted(stageCompleted))
+ listener.onStageCompleted(stageCompleted)
case jobStart: SparkListenerJobStart =>
- foreachListener(_.onJobStart(jobStart))
+ listener.onJobStart(jobStart)
case jobEnd: SparkListenerJobEnd =>
- foreachListener(_.onJobEnd(jobEnd))
+ listener.onJobEnd(jobEnd)
case taskStart: SparkListenerTaskStart =>
- foreachListener(_.onTaskStart(taskStart))
+ listener.onTaskStart(taskStart)
case taskGettingResult: SparkListenerTaskGettingResult =>
- foreachListener(_.onTaskGettingResult(taskGettingResult))
+ listener.onTaskGettingResult(taskGettingResult)
case taskEnd: SparkListenerTaskEnd =>
- foreachListener(_.onTaskEnd(taskEnd))
+ listener.onTaskEnd(taskEnd)
case environmentUpdate: SparkListenerEnvironmentUpdate =>
- foreachListener(_.onEnvironmentUpdate(environmentUpdate))
+ listener.onEnvironmentUpdate(environmentUpdate)
case blockManagerAdded: SparkListenerBlockManagerAdded =>
- foreachListener(_.onBlockManagerAdded(blockManagerAdded))
+ listener.onBlockManagerAdded(blockManagerAdded)
case blockManagerRemoved: SparkListenerBlockManagerRemoved =>
- foreachListener(_.onBlockManagerRemoved(blockManagerRemoved))
+ listener.onBlockManagerRemoved(blockManagerRemoved)
case unpersistRDD: SparkListenerUnpersistRDD =>
- foreachListener(_.onUnpersistRDD(unpersistRDD))
+ listener.onUnpersistRDD(unpersistRDD)
case applicationStart: SparkListenerApplicationStart =>
- foreachListener(_.onApplicationStart(applicationStart))
+ listener.onApplicationStart(applicationStart)
case applicationEnd: SparkListenerApplicationEnd =>
- foreachListener(_.onApplicationEnd(applicationEnd))
+ listener.onApplicationEnd(applicationEnd)
case metricsUpdate: SparkListenerExecutorMetricsUpdate =>
- foreachListener(_.onExecutorMetricsUpdate(metricsUpdate))
- case SparkListenerShutdown =>
- }
- }
-
- /**
- * Apply the given function to all attached listeners, catching and logging any exception.
- */
- private def foreachListener(f: SparkListener => Unit): Unit = {
- sparkListeners.foreach { listener =>
- try {
- f(listener)
- } catch {
- case e: Exception =>
- logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e)
- }
+ listener.onExecutorMetricsUpdate(metricsUpdate)
+ case executorAdded: SparkListenerExecutorAdded =>
+ listener.onExecutorAdded(executorAdded)
+ case executorRemoved: SparkListenerExecutorRemoved =>
+ listener.onExecutorRemoved(executorRemoved)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 2552d03d18d06..847a4912eec13 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -44,10 +44,18 @@ import org.apache.spark.util.Utils
*/
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
- final def run(attemptId: Long): T = {
- context = new TaskContextImpl(stageId, partitionId, attemptId, false)
+ /**
+ * Called by Executor to run this task.
+ *
+ * @param taskAttemptId an identifier for this task attempt that is unique within a SparkContext.
+ * @param attemptNumber how many times this task has been attempted (0 for the first attempt)
+ * @return the result of the task
+ */
+ final def run(taskAttemptId: Long, attemptNumber: Int): T = {
+ context = new TaskContextImpl(stageId = stageId, partitionId = partitionId,
+ taskAttemptId = taskAttemptId, attemptNumber = attemptNumber, runningLocally = false)
TaskContextHelper.setTaskContext(context)
- context.taskMetrics.hostname = Utils.localHostName()
+ context.taskMetrics.setHostname(Utils.localHostName())
taskThread = Thread.currentThread()
if (_killed) {
kill(interruptThread = false)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
index 4c96b9e5fef60..1c7c81c488c3a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala
@@ -27,6 +27,7 @@ import org.apache.spark.util.SerializableBuffer
*/
private[spark] class TaskDescription(
val taskId: Long,
+ val attemptNumber: Int,
val executorId: String,
val name: String,
val index: Int, // Index within this task's TaskSet
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 819b51e12ad8c..774f3d8cdb275 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import scala.language.existentials
import scala.util.control.NonFatal
import org.apache.spark._
@@ -76,7 +77,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
(deserializedResult, size)
}
- result.metrics.resultSize = size
+ result.metrics.setResultSize(size)
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
} catch {
case cnf: ClassNotFoundException =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index cd3c015321e85..79f84e70df9d5 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -31,6 +31,7 @@ import scala.util.Random
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+import org.apache.spark.scheduler.TaskLocality.TaskLocality
import org.apache.spark.util.Utils
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
@@ -167,7 +168,7 @@ private[spark] class TaskSchedulerImpl(
if (!hasLaunchedTask) {
logWarning("Initial job has not accepted any resources; " +
"check your cluster UI to ensure that workers are registered " +
- "and have sufficient memory")
+ "and have sufficient resources")
} else {
this.cancel()
}
@@ -209,6 +210,40 @@ private[spark] class TaskSchedulerImpl(
.format(manager.taskSet.id, manager.parent.name))
}
+ private def resourceOfferSingleTaskSet(
+ taskSet: TaskSetManager,
+ maxLocality: TaskLocality,
+ shuffledOffers: Seq[WorkerOffer],
+ availableCpus: Array[Int],
+ tasks: Seq[ArrayBuffer[TaskDescription]]) : Boolean = {
+ var launchedTask = false
+ for (i <- 0 until shuffledOffers.size) {
+ val execId = shuffledOffers(i).executorId
+ val host = shuffledOffers(i).host
+ if (availableCpus(i) >= CPUS_PER_TASK) {
+ try {
+ for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
+ tasks(i) += task
+ val tid = task.taskId
+ taskIdToTaskSetId(tid) = taskSet.taskSet.id
+ taskIdToExecutorId(tid) = execId
+ executorsByHost(host) += execId
+ availableCpus(i) -= CPUS_PER_TASK
+ assert(availableCpus(i) >= 0)
+ launchedTask = true
+ }
+ } catch {
+ case e: TaskNotSerializableException =>
+ logError(s"Resource offer failed, task set ${taskSet.name} was not serializable")
+ // Do not offer resources for this task, but don't throw an error to allow other
+ // task sets to be submitted.
+ return launchedTask
+ }
+ }
+ }
+ return launchedTask
+ }
+
/**
* Called by cluster manager to offer resources on slaves. We respond by asking our active task
* sets for tasks in order of priority. We fill each node with tasks in a round-robin manner so
@@ -251,23 +286,8 @@ private[spark] class TaskSchedulerImpl(
var launchedTask = false
for (taskSet <- sortedTaskSets; maxLocality <- taskSet.myLocalityLevels) {
do {
- launchedTask = false
- for (i <- 0 until shuffledOffers.size) {
- val execId = shuffledOffers(i).executorId
- val host = shuffledOffers(i).host
- if (availableCpus(i) >= CPUS_PER_TASK) {
- for (task <- taskSet.resourceOffer(execId, host, maxLocality)) {
- tasks(i) += task
- val tid = task.taskId
- taskIdToTaskSetId(tid) = taskSet.taskSet.id
- taskIdToExecutorId(tid) = execId
- executorsByHost(host) += execId
- availableCpus(i) -= CPUS_PER_TASK
- assert(availableCpus(i) >= 0)
- launchedTask = true
- }
- }
- }
+ launchedTask = resourceOfferSingleTaskSet(
+ taskSet, maxLocality, shuffledOffers, availableCpus, tasks)
} while (launchedTask)
}
@@ -341,7 +361,7 @@ private[spark] class TaskSchedulerImpl(
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)
}
- def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long) {
+ def handleTaskGettingResult(taskSetManager: TaskSetManager, tid: Long): Unit = synchronized {
taskSetManager.handleTaskGettingResult(tid)
}
@@ -394,9 +414,6 @@ private[spark] class TaskSchedulerImpl(
taskResultGetter.stop()
}
starvationTimer.cancel()
-
- // sleeping for an arbitrary 1 seconds to ensure that messages are sent out.
- Thread.sleep(1000L)
}
override def defaultParallelism() = backend.defaultParallelism()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index cabdc655f89bf..55024ecd55e61 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -18,12 +18,14 @@
package org.apache.spark.scheduler
import java.io.NotSerializableException
+import java.nio.ByteBuffer
import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
import scala.math.{min, max}
+import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.executor.TaskMetrics
@@ -249,7 +251,7 @@ private[spark] class TaskSetManager(
* This method also cleans up any tasks in the list that have already
* been launched, since we want that to happen lazily.
*/
- private def findTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = {
+ private def dequeueTaskFromList(execId: String, list: ArrayBuffer[Int]): Option[Int] = {
var indexOffset = list.size
while (indexOffset > 0) {
indexOffset -= 1
@@ -290,7 +292,7 @@ private[spark] class TaskSetManager(
* an attempt running on this host, in case the host is slow. In addition, the task should meet
* the given locality constraint.
*/
- private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
@@ -366,22 +368,22 @@ private[spark] class TaskSetManager(
*
* @return An option containing (task index within the task set, locality, is speculative?)
*/
- private def findTask(execId: String, host: String, maxLocality: TaskLocality.Value)
+ private def dequeueTask(execId: String, host: String, maxLocality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value, Boolean)] =
{
- for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) {
+ for (index <- dequeueTaskFromList(execId, getPendingTasksForExecutor(execId))) {
return Some((index, TaskLocality.PROCESS_LOCAL, false))
}
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) {
- for (index <- findTaskFromList(execId, getPendingTasksForHost(host))) {
+ for (index <- dequeueTaskFromList(execId, getPendingTasksForHost(host))) {
return Some((index, TaskLocality.NODE_LOCAL, false))
}
}
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) {
// Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic
- for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) {
+ for (index <- dequeueTaskFromList(execId, pendingTasksWithNoPrefs)) {
return Some((index, TaskLocality.PROCESS_LOCAL, false))
}
}
@@ -389,20 +391,20 @@ private[spark] class TaskSetManager(
if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) {
for {
rack <- sched.getRackForHost(host)
- index <- findTaskFromList(execId, getPendingTasksForRack(rack))
+ index <- dequeueTaskFromList(execId, getPendingTasksForRack(rack))
} {
return Some((index, TaskLocality.RACK_LOCAL, false))
}
}
if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) {
- for (index <- findTaskFromList(execId, allPendingTasks)) {
+ for (index <- dequeueTaskFromList(execId, allPendingTasks)) {
return Some((index, TaskLocality.ANY, false))
}
}
// find a speculative task if all others tasks have been scheduled
- findSpeculativeTask(execId, host, maxLocality).map {
+ dequeueSpeculativeTask(execId, host, maxLocality).map {
case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)}
}
@@ -417,6 +419,7 @@ private[spark] class TaskSetManager(
* @param host the host Id of the offered resource
* @param maxLocality the maximum locality we want to schedule the tasks at
*/
+ @throws[TaskNotSerializableException]
def resourceOffer(
execId: String,
host: String,
@@ -436,7 +439,7 @@ private[spark] class TaskSetManager(
}
}
- findTask(execId, host, allowedLocality) match {
+ dequeueTask(execId, host, allowedLocality) match {
case Some((index, taskLocality, speculative)) => {
// Found a task; do some bookkeeping and return a task description
val task = tasks(index)
@@ -456,10 +459,17 @@ private[spark] class TaskSetManager(
}
// Serialize and return the task
val startTime = clock.getTime()
- // We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
- // we assume the task can be serialized without exceptions.
- val serializedTask = Task.serializeWithDependencies(
- task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ val serializedTask: ByteBuffer = try {
+ Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
+ } catch {
+ // If the task cannot be serialized, then there's no point to re-attempt the task,
+ // as it will always fail. So just abort the whole task-set.
+ case NonFatal(e) =>
+ val msg = s"Failed to serialize task $taskId, not attempting to retry it."
+ logError(msg, e)
+ abort(s"$msg Exception during serialization: $e")
+ throw new TaskNotSerializableException(e)
+ }
if (serializedTask.limit > TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024 &&
!emittedTaskSizeWarning) {
emittedTaskSizeWarning = true
@@ -477,7 +487,8 @@ private[spark] class TaskSetManager(
taskName, taskId, host, taskLocality, serializedTask.limit))
sched.dagScheduler.taskStarted(task, info)
- return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
+ return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId,
+ taskName, index, serializedTask))
}
case _ =>
}
@@ -495,13 +506,64 @@ private[spark] class TaskSetManager(
* Get the level we can launch tasks according to delay scheduling, based on current wait time.
*/
private def getAllowedLocalityLevel(curTime: Long): TaskLocality.TaskLocality = {
- while (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex) &&
- currentLocalityIndex < myLocalityLevels.length - 1)
- {
- // Jump to the next locality level, and remove our waiting time for the current one since
- // we don't want to count it again on the next one
- lastLaunchTime += localityWaits(currentLocalityIndex)
- currentLocalityIndex += 1
+ // Remove the scheduled or finished tasks lazily
+ def tasksNeedToBeScheduledFrom(pendingTaskIds: ArrayBuffer[Int]): Boolean = {
+ var indexOffset = pendingTaskIds.size
+ while (indexOffset > 0) {
+ indexOffset -= 1
+ val index = pendingTaskIds(indexOffset)
+ if (copiesRunning(index) == 0 && !successful(index)) {
+ return true
+ } else {
+ pendingTaskIds.remove(indexOffset)
+ }
+ }
+ false
+ }
+ // Walk through the list of tasks that can be scheduled at each location and returns true
+ // if there are any tasks that still need to be scheduled. Lazily cleans up tasks that have
+ // already been scheduled.
+ def moreTasksToRunIn(pendingTasks: HashMap[String, ArrayBuffer[Int]]): Boolean = {
+ val emptyKeys = new ArrayBuffer[String]
+ val hasTasks = pendingTasks.exists {
+ case (id: String, tasks: ArrayBuffer[Int]) =>
+ if (tasksNeedToBeScheduledFrom(tasks)) {
+ true
+ } else {
+ emptyKeys += id
+ false
+ }
+ }
+ // The key could be executorId, host or rackId
+ emptyKeys.foreach(id => pendingTasks.remove(id))
+ hasTasks
+ }
+
+ while (currentLocalityIndex < myLocalityLevels.length - 1) {
+ val moreTasks = myLocalityLevels(currentLocalityIndex) match {
+ case TaskLocality.PROCESS_LOCAL => moreTasksToRunIn(pendingTasksForExecutor)
+ case TaskLocality.NODE_LOCAL => moreTasksToRunIn(pendingTasksForHost)
+ case TaskLocality.NO_PREF => pendingTasksWithNoPrefs.nonEmpty
+ case TaskLocality.RACK_LOCAL => moreTasksToRunIn(pendingTasksForRack)
+ }
+ if (!moreTasks) {
+ // This is a performance optimization: if there are no more tasks that can
+ // be scheduled at a particular locality level, there is no point in waiting
+ // for the locality wait timeout (SPARK-4939).
+ lastLaunchTime = curTime
+ logDebug(s"No tasks for locality level ${myLocalityLevels(currentLocalityIndex)}, " +
+ s"so moving to locality level ${myLocalityLevels(currentLocalityIndex + 1)}")
+ currentLocalityIndex += 1
+ } else if (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex)) {
+ // Jump to the next locality level, and reset lastLaunchTime so that the next locality
+ // wait timer doesn't immediately expire
+ lastLaunchTime += localityWaits(currentLocalityIndex)
+ currentLocalityIndex += 1
+ logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex)} after waiting for " +
+ s"${localityWaits(currentLocalityIndex)}ms")
+ } else {
+ return myLocalityLevels(currentLocalityIndex)
+ }
}
myLocalityLevels(currentLocalityIndex)
}
@@ -531,7 +593,7 @@ private[spark] class TaskSetManager(
/**
* Check whether has enough quota to fetch the result with `size` bytes
*/
- def canFetchMoreResults(size: Long): Boolean = synchronized {
+ def canFetchMoreResults(size: Long): Boolean = sched.synchronized {
totalResultSize += size
calculatedTasks += 1
if (maxResultSize > 0 && totalResultSize > maxResultSize) {
@@ -660,7 +722,7 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}
- def abort(message: String) {
+ def abort(message: String): Unit = sched.synchronized {
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.dagScheduler.taskSetFailed(taskSet, message)
isZombie = true
@@ -704,7 +766,7 @@ private[spark] class TaskSetManager(
// Re-enqueue pending tasks for this host based on the status of the cluster. Note
// that it's okay if we add a task to the same queue twice (if it had multiple preferred
- // locations), because findTaskFromList will skip already-running tasks.
+ // locations), because dequeueTaskFromList will skip already-running tasks.
for (index <- getPendingTasksForExecutor(execId)) {
addPendingTask(index, readding=true)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 1da6fe976da5b..9bf74f4be198d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -39,7 +39,11 @@ private[spark] object CoarseGrainedClusterMessages {
case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage
// Executors to driver
- case class RegisterExecutor(executorId: String, hostPort: String, cores: Int)
+ case class RegisterExecutor(
+ executorId: String,
+ hostPort: String,
+ cores: Int,
+ logUrls: Map[String, String])
extends CoarseGrainedClusterMessage {
Utils.checkHostPort(hostPort, "Expected host port")
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 88b196ac64368..9d2fb4f3b4729 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -27,8 +27,8 @@ import akka.actor._
import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState}
-import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer}
+import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
+import org.apache.spark.scheduler._
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils}
@@ -42,7 +42,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut
*/
private[spark]
class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem)
- extends SchedulerBackend with Logging
+ extends ExecutorAllocationClient with SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
@@ -66,6 +66,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
// Number of executors requested from the cluster manager that have not registered yet
private var numPendingExecutors = 0
+ private val listenerBus = scheduler.sc.listenerBus
+
// Executors we have requested the cluster manager to kill that have not died yet
private val executorsPendingToRemove = new HashSet[String]
@@ -84,7 +86,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
}
def receiveWithLogging = {
- case RegisterExecutor(executorId, hostPort, cores) =>
+ case RegisterExecutor(executorId, hostPort, cores, logUrls) =>
Utils.checkHostPort(hostPort, "Host port expected " + hostPort)
if (executorDataMap.contains(executorId)) {
sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId)
@@ -96,7 +98,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
val (host, _) = Utils.parseHostPort(hostPort)
- val data = new ExecutorData(sender, sender.path.address, host, cores, cores)
+ val data = new ExecutorData(sender, sender.path.address, host, cores, cores, logUrls)
// This must be synchronized because variables mutated
// in this block are read when requesting executors
CoarseGrainedSchedulerBackend.this.synchronized {
@@ -106,6 +108,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
}
}
+ listenerBus.post(
+ SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data))
makeOffers()
}
@@ -213,6 +217,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
totalCoreCount.addAndGet(-executorInfo.totalCores)
totalRegisteredExecutors.addAndGet(-1)
scheduler.executorLost(executorId, SlaveLost(reason))
+ listenerBus.post(
+ SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason))
case None => logError(s"Asked to remove non-existent executor $executorId")
}
}
@@ -307,7 +313,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
* Request an additional number of executors from the cluster manager.
* Return whether the request is acknowledged.
*/
- final def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
+ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
logDebug(s"Number of pending executors is now $numPendingExecutors")
numPendingExecutors += numAdditionalExecutors
@@ -334,7 +340,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
* Request that the cluster manager kill the specified executors.
* Return whether the kill request is acknowledged.
*/
- final def killExecutors(executorIds: Seq[String]): Boolean = {
+ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized {
logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
val filteredExecutorIds = new ArrayBuffer[String]
executorIds.foreach { id =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
index b71bd5783d6df..5e571efe76720 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala
@@ -31,7 +31,8 @@ import akka.actor.{Address, ActorRef}
private[cluster] class ExecutorData(
val executorActor: ActorRef,
val executorAddress: Address,
- val executorHost: String ,
+ override val executorHost: String,
var freeCores: Int,
- val totalCores: Int
-)
+ override val totalCores: Int,
+ override val logUrlMap: Map[String, String]
+) extends ExecutorInfo(executorHost, totalCores, logUrlMap)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala
new file mode 100644
index 0000000000000..7f218566146a1
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorInfo.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Stores information about an executor to pass from the scheduler to SparkListeners.
+ */
+@DeveloperApi
+class ExecutorInfo(
+ val executorHost: String,
+ val totalCores: Int,
+ val logUrlMap: Map[String, String]) {
+
+ def canEqual(other: Any): Boolean = other.isInstanceOf[ExecutorInfo]
+
+ override def equals(other: Any): Boolean = other match {
+ case that: ExecutorInfo =>
+ (that canEqual this) &&
+ executorHost == that.executorHost &&
+ totalCores == that.totalCores &&
+ logUrlMap == that.logUrlMap
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ val state = Seq(executorHost, totalCores, logUrlMap)
+ state.map(_.hashCode()).foldLeft(0)((a, b) => 31 * a + b)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index ee10aa061f4e9..06786a59524e7 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -22,6 +22,7 @@ import org.apache.hadoop.fs.{Path, FileSystem}
import org.apache.spark.{Logging, SparkContext, SparkEnv}
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.util.AkkaUtils
private[spark] class SimrSchedulerBackend(
scheduler: TaskSchedulerImpl,
@@ -38,7 +39,8 @@ private[spark] class SimrSchedulerBackend(
override def start() {
super.start()
- val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ val driverUrl = AkkaUtils.address(
+ AkkaUtils.protocol(actorSystem),
SparkEnv.driverActorSystemName,
sc.conf.get("spark.driver.host"),
sc.conf.get("spark.driver.port"),
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index 8c7de75600b5f..d2e1680a5fd1b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -21,7 +21,7 @@ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
import org.apache.spark.scheduler.{ExecutorExited, ExecutorLossReason, SlaveLost, TaskSchedulerImpl}
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{AkkaUtils, Utils}
private[spark] class SparkDeploySchedulerBackend(
scheduler: TaskSchedulerImpl,
@@ -46,7 +46,8 @@ private[spark] class SparkDeploySchedulerBackend(
super.start()
// The endpoint for executors to talk to us
- val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ val driverUrl = AkkaUtils.address(
+ AkkaUtils.protocol(actorSystem),
SparkEnv.driverActorSystemName,
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
@@ -55,19 +56,26 @@ private[spark] class SparkDeploySchedulerBackend(
"{{WORKER_URL}}")
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
.map(Utils.splitCommandString).getOrElse(Seq.empty)
- val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath").toSeq.flatMap { cp =>
- cp.split(java.io.File.pathSeparator)
- }
- val libraryPathEntries =
- sc.conf.getOption("spark.executor.extraLibraryPath").toSeq.flatMap { cp =>
- cp.split(java.io.File.pathSeparator)
+ val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath")
+ .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
+ val libraryPathEntries = sc.conf.getOption("spark.executor.extraLibraryPath")
+ .map(_.split(java.io.File.pathSeparator).toSeq).getOrElse(Nil)
+
+ // When testing, expose the parent class path to the child. This is processed by
+ // compute-classpath.{cmd,sh} and makes all needed jars available to child processes
+ // when the assembly is built with the "*-provided" profiles enabled.
+ val testingClassPath =
+ if (sys.props.contains("spark.testing")) {
+ sys.props("java.class.path").split(java.io.File.pathSeparator).toSeq
+ } else {
+ Nil
}
// Start executors with a few necessary configs for registering with the scheduler
val sparkJavaOpts = Utils.sparkJavaOpts(conf, SparkConf.isExecutorStartupConf)
val javaOpts = sparkJavaOpts ++ extraJavaOpts
val command = Command("org.apache.spark.executor.CoarseGrainedExecutorBackend",
- args, sc.executorEnvs, classPathEntries, libraryPathEntries, javaOpts)
+ args, sc.executorEnvs, classPathEntries ++ testingClassPath, libraryPathEntries, javaOpts)
val appUIAddress = sc.ui.map(_.appUIAddress).getOrElse("")
val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command,
appUIAddress, sc.eventLogDir)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 50721b9d6cd6c..f14aaeea0a25c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -17,6 +17,8 @@
package org.apache.spark.scheduler.cluster
+import scala.concurrent.{Future, ExecutionContext}
+
import akka.actor.{Actor, ActorRef, Props}
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
@@ -24,7 +26,9 @@ import org.apache.spark.SparkContext
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.ui.JettyUtils
-import org.apache.spark.util.AkkaUtils
+import org.apache.spark.util.{AkkaUtils, Utils}
+
+import scala.util.control.NonFatal
/**
* Abstract Yarn scheduler backend that contains common logic
@@ -97,6 +101,9 @@ private[spark] abstract class YarnSchedulerBackend(
private class YarnSchedulerActor extends Actor {
private var amActor: Option[ActorRef] = None
+ implicit val askAmActorExecutor = ExecutionContext.fromExecutor(
+ Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor"))
+
override def preStart(): Unit = {
// Listen for disassociation events
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
@@ -110,7 +117,12 @@ private[spark] abstract class YarnSchedulerBackend(
case r: RequestExecutors =>
amActor match {
case Some(actor) =>
- sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ val driverActor = sender
+ Future {
+ driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ } onFailure {
+ case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e)
+ }
case None =>
logWarning("Attempted to request executors before the AM has registered!")
sender ! false
@@ -119,7 +131,12 @@ private[spark] abstract class YarnSchedulerBackend(
case k: KillExecutors =>
amActor match {
case Some(actor) =>
- sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ val driverActor = sender
+ Future {
+ driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ } onFailure {
+ case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e)
+ }
case None =>
logWarning("Attempted to kill executors before the AM has registered!")
sender ! false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 5289661eb896b..0d1c2a916ca7f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -31,7 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas
import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException}
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{Utils, AkkaUtils}
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -143,7 +143,8 @@ private[spark] class CoarseMesosSchedulerBackend(
}
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
- val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format(
+ val driverUrl = AkkaUtils.address(
+ AkkaUtils.protocol(sc.env.actorSystem),
SparkEnv.driverActorSystemName,
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 10e6886c16a4f..cfb6592e14aa8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -22,14 +22,17 @@ import java.util.{ArrayList => JArrayList, List => JList}
import java.util.Collections
import scala.collection.JavaConversions._
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.collection.mutable.{HashMap, HashSet}
import org.apache.mesos.protobuf.ByteString
import org.apache.mesos.{Scheduler => MScheduler}
import org.apache.mesos._
-import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState, _}
+import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTaskState,
+ ExecutorInfo => MesosExecutorInfo, _}
+import org.apache.spark.executor.MesosExecutorBackend
import org.apache.spark.{Logging, SparkContext, SparkException, TaskState}
+import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler._
import org.apache.spark.util.Utils
@@ -62,6 +65,9 @@ private[spark] class MesosSchedulerBackend(
var classLoader: ClassLoader = null
+ // The listener bus to publish executor added/removed events.
+ val listenerBus = sc.listenerBus
+
@volatile var appId: String = _
override def start() {
@@ -87,7 +93,7 @@ private[spark] class MesosSchedulerBackend(
}
}
- def createExecutorInfo(execId: String): ExecutorInfo = {
+ def createExecutorInfo(execId: String): MesosExecutorInfo = {
val executorSparkHome = sc.conf.getOption("spark.mesos.executor.home")
.orElse(sc.getSparkHome()) // Fall back to driver Spark home for backward compatibility
.getOrElse {
@@ -118,14 +124,15 @@ private[spark] class MesosSchedulerBackend(
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
val uri = sc.conf.get("spark.executor.uri", null)
+ val executorBackendName = classOf[MesosExecutorBackend].getName
if (uri == null) {
- val executorPath = new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath
- command.setValue("%s %s".format(prefixEnv, executorPath))
+ val executorPath = new File(executorSparkHome, "/bin/spark-class").getCanonicalPath
+ command.setValue(s"$prefixEnv $executorPath $executorBackendName")
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
- command.setValue("cd %s*; %s ./sbin/spark-executor".format(basename, prefixEnv))
+ command.setValue(s"cd ${basename}*; $prefixEnv ./bin/spark-class $executorBackendName")
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
val cpus = Resource.newBuilder()
@@ -141,7 +148,7 @@ private[spark] class MesosSchedulerBackend(
Value.Scalar.newBuilder()
.setValue(MemoryUtils.calculateTotalMemory(sc)).build())
.build()
- ExecutorInfo.newBuilder()
+ MesosExecutorInfo.newBuilder()
.setExecutorId(ExecutorID.newBuilder().setValue(execId).build())
.setCommand(command)
.setData(ByteString.copyFrom(createExecArg()))
@@ -237,6 +244,7 @@ private[spark] class MesosSchedulerBackend(
}
val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap
+ val slaveIdToWorkerOffer = workerOffers.map(o => o.executorId -> o).toMap
val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]]
@@ -260,6 +268,11 @@ private[spark] class MesosSchedulerBackend(
val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
mesosTasks.foreach { case (slaveId, tasks) =>
+ slaveIdToWorkerOffer.get(slaveId).foreach(o =>
+ listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId,
+ // TODO: Add support for log urls for Mesos
+ new ExecutorInfo(o.host, o.cores, Map.empty)))
+ )
d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters)
}
@@ -296,7 +309,7 @@ private[spark] class MesosSchedulerBackend(
.setExecutor(createExecutorInfo(slaveId))
.setName(task.name)
.addResources(cpuResource)
- .setData(ByteString.copyFrom(task.serializedTask))
+ .setData(MesosTaskLaunchData(task.serializedTask, task.attemptNumber).toByteString)
.build()
}
@@ -315,7 +328,7 @@ private[spark] class MesosSchedulerBackend(
synchronized {
if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) {
// We lost the executor on this slave, so remember that it's gone
- slaveIdsWithExecutors -= taskIdToSlaveId(tid)
+ removeExecutor(taskIdToSlaveId(tid), "Lost executor")
}
if (isFinished(status.getState)) {
taskIdToSlaveId.remove(tid)
@@ -344,12 +357,20 @@ private[spark] class MesosSchedulerBackend(
override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
+ /**
+ * Remove executor associated with slaveId in a thread safe manner.
+ */
+ private def removeExecutor(slaveId: String, reason: String) = {
+ synchronized {
+ listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason))
+ slaveIdsWithExecutors -= slaveId
+ }
+ }
+
private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) {
inClassLoader() {
logInfo("Mesos slave lost: " + slaveId.getValue)
- synchronized {
- slaveIdsWithExecutors -= slaveId.getValue
- }
+ removeExecutor(slaveId.getValue, reason.toString)
scheduler.executorLost(slaveId.getValue, reason)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
new file mode 100644
index 0000000000000..5e7e6567a3e06
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosTaskLaunchData.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster.mesos
+
+import java.nio.ByteBuffer
+
+import org.apache.mesos.protobuf.ByteString
+
+import org.apache.spark.Logging
+
+/**
+ * Wrapper for serializing the data sent when launching Mesos tasks.
+ */
+private[spark] case class MesosTaskLaunchData(
+ serializedTask: ByteBuffer,
+ attemptNumber: Int) extends Logging {
+
+ def toByteString: ByteString = {
+ val dataBuffer = ByteBuffer.allocate(4 + serializedTask.limit)
+ dataBuffer.putInt(attemptNumber)
+ dataBuffer.put(serializedTask)
+ dataBuffer.rewind
+ logDebug(s"ByteBuffer size: [${dataBuffer.remaining}]")
+ ByteString.copyFrom(dataBuffer)
+ }
+}
+
+private[spark] object MesosTaskLaunchData extends Logging {
+ def fromByteString(byteString: ByteString): MesosTaskLaunchData = {
+ val byteBuffer = byteString.asReadOnlyByteBuffer()
+ logDebug(s"ByteBuffer size: [${byteBuffer.remaining}]")
+ val attemptNumber = byteBuffer.getInt // updates the position by 4 bytes
+ val serializedTask = byteBuffer.slice() // subsequence starting at the current position
+ MesosTaskLaunchData(serializedTask, attemptNumber)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index a2f1f14264a99..4676b828d3d89 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -19,6 +19,8 @@ package org.apache.spark.scheduler.local
import java.nio.ByteBuffer
+import scala.concurrent.duration._
+
import akka.actor.{Actor, ActorRef, Props}
import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
@@ -41,17 +43,20 @@ private case class StopExecutor()
* and the TaskSchedulerImpl.
*/
private[spark] class LocalActor(
- scheduler: TaskSchedulerImpl,
- executorBackend: LocalBackend,
- private val totalCores: Int) extends Actor with ActorLogReceive with Logging {
+ scheduler: TaskSchedulerImpl,
+ executorBackend: LocalBackend,
+ private val totalCores: Int)
+ extends Actor with ActorLogReceive with Logging {
+
+ import context.dispatcher // to use Akka's scheduler.scheduleOnce()
private var freeCores = totalCores
private val localExecutorId = SparkContext.DRIVER_IDENTIFIER
private val localExecutorHostname = "localhost"
- val executor = new Executor(
- localExecutorId, localExecutorHostname, scheduler.conf.getAll, totalCores, isLocal = true)
+ private val executor = new Executor(
+ localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true)
override def receiveWithLogging = {
case ReviveOffers =>
@@ -73,9 +78,15 @@ private[spark] class LocalActor(
def reviveOffers() {
val offers = Seq(new WorkerOffer(localExecutorId, localExecutorHostname, freeCores))
- for (task <- scheduler.resourceOffers(offers).flatten) {
+ val tasks = scheduler.resourceOffers(offers).flatten
+ for (task <- tasks) {
freeCores -= scheduler.CPUS_PER_TASK
- executor.launchTask(executorBackend, task.taskId, task.name, task.serializedTask)
+ executor.launchTask(executorBackend, taskId = task.taskId, attemptNumber = task.attemptNumber,
+ task.name, task.serializedTask)
+ }
+ if (tasks.isEmpty && scheduler.activeTaskSets.nonEmpty) {
+ // Try to reviveOffer after 1 second, because scheduler may wait for locality timeout
+ context.system.scheduler.scheduleOnce(1000 millis, self, ReviveOffers)
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 662a7b91248aa..1baa0e009f3ae 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -27,7 +27,8 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
import org.apache.spark.util.Utils
-private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int)
+private[spark] class JavaSerializationStream(
+ out: OutputStream, counterReset: Int, extraDebugInfo: Boolean)
extends SerializationStream {
private val objOut = new ObjectOutputStream(out)
private var counter = 0
@@ -39,7 +40,12 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
* the stream 'resets' object class descriptions have to be re-written)
*/
def writeObject[T: ClassTag](t: T): SerializationStream = {
- objOut.writeObject(t)
+ try {
+ objOut.writeObject(t)
+ } catch {
+ case e: NotSerializableException if extraDebugInfo =>
+ throw SerializationDebugger.improveException(t, e)
+ }
counter += 1
if (counterReset > 0 && counter >= counterReset) {
objOut.reset()
@@ -64,7 +70,8 @@ extends DeserializationStream {
}
-private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader)
+private[spark] class JavaSerializerInstance(
+ counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader)
extends SerializerInstance {
override def serialize[T: ClassTag](t: T): ByteBuffer = {
@@ -88,11 +95,11 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade
}
override def serializeStream(s: OutputStream): SerializationStream = {
- new JavaSerializationStream(s, counterReset)
+ new JavaSerializationStream(s, counterReset, extraDebugInfo)
}
override def deserializeStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
+ new JavaDeserializationStream(s, defaultClassLoader)
}
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
@@ -111,17 +118,20 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade
@DeveloperApi
class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100)
+ private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true)
override def newInstance(): SerializerInstance = {
val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader)
- new JavaSerializerInstance(counterReset, classLoader)
+ new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader)
}
override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(counterReset)
+ out.writeBoolean(extraDebugInfo)
}
override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
counterReset = in.readInt()
+ extraDebugInfo = in.readBoolean()
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 621a951c27d07..02158aa0f866e 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -26,9 +26,10 @@ import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializ
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.spark._
+import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock}
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.collection.CompactBuffer
@@ -57,14 +58,6 @@ class KryoSerializer(conf: SparkConf)
private val classesToRegister = conf.get("spark.kryo.classesToRegister", "")
.split(',')
.filter(!_.isEmpty)
- .map { className =>
- try {
- Class.forName(className)
- } catch {
- case e: Exception =>
- throw new SparkException("Failed to load class to register with Kryo", e)
- }
- }
def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize))
@@ -90,12 +83,14 @@ class KryoSerializer(conf: SparkConf)
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
+ kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
try {
// Use the default classloader when calling the user registrator.
Thread.currentThread.setContextClassLoader(classLoader)
// Register classes given through spark.kryo.classesToRegister.
- classesToRegister.foreach { clazz => kryo.register(clazz) }
+ classesToRegister
+ .foreach { className => kryo.register(Class.forName(className, true, classLoader)) }
// Allow the user to register their own classes by setting spark.kryo.registrator.
userRegistrator
.map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
@@ -205,7 +200,8 @@ private[serializer] object KryoSerializer {
classOf[PutBlock],
classOf[GotBlock],
classOf[GetBlock],
- classOf[MapStatus],
+ classOf[CompressedMapStatus],
+ classOf[HighlyCompressedMapStatus],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
classOf[Array[Byte]],
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
new file mode 100644
index 0000000000000..cecb992579655
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField}
+import java.lang.reflect.{Field, Method}
+import java.security.AccessController
+
+import scala.annotation.tailrec
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+
+private[serializer] object SerializationDebugger extends Logging {
+
+ /**
+ * Improve the given NotSerializableException with the serialization path leading from the given
+ * object to the problematic object. This is turned off automatically if
+ * `sun.io.serialization.extendedDebugInfo` flag is turned on for the JVM.
+ */
+ def improveException(obj: Any, e: NotSerializableException): NotSerializableException = {
+ if (enableDebugging && reflect != null) {
+ new NotSerializableException(
+ e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n"))
+ } else {
+ e
+ }
+ }
+
+ /**
+ * Find the path leading to a not serializable object. This method is modeled after OpenJDK's
+ * serialization mechanism, and handles the following cases:
+ * - primitives
+ * - arrays of primitives
+ * - arrays of non-primitive objects
+ * - Serializable objects
+ * - Externalizable objects
+ * - writeReplace
+ *
+ * It does not yet handle writeObject override, but that shouldn't be too hard to do either.
+ */
+ def find(obj: Any): List[String] = {
+ new SerializationDebugger().visit(obj, List.empty)
+ }
+
+ private[serializer] var enableDebugging: Boolean = {
+ !AccessController.doPrivileged(new sun.security.action.GetBooleanAction(
+ "sun.io.serialization.extendedDebugInfo")).booleanValue()
+ }
+
+ private class SerializationDebugger {
+
+ /** A set to track the list of objects we have visited, to avoid cycles in the graph. */
+ private val visited = new mutable.HashSet[Any]
+
+ /**
+ * Visit the object and its fields and stop when we find an object that is not serializable.
+ * Return the path as a list. If everything can be serialized, return an empty list.
+ */
+ def visit(o: Any, stack: List[String]): List[String] = {
+ if (o == null) {
+ List.empty
+ } else if (visited.contains(o)) {
+ List.empty
+ } else {
+ visited += o
+ o match {
+ // Primitive value, string, and primitive arrays are always serializable
+ case _ if o.getClass.isPrimitive => List.empty
+ case _: String => List.empty
+ case _ if o.getClass.isArray && o.getClass.getComponentType.isPrimitive => List.empty
+
+ // Traverse non primitive array.
+ case a: Array[_] if o.getClass.isArray && !o.getClass.getComponentType.isPrimitive =>
+ val elem = s"array (class ${a.getClass.getName}, size ${a.length})"
+ visitArray(o.asInstanceOf[Array[_]], elem :: stack)
+
+ case e: java.io.Externalizable =>
+ val elem = s"externalizable object (class ${e.getClass.getName}, $e)"
+ visitExternalizable(e, elem :: stack)
+
+ case s: Object with java.io.Serializable =>
+ val elem = s"object (class ${s.getClass.getName}, $s)"
+ visitSerializable(s, elem :: stack)
+
+ case _ =>
+ // Found an object that is not serializable!
+ s"object not serializable (class: ${o.getClass.getName}, value: $o)" :: stack
+ }
+ }
+ }
+
+ private def visitArray(o: Array[_], stack: List[String]): List[String] = {
+ var i = 0
+ while (i < o.length) {
+ val childStack = visit(o(i), s"element of array (index: $i)" :: stack)
+ if (childStack.nonEmpty) {
+ return childStack
+ }
+ i += 1
+ }
+ return List.empty
+ }
+
+ private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] =
+ {
+ val fieldList = new ListObjectOutput
+ o.writeExternal(fieldList)
+ val childObjects = fieldList.outputArray
+ var i = 0
+ while (i < childObjects.length) {
+ val childStack = visit(childObjects(i), "writeExternal data" :: stack)
+ if (childStack.nonEmpty) {
+ return childStack
+ }
+ i += 1
+ }
+ return List.empty
+ }
+
+ private def visitSerializable(o: Object, stack: List[String]): List[String] = {
+ // An object contains multiple slots in serialization.
+ // Get the slots and visit fields in all of them.
+ val (finalObj, desc) = findObjectAndDescriptor(o)
+ val slotDescs = desc.getSlotDescs
+ var i = 0
+ while (i < slotDescs.length) {
+ val slotDesc = slotDescs(i)
+ if (slotDesc.hasWriteObjectMethod) {
+ // TODO: Handle classes that specify writeObject method.
+ } else {
+ val fields: Array[ObjectStreamField] = slotDesc.getFields
+ val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields)
+ val numPrims = fields.length - objFieldValues.length
+ desc.getObjFieldValues(finalObj, objFieldValues)
+
+ var j = 0
+ while (j < objFieldValues.length) {
+ val fieldDesc = fields(numPrims + j)
+ val elem = s"field (class: ${slotDesc.getName}" +
+ s", name: ${fieldDesc.getName}" +
+ s", type: ${fieldDesc.getType})"
+ val childStack = visit(objFieldValues(j), elem :: stack)
+ if (childStack.nonEmpty) {
+ return childStack
+ }
+ j += 1
+ }
+
+ }
+ i += 1
+ }
+ return List.empty
+ }
+ }
+
+ /**
+ * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles
+ * writeReplace in Serializable. It starts with the object itself, and keeps calling the
+ * writeReplace method until there is no more
+ */
+ @tailrec
+ private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = {
+ val cl = o.getClass
+ val desc = ObjectStreamClass.lookupAny(cl)
+ if (!desc.hasWriteReplaceMethod) {
+ (o, desc)
+ } else {
+ // write place
+ findObjectAndDescriptor(desc.invokeWriteReplace(o))
+ }
+ }
+
+ /**
+ * A dummy [[ObjectOutput]] that simply saves the list of objects written by a writeExternal
+ * call, and returns them through `outputArray`.
+ */
+ private class ListObjectOutput extends ObjectOutput {
+ private val output = new mutable.ArrayBuffer[Any]
+ def outputArray: Array[Any] = output.toArray
+ override def writeObject(o: Any): Unit = output += o
+ override def flush(): Unit = {}
+ override def write(i: Int): Unit = {}
+ override def write(bytes: Array[Byte]): Unit = {}
+ override def write(bytes: Array[Byte], i: Int, i1: Int): Unit = {}
+ override def close(): Unit = {}
+ override def writeFloat(v: Float): Unit = {}
+ override def writeChars(s: String): Unit = {}
+ override def writeDouble(v: Double): Unit = {}
+ override def writeUTF(s: String): Unit = {}
+ override def writeShort(i: Int): Unit = {}
+ override def writeInt(i: Int): Unit = {}
+ override def writeBoolean(b: Boolean): Unit = {}
+ override def writeBytes(s: String): Unit = {}
+ override def writeChar(i: Int): Unit = {}
+ override def writeLong(l: Long): Unit = {}
+ override def writeByte(i: Int): Unit = {}
+ }
+
+ /** An implicit class that allows us to call private methods of ObjectStreamClass. */
+ implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal {
+ def getSlotDescs: Array[ObjectStreamClass] = {
+ reflect.GetClassDataLayout.invoke(desc).asInstanceOf[Array[Object]].map {
+ classDataSlot => reflect.DescField.get(classDataSlot).asInstanceOf[ObjectStreamClass]
+ }
+ }
+
+ def hasWriteObjectMethod: Boolean = {
+ reflect.HasWriteObjectMethod.invoke(desc).asInstanceOf[Boolean]
+ }
+
+ def hasWriteReplaceMethod: Boolean = {
+ reflect.HasWriteReplaceMethod.invoke(desc).asInstanceOf[Boolean]
+ }
+
+ def invokeWriteReplace(obj: Object): Object = {
+ reflect.InvokeWriteReplace.invoke(desc, obj)
+ }
+
+ def getNumObjFields: Int = {
+ reflect.GetNumObjFields.invoke(desc).asInstanceOf[Int]
+ }
+
+ def getObjFieldValues(obj: Object, out: Array[Object]): Unit = {
+ reflect.GetObjFieldValues.invoke(desc, obj, out)
+ }
+ }
+
+ /**
+ * Object to hold all the reflection objects. If we run on a JVM that we cannot understand,
+ * this field will be null and this the debug helper should be disabled.
+ */
+ private val reflect: ObjectStreamClassReflection = try {
+ new ObjectStreamClassReflection
+ } catch {
+ case e: Exception =>
+ logWarning("Cannot find private methods using reflection", e)
+ null
+ }
+
+ private class ObjectStreamClassReflection {
+ /** ObjectStreamClass.getClassDataLayout */
+ val GetClassDataLayout: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("getClassDataLayout")
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.hasWriteObjectMethod */
+ val HasWriteObjectMethod: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteObjectMethod")
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.hasWriteReplaceMethod */
+ val HasWriteReplaceMethod: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteReplaceMethod")
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.invokeWriteReplace */
+ val InvokeWriteReplace: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("invokeWriteReplace", classOf[Object])
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.getNumObjFields */
+ val GetNumObjFields: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod("getNumObjFields")
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass.getObjFieldValues */
+ val GetObjFieldValues: Method = {
+ val f = classOf[ObjectStreamClass].getDeclaredMethod(
+ "getObjFieldValues", classOf[Object], classOf[Array[Object]])
+ f.setAccessible(true)
+ f
+ }
+
+ /** ObjectStreamClass$ClassDataSlot.desc field */
+ val DescField: Field = {
+ val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc")
+ f.setAccessible(true)
+ f
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index 801ae54086053..a44a8e1249256 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -20,8 +20,8 @@ package org.apache.spark.shuffle
import org.apache.spark.{TaskContext, ShuffleDependency}
/**
- * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the
- * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles
+ * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the driver
+ * and on each executor, based on the spark.shuffle.manager setting. The driver registers shuffles
* with it, and executors (or tasks running locally in the driver) can ask to read and write data.
*
* NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index e3e7434df45b0..7a2c5ae32d98b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -86,6 +86,12 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
context.taskMetrics.updateShuffleReadMetrics()
})
- new InterruptibleIterator[T](context, completionIter)
+ new InterruptibleIterator[T](context, completionIter) {
+ val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+ override def next(): T = {
+ readMetrics.incRecordsRead(1)
+ delegate.next()
+ }
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 5baf45db45c17..41bafabde05b9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -45,9 +45,9 @@ private[spark] class HashShuffleReader[K, C](
} else {
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
}
- } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
- throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
+ require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
+
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
}
@@ -59,8 +59,8 @@ private[spark] class HashShuffleReader[K, C](
// the ExternalSorter won't spill to disk.
val sorter = new ExternalSorter[K, C, C](ordering = Some(keyOrd), serializer = Some(ser))
sorter.insertAll(aggregatedIter)
- context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
- context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
+ context.taskMetrics.incMemoryBytesSpilled(sorter.memoryBytesSpilled)
+ context.taskMetrics.incDiskBytesSpilled(sorter.diskBytesSpilled)
sorter.iterator
case None =>
aggregatedIter
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 183a30373b28c..755f17d6aa15a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -56,9 +56,8 @@ private[spark] class HashShuffleWriter[K, V](
} else {
records
}
- } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
- throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
+ require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
records
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index d75f9d7311fad..27496c5a289cb 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -50,9 +50,7 @@ private[spark] class SortShuffleWriter[K, V, C](
/** Write a bunch of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
if (dep.mapSideCombine) {
- if (!dep.aggregator.isDefined) {
- throw new IllegalStateException("Aggregator is empty for map-side combine")
- }
+ require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
sorter = new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
sorter.insertAll(records)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 308c59eda594d..86dbd89f0ffb8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -34,10 +34,9 @@ import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
-import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
-import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.hash.HashShuffleManager
@@ -54,7 +53,7 @@ private[spark] class BlockResult(
readMethod: DataReadMethod.Value,
bytes: Long) {
val inputMetrics = new InputMetrics(readMethod)
- inputMetrics.bytesRead = bytes
+ inputMetrics.incBytesRead(bytes)
}
/**
@@ -120,7 +119,7 @@ private[spark] class BlockManager(
private[spark] var shuffleServerId: BlockManagerId = _
// Client to read other executors' shuffle files. This is either an external service, or just the
- // standard BlockTranserService to directly connect to other Executors.
+ // standard BlockTransferService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
@@ -1014,8 +1013,10 @@ private[spark] class BlockManager(
// If we get here, the block write failed.
logWarning(s"Block $blockId was marked as failure. Nothing to drop")
return None
+ } else if (blockInfo.get(blockId).isEmpty) {
+ logWarning(s"Block $blockId was already dropped.")
+ return None
}
-
var blockIsUpdated = false
val level = info.level
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 9cbda41223a8b..64133464d8daa 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -52,8 +52,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
private val akkaTimeout = AkkaUtils.askTimeout(conf)
- val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs",
- math.max(conf.getInt("spark.executor.heartbeatInterval", 10000) * 3, 45000))
+ val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120 * 1000)
val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 9c469370ffe1f..81164178b9e8e 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -29,7 +29,8 @@ import org.apache.spark.executor.ShuffleWriteMetrics
* appending data to an existing block, and can guarantee atomicity in the case of faults
* as it allows the caller to revert partial writes.
*
- * This interface does not support concurrent writes.
+ * This interface does not support concurrent writes. Also, once the writer has
+ * been opened, it cannot be reopened again.
*/
private[spark] abstract class BlockObjectWriter(val blockId: BlockId) {
@@ -95,6 +96,7 @@ private[spark] class DiskBlockObjectWriter(
private var ts: TimeTrackingOutputStream = null
private var objOut: SerializationStream = null
private var initialized = false
+ private var hasBeenClosed = false
/**
* Cursors used to represent positions in the file.
@@ -115,11 +117,16 @@ private[spark] class DiskBlockObjectWriter(
private var finalPosition: Long = -1
private var reportedPosition = initialPosition
- /** Calling channel.position() to update the write metrics can be a little bit expensive, so we
- * only call it every N writes */
- private var writesSinceMetricsUpdate = 0
+ /**
+ * Keep track of number of records written and also use this to periodically
+ * output bytes written since the latter is expensive to do for each record.
+ */
+ private var numRecordsWritten = 0
override def open(): BlockObjectWriter = {
+ if (hasBeenClosed) {
+ throw new IllegalStateException("Writer already closed. Cannot be reopened.")
+ }
fos = new FileOutputStream(file, true)
ts = new TimeTrackingOutputStream(fos)
channel = fos.getChannel()
@@ -145,6 +152,7 @@ private[spark] class DiskBlockObjectWriter(
ts = null
objOut = null
initialized = false
+ hasBeenClosed = true
}
}
@@ -160,14 +168,15 @@ private[spark] class DiskBlockObjectWriter(
}
finalPosition = file.length()
// In certain compression codecs, more bytes are written after close() is called
- writeMetrics.shuffleBytesWritten += (finalPosition - reportedPosition)
+ writeMetrics.incShuffleBytesWritten(finalPosition - reportedPosition)
}
// Discard current writes. We do this by flushing the outstanding writes and then
// truncating the file to its initial position.
override def revertPartialWritesAndClose() {
try {
- writeMetrics.shuffleBytesWritten -= (reportedPosition - initialPosition)
+ writeMetrics.decShuffleBytesWritten(reportedPosition - initialPosition)
+ writeMetrics.decShuffleRecordsWritten(numRecordsWritten)
if (initialized) {
objOut.flush()
@@ -193,12 +202,11 @@ private[spark] class DiskBlockObjectWriter(
}
objOut.writeObject(value)
+ numRecordsWritten += 1
+ writeMetrics.incShuffleRecordsWritten(1)
- if (writesSinceMetricsUpdate == 32) {
- writesSinceMetricsUpdate = 0
+ if (numRecordsWritten % 32 == 0) {
updateBytesWritten()
- } else {
- writesSinceMetricsUpdate += 1
}
}
@@ -212,14 +220,14 @@ private[spark] class DiskBlockObjectWriter(
*/
private def updateBytesWritten() {
val pos = channel.position()
- writeMetrics.shuffleBytesWritten += (pos - reportedPosition)
+ writeMetrics.incShuffleBytesWritten(pos - reportedPosition)
reportedPosition = pos
}
private def callWithTiming(f: => Unit) = {
val start = System.nanoTime()
f
- writeMetrics.shuffleWriteTime += (System.nanoTime() - start)
+ writeMetrics.incShuffleWriteTime(System.nanoTime() - start)
}
// For testing
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 58fba54710510..53eaedacbf291 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -17,9 +17,8 @@
package org.apache.spark.storage
-import java.io.File
-import java.text.SimpleDateFormat
-import java.util.{Date, Random, UUID}
+import java.util.UUID
+import java.io.{IOException, File}
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.executor.ExecutorExitCode
@@ -37,7 +36,6 @@ import org.apache.spark.util.Utils
private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf)
extends Logging {
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
private[spark]
val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
@@ -71,7 +69,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
old
} else {
val newDir = new File(localDirs(dirId), "%02x".format(subDirId))
- newDir.mkdir()
+ if (!newDir.exists() && !newDir.mkdir()) {
+ throw new IOException(s"Failed to create local dir in $newDir.")
+ }
subDirs(dirId)(subDirId) = newDir
newDir
}
@@ -121,33 +121,15 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
private def createLocalDirs(conf: SparkConf): Array[File] = {
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, s"spark-local-$localDirId")
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning(s"Attempt $tries to create local dir $localDir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create local dir in $rootDir." +
- " Ignoring this directory.")
- None
- } else {
+ try {
+ val localDir = Utils.createDirectory(rootDir, "blockmgr")
logInfo(s"Created local directory at $localDir")
Some(localDir)
+ } catch {
+ case e: IOException =>
+ logError(s"Failed to create local dir in $rootDir. Ignoring this directory.", e)
+ None
}
}
}
@@ -164,7 +146,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
/** Cleanup local dirs and stop shuffle sender. */
private[spark] def stop() {
// Only perform cleanup if an external service is not serving our shuffle files.
- if (!blockManager.externalShuffleServiceEnabled) {
+ if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) {
localDirs.foreach { localDir =>
if (localDir.isDirectory() && localDir.exists()) {
try {
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index 8dadf6794039e..61ef5ff168791 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -31,7 +31,8 @@ import org.apache.spark.util.Utils
private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBlockManager)
extends BlockStore(blockManager) with Logging {
- val minMemoryMapBytes = blockManager.conf.getLong("spark.storage.memoryMapThreshold", 2 * 4096L)
+ val minMemoryMapBytes = blockManager.conf.getLong(
+ "spark.storage.memoryMapThreshold", 2 * 1024L * 1024L)
override def getSize(blockId: BlockId): Long = {
diskManager.getFile(blockId.name).length
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 2499c11a65b0e..ab9ee4f0096bf 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -156,8 +156,8 @@ final class ShuffleBlockFetcherIterator(
// This needs to be released after use.
buf.retain()
results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
- shuffleMetrics.remoteBytesRead += buf.size
- shuffleMetrics.remoteBlocksFetched += 1
+ shuffleMetrics.incRemoteBytesRead(buf.size)
+ shuffleMetrics.incRemoteBlocksFetched(1)
}
logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
@@ -233,7 +233,7 @@ final class ShuffleBlockFetcherIterator(
val blockId = iter.next()
try {
val buf = blockManager.getBlockData(blockId)
- shuffleMetrics.localBlocksFetched += 1
+ shuffleMetrics.incLocalBlocksFetched(1)
buf.retain()
results.put(new SuccessFetchResult(blockId, 0, buf))
} catch {
@@ -277,7 +277,7 @@ final class ShuffleBlockFetcherIterator(
currentResult = results.take()
val result = currentResult
val stopFetchWait = System.currentTimeMillis()
- shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
+ shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait)
result match {
case SuccessFetchResult(_, size, _) => bytesInFlight -= size
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 56edc4fe2e4ad..e5e1cf5a69a19 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -18,6 +18,7 @@
package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
+import java.util.concurrent.ConcurrentHashMap
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
@@ -220,8 +221,7 @@ object StorageLevel {
getCachedStorageLevel(obj)
}
- private[spark] val storageLevelCache =
- new java.util.concurrent.ConcurrentHashMap[StorageLevel, StorageLevel]()
+ private[spark] val storageLevelCache = new ConcurrentHashMap[StorageLevel, StorageLevel]()
private[spark] def getCachedStorageLevel(level: StorageLevel): StorageLevel = {
storageLevelCache.putIfAbsent(level, level)
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 2a27d49d2de05..bf4b24e98b134 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -62,17 +62,22 @@ private[spark] object JettyUtils extends Logging {
securityMgr: SecurityManager): HttpServlet = {
new HttpServlet {
override def doGet(request: HttpServletRequest, response: HttpServletResponse) {
- if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) {
- response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
- response.setStatus(HttpServletResponse.SC_OK)
- val result = servletParams.responder(request)
- response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
- response.getWriter.println(servletParams.extractFn(result))
- } else {
- response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
- response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
- response.sendError(HttpServletResponse.SC_UNAUTHORIZED,
- "User is not authorized to access this page.")
+ try {
+ if (securityMgr.checkUIViewPermissions(request.getRemoteUser)) {
+ response.setContentType("%s;charset=utf-8".format(servletParams.contentType))
+ response.setStatus(HttpServletResponse.SC_OK)
+ val result = servletParams.responder(request)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.getWriter.println(servletParams.extractFn(result))
+ } else {
+ response.setStatus(HttpServletResponse.SC_UNAUTHORIZED)
+ response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
+ response.sendError(HttpServletResponse.SC_UNAUTHORIZED,
+ "User is not authorized to access this page.")
+ }
+ } catch {
+ case e: IllegalArgumentException =>
+ response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage)
}
}
}
@@ -201,7 +206,7 @@ private[spark] object JettyUtils extends Logging {
}
}
- val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName)
+ val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
ServerInfo(server, boundPort, collection)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index 176907dffa46a..0c24ad2760e08 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -57,8 +57,6 @@ private[spark] class SparkUI private (
attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath))
attachHandler(
createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest))
- // If the UI is live, then serve
- sc.foreach { _.env.metricsSystem.getServletHandlers.foreach(attachHandler) }
}
initialize()
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 6f446c5a95a0a..3a15e603b1969 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -24,17 +24,20 @@ private[spark] object ToolTips {
scheduler delay is large, consider decreasing the size of tasks or decreasing the size
of task results."""
- val TASK_DESERIALIZATION_TIME =
- """Time spent deserializating the task closure on the executor."""
+ val TASK_DESERIALIZATION_TIME = "Time spent deserializing the task closure on the executor."
- val INPUT = "Bytes read from Hadoop or from Spark storage."
+ val SHUFFLE_READ_BLOCKED_TIME =
+ "Time that the task spent blocked waiting for shuffle data to be read from remote machines."
- val OUTPUT = "Bytes written to Hadoop."
+ val INPUT = "Bytes and records read from Hadoop or from Spark storage."
- val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage."
+ val OUTPUT = "Bytes and records written to Hadoop."
+
+ val SHUFFLE_WRITE =
+ "Bytes and records written to disk in order to be read by a shuffle in a future stage."
val SHUFFLE_READ =
- """Bytes read from remote executors. Typically less than shuffle write bytes
+ """Bytes and records read from remote executors. Typically less than shuffle write bytes
because this does not include shuffle data read locally."""
val GETTING_RESULT_TIME =
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 315327c3c6b7c..b5022fe853c49 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -181,7 +181,9 @@ private[spark] object UIUtils extends Logging {
}
val helpButton: Seq[Node] = helpText.map { helpText =>
- (?)
+
+ (?)
+
}.getOrElse(Seq.empty)
@@ -192,9 +194,12 @@ private[spark] object UIUtils extends Logging {
-
-
-
+
{shortAppName} application UI
@@ -212,11 +217,6 @@ private[spark] object UIUtils extends Logging {
{content}
-
}
@@ -234,8 +234,9 @@ private[spark] object UIUtils extends Logging {
@@ -243,11 +244,6 @@ private[spark] object UIUtils extends Logging {
{content}
-
|