diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000..2b65f6fe3cc80
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,2 @@
+*.bat text eol=crlf
+*.cmd text eol=crlf
diff --git a/.rat-excludes b/.rat-excludes
index b14ad53720f32..20e3372464386 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -1,5 +1,6 @@
target
.gitignore
+.gitattributes
.project
.classpath
.mima-excludes
@@ -48,6 +49,7 @@ sbt-launch-lib.bash
plugins.sbt
work
.*\.q
+.*\.qv
golden
test.out/*
.*iml
diff --git a/LICENSE b/LICENSE
index a7eee041129cb..f1732fb47afc0 100644
--- a/LICENSE
+++ b/LICENSE
@@ -712,18 +712,6 @@ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-========================================================================
-For colt:
-========================================================================
-
-Copyright (c) 1999 CERN - European Organization for Nuclear Research.
-Permission to use, copy, modify, distribute and sell this software and its documentation for any purpose is hereby granted without fee, provided that the above copyright notice appear in all copies and that both that copyright notice and this permission notice appear in supporting documentation. CERN makes no representations about the suitability of this software for any purpose. It is provided "as is" without expressed or implied warranty.
-
-Packages hep.aida.*
-
-Written by Pavel Binko, Dino Ferrero Merlino, Wolfgang Hoschek, Tony Johnson, Andreas Pfeiffer, and others. Check the FreeHEP home page for more info. Permission to use and/or redistribute this work is granted under the terms of the LGPL License, with the exception that any usage related to military applications is expressly forbidden. The software and documentation made available under the terms of this license are provided with no warranty.
-
-
========================================================================
For SnapTree:
========================================================================
diff --git a/README.md b/README.md
index 8dd8b70696aa2..9916ac7b1ae8e 100644
--- a/README.md
+++ b/README.md
@@ -25,7 +25,7 @@ To build Spark and its example programs, run:
(You do not need to do this if you downloaded a pre-built package.)
More detailed documentation is available from the project site, at
-["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html).
+["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-with-maven.html).
## Interactive Scala Shell
@@ -84,7 +84,7 @@ storage systems. Because the protocols have changed in different versions of
Hadoop, you must build Spark against the same version that your cluster runs.
Please refer to the build documentation at
-["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-spark.html#specifying-the-hadoop-version)
+["Specifying the Hadoop Version"](http://spark.apache.org/docs/latest/building-with-maven.html#specifying-the-hadoop-version)
for detailed guidance on building for a particular distribution of Hadoop, including
building for particular Hive and Hive Thriftserver distributions. See also
["Third Party Hadoop Distributions"](http://spark.apache.org/docs/latest/hadoop-third-party-distributions.html)
diff --git a/bin/compute-classpath.cmd b/bin/compute-classpath.cmd
index 3cd0579aea8d3..a4c099fb45b14 100644
--- a/bin/compute-classpath.cmd
+++ b/bin/compute-classpath.cmd
@@ -1,117 +1,117 @@
-@echo off
-
-rem
-rem Licensed to the Apache Software Foundation (ASF) under one or more
-rem contributor license agreements. See the NOTICE file distributed with
-rem this work for additional information regarding copyright ownership.
-rem The ASF licenses this file to You under the Apache License, Version 2.0
-rem (the "License"); you may not use this file except in compliance with
-rem the License. You may obtain a copy of the License at
-rem
-rem http://www.apache.org/licenses/LICENSE-2.0
-rem
-rem Unless required by applicable law or agreed to in writing, software
-rem distributed under the License is distributed on an "AS IS" BASIS,
-rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-rem See the License for the specific language governing permissions and
-rem limitations under the License.
-rem
-
-rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
-rem script and the ExecutorRunner in standalone cluster mode.
-
-rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting
-rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we
-rem need to set it here because we use !datanucleus_jars! below.
-if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion
-setlocal enabledelayedexpansion
-:skip_delayed_expansion
-
-set SCALA_VERSION=2.10
-
-rem Figure out where the Spark framework is installed
-set FWDIR=%~dp0..\
-
-rem Load environment variables from conf\spark-env.cmd, if it exists
-if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
-
-rem Build up classpath
-set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
-
-if not "x%SPARK_CONF_DIR%"=="x" (
- set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
-) else (
- set CLASSPATH=%CLASSPATH%;%FWDIR%conf
-)
-
-if exist "%FWDIR%RELEASE" (
- for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-) else (
- for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
- set ASSEMBLY_JAR=%%d
- )
-)
-
-set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
-
-rem When Hive support is needed, Datanucleus jars must be included on the classpath.
-rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost.
-rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is
-rem built with Hive, so look for them there.
-if exist "%FWDIR%RELEASE" (
- set datanucleus_dir=%FWDIR%lib
-) else (
- set datanucleus_dir=%FWDIR%lib_managed\jars
-)
-set "datanucleus_jars="
-for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do (
- set datanucleus_jars=!datanucleus_jars!;%%d
-)
-set CLASSPATH=%CLASSPATH%;%datanucleus_jars%
-
-set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
-set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
-
-set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
-set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
-
-if "x%SPARK_TESTING%"=="x1" (
- rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
- rem so that local compilation takes precedence over assembled jar
- set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH%
-)
-
-rem Add hadoop conf dir - else FileSystem.*, etc fail
-rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
-rem the configurtion files.
-if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
- set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
-:no_hadoop_conf_dir
-
-if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
- set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
-:no_yarn_conf_dir
-
-rem A bit of a hack to allow calling this script within run2.cmd without seeing output
-if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
-
-echo %CLASSPATH%
-
-:exit
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+rem This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
+rem script and the ExecutorRunner in standalone cluster mode.
+
+rem If we're called from spark-class2.cmd, it already set enabledelayedexpansion and setting
+rem it here would stop us from affecting its copy of the CLASSPATH variable; otherwise we
+rem need to set it here because we use !datanucleus_jars! below.
+if "%DONT_PRINT_CLASSPATH%"=="1" goto skip_delayed_expansion
+setlocal enabledelayedexpansion
+:skip_delayed_expansion
+
+set SCALA_VERSION=2.10
+
+rem Figure out where the Spark framework is installed
+set FWDIR=%~dp0..\
+
+rem Load environment variables from conf\spark-env.cmd, if it exists
+if exist "%FWDIR%conf\spark-env.cmd" call "%FWDIR%conf\spark-env.cmd"
+
+rem Build up classpath
+set CLASSPATH=%SPARK_CLASSPATH%;%SPARK_SUBMIT_CLASSPATH%
+
+if not "x%SPARK_CONF_DIR%"=="x" (
+ set CLASSPATH=%CLASSPATH%;%SPARK_CONF_DIR%
+) else (
+ set CLASSPATH=%CLASSPATH%;%FWDIR%conf
+)
+
+if exist "%FWDIR%RELEASE" (
+ for %%d in ("%FWDIR%lib\spark-assembly*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+) else (
+ for %%d in ("%FWDIR%assembly\target\scala-%SCALA_VERSION%\spark-assembly*hadoop*.jar") do (
+ set ASSEMBLY_JAR=%%d
+ )
+)
+
+set CLASSPATH=%CLASSPATH%;%ASSEMBLY_JAR%
+
+rem When Hive support is needed, Datanucleus jars must be included on the classpath.
+rem Datanucleus jars do not work if only included in the uber jar as plugin.xml metadata is lost.
+rem Both sbt and maven will populate "lib_managed/jars/" with the datanucleus jars when Spark is
+rem built with Hive, so look for them there.
+if exist "%FWDIR%RELEASE" (
+ set datanucleus_dir=%FWDIR%lib
+) else (
+ set datanucleus_dir=%FWDIR%lib_managed\jars
+)
+set "datanucleus_jars="
+for %%d in ("%datanucleus_dir%\datanucleus-*.jar") do (
+ set datanucleus_jars=!datanucleus_jars!;%%d
+)
+set CLASSPATH=%CLASSPATH%;%datanucleus_jars%
+
+set SPARK_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%tools\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\classes
+set SPARK_CLASSES=%SPARK_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\classes
+
+set SPARK_TEST_CLASSES=%FWDIR%core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%repl\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%mllib\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%bagel\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%graphx\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%streaming\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\catalyst\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\core\target\scala-%SCALA_VERSION%\test-classes
+set SPARK_TEST_CLASSES=%SPARK_TEST_CLASSES%;%FWDIR%sql\hive\target\scala-%SCALA_VERSION%\test-classes
+
+if "x%SPARK_TESTING%"=="x1" (
+ rem Add test clases to path - note, add SPARK_CLASSES and SPARK_TEST_CLASSES before CLASSPATH
+ rem so that local compilation takes precedence over assembled jar
+ set CLASSPATH=%SPARK_CLASSES%;%SPARK_TEST_CLASSES%;%CLASSPATH%
+)
+
+rem Add hadoop conf dir - else FileSystem.*, etc fail
+rem Note, this assumes that there is either a HADOOP_CONF_DIR or YARN_CONF_DIR which hosts
+rem the configurtion files.
+if "x%HADOOP_CONF_DIR%"=="x" goto no_hadoop_conf_dir
+ set CLASSPATH=%CLASSPATH%;%HADOOP_CONF_DIR%
+:no_hadoop_conf_dir
+
+if "x%YARN_CONF_DIR%"=="x" goto no_yarn_conf_dir
+ set CLASSPATH=%CLASSPATH%;%YARN_CONF_DIR%
+:no_yarn_conf_dir
+
+rem A bit of a hack to allow calling this script within run2.cmd without seeing output
+if "%DONT_PRINT_CLASSPATH%"=="1" goto exit
+
+echo %CLASSPATH%
+
+:exit
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index a0e66abcc26c9..59415e9bdec2c 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -59,7 +59,12 @@ for /f %%i in ('echo %1^| findstr /R "\.py"') do (
)
if [%PYTHON_FILE%] == [] (
- %PYSPARK_PYTHON%
+ set PYSPARK_SHELL=1
+ if [%IPYTHON%] == [1] (
+ ipython %IPYTHON_OPTS%
+ ) else (
+ %PYSPARK_PYTHON%
+ )
) else (
echo.
echo WARNING: Running python applications through ./bin/pyspark.cmd is deprecated as of Spark 1.0.
diff --git a/bin/spark-class b/bin/spark-class
index 91d858bc063d0..925367b0dd187 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -81,7 +81,11 @@ case "$1" in
OUR_JAVA_OPTS="$SPARK_JAVA_OPTS $SPARK_SUBMIT_OPTS"
OUR_JAVA_MEM=${SPARK_DRIVER_MEMORY:-$DEFAULT_MEM}
if [ -n "$SPARK_SUBMIT_LIBRARY_PATH" ]; then
- OUR_JAVA_OPTS="$OUR_JAVA_OPTS -Djava.library.path=$SPARK_SUBMIT_LIBRARY_PATH"
+ if [[ $OSTYPE == darwin* ]]; then
+ export DYLD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$DYLD_LIBRARY_PATH"
+ else
+ export LD_LIBRARY_PATH="$SPARK_SUBMIT_LIBRARY_PATH:$LD_LIBRARY_PATH"
+ fi
fi
if [ -n "$SPARK_SUBMIT_DRIVER_MEMORY" ]; then
OUR_JAVA_MEM="$SPARK_SUBMIT_DRIVER_MEMORY"
diff --git a/core/pom.xml b/core/pom.xml
index a5a178079bc57..41296e0eca330 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -44,6 +44,16 @@
+
+ org.apache.spark
+ spark-network-common_2.10
+ ${project.version}
+
+
+ org.apache.spark
+ spark-network-shuffle_2.10
+ ${project.version}
+ net.java.dev.jets3tjets3t
@@ -85,8 +95,6 @@
org.apache.commonscommons-math3
- 3.3
- testcom.google.code.findbugs
@@ -132,6 +140,10 @@
com.twitterchill-java
+
+ org.roaringbitmap
+ RoaringBitmap
+ commons-netcommons-net
@@ -158,10 +170,6 @@
json4s-jackson_${scala.binary.version}3.2.10
-
- colt
- colt
- org.apache.mesosmesos
@@ -243,6 +251,11 @@
+
+ org.seleniumhq.selenium
+ selenium-java
+ test
+ org.scalatestscalatest_${scala.binary.version}
diff --git a/core/src/main/java/org/apache/spark/JobExecutionStatus.java b/core/src/main/java/org/apache/spark/JobExecutionStatus.java
new file mode 100644
index 0000000000000..6e161313702bb
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/JobExecutionStatus.java
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+public enum JobExecutionStatus {
+ RUNNING,
+ SUCCEEDED,
+ FAILED,
+ UNKNOWN
+}
diff --git a/core/src/main/java/org/apache/spark/SparkJobInfo.java b/core/src/main/java/org/apache/spark/SparkJobInfo.java
new file mode 100644
index 0000000000000..4e3c983b1170a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/SparkJobInfo.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+/**
+ * Exposes information about Spark Jobs.
+ *
+ * This interface is not designed to be implemented outside of Spark. We may add additional methods
+ * which may break binary compatibility with outside implementations.
+ */
+public interface SparkJobInfo {
+ int jobId();
+ int[] stageIds();
+ JobExecutionStatus status();
+}
diff --git a/core/src/main/java/org/apache/spark/SparkStageInfo.java b/core/src/main/java/org/apache/spark/SparkStageInfo.java
new file mode 100644
index 0000000000000..04e2247210ecc
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/SparkStageInfo.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark;
+
+/**
+ * Exposes information about Spark Stages.
+ *
+ * This interface is not designed to be implemented outside of Spark. We may add additional methods
+ * which may break binary compatibility with outside implementations.
+ */
+public interface SparkStageInfo {
+ int stageId();
+ int currentAttemptId();
+ String name();
+ int numTasks();
+ int numActiveTasks();
+ int numCompletedTasks();
+ int numFailedTasks();
+}
diff --git a/core/src/main/java/org/apache/spark/TaskContext.java b/core/src/main/java/org/apache/spark/TaskContext.java
index 2d998d4c7a5d9..0d6973203eba1 100644
--- a/core/src/main/java/org/apache/spark/TaskContext.java
+++ b/core/src/main/java/org/apache/spark/TaskContext.java
@@ -71,7 +71,6 @@ static void unset() {
/**
* Add a (Java friendly) listener to be executed on task completion.
* This will be called in all situation - success, failure, or cancellation.
- *
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
public abstract TaskContext addTaskCompletionListener(TaskCompletionListener listener);
@@ -79,7 +78,6 @@ static void unset() {
/**
* Add a listener in the form of a Scala closure to be executed on task completion.
* This will be called in all situations - success, failure, or cancellation.
- *
* An example use is for HadoopRDD to register a callback to close the input stream.
*/
public abstract TaskContext addTaskCompletionListener(final Function1 f);
diff --git a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
index abd9bcc07ac61..99bf240a17225 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/PairFunction.java
@@ -22,7 +22,8 @@
import scala.Tuple2;
/**
- * A function that returns key-value pairs (Tuple2), and can be used to construct PairRDDs.
+ * A function that returns key-value pairs (Tuple2<K, V>), and can be used to
+ * construct PairRDDs.
*/
public interface PairFunction extends Serializable {
public Tuple2 call(T t) throws Exception;
diff --git a/core/src/main/java/org/apache/spark/util/collection/Sorter.java b/core/src/main/java/org/apache/spark/util/collection/Sorter.java
deleted file mode 100644
index 64ad18c0e463a..0000000000000
--- a/core/src/main/java/org/apache/spark/util/collection/Sorter.java
+++ /dev/null
@@ -1,915 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.util.collection;
-
-import java.util.Comparator;
-
-/**
- * A port of the Android Timsort class, which utilizes a "stable, adaptive, iterative mergesort."
- * See the method comment on sort() for more details.
- *
- * This has been kept in Java with the original style in order to match very closely with the
- * Anroid source code, and thus be easy to verify correctness.
- *
- * The purpose of the port is to generalize the interface to the sort to accept input data formats
- * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
- * uses this to sort an Array with alternating elements of the form [key, value, key, value].
- * This generalization comes with minimal overhead -- see SortDataFormat for more information.
- */
-class Sorter {
-
- /**
- * This is the minimum sized sequence that will be merged. Shorter
- * sequences will be lengthened by calling binarySort. If the entire
- * array is less than this length, no merges will be performed.
- *
- * This constant should be a power of two. It was 64 in Tim Peter's C
- * implementation, but 32 was empirically determined to work better in
- * this implementation. In the unlikely event that you set this constant
- * to be a number that's not a power of two, you'll need to change the
- * minRunLength computation.
- *
- * If you decrease this constant, you must change the stackLen
- * computation in the TimSort constructor, or you risk an
- * ArrayOutOfBounds exception. See listsort.txt for a discussion
- * of the minimum stack length required as a function of the length
- * of the array being sorted and the minimum merge sequence length.
- */
- private static final int MIN_MERGE = 32;
-
- private final SortDataFormat s;
-
- public Sorter(SortDataFormat sortDataFormat) {
- this.s = sortDataFormat;
- }
-
- /**
- * A stable, adaptive, iterative mergesort that requires far fewer than
- * n lg(n) comparisons when running on partially sorted arrays, while
- * offering performance comparable to a traditional mergesort when run
- * on random arrays. Like all proper mergesorts, this sort is stable and
- * runs O(n log n) time (worst case). In the worst case, this sort requires
- * temporary storage space for n/2 object references; in the best case,
- * it requires only a small constant amount of space.
- *
- * This implementation was adapted from Tim Peters's list sort for
- * Python, which is described in detail here:
- *
- * http://svn.python.org/projects/python/trunk/Objects/listsort.txt
- *
- * Tim's C code may be found here:
- *
- * http://svn.python.org/projects/python/trunk/Objects/listobject.c
- *
- * The underlying techniques are described in this paper (and may have
- * even earlier origins):
- *
- * "Optimistic Sorting and Information Theoretic Complexity"
- * Peter McIlroy
- * SODA (Fourth Annual ACM-SIAM Symposium on Discrete Algorithms),
- * pp 467-474, Austin, Texas, 25-27 January 1993.
- *
- * While the API to this class consists solely of static methods, it is
- * (privately) instantiable; a TimSort instance holds the state of an ongoing
- * sort, assuming the input array is large enough to warrant the full-blown
- * TimSort. Small arrays are sorted in place, using a binary insertion sort.
- *
- * @author Josh Bloch
- */
- void sort(Buffer a, int lo, int hi, Comparator super K> c) {
- assert c != null;
-
- int nRemaining = hi - lo;
- if (nRemaining < 2)
- return; // Arrays of size 0 and 1 are always sorted
-
- // If array is small, do a "mini-TimSort" with no merges
- if (nRemaining < MIN_MERGE) {
- int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
- binarySort(a, lo, hi, lo + initRunLen, c);
- return;
- }
-
- /**
- * March over the array once, left to right, finding natural runs,
- * extending short natural runs to minRun elements, and merging runs
- * to maintain stack invariant.
- */
- SortState sortState = new SortState(a, c, hi - lo);
- int minRun = minRunLength(nRemaining);
- do {
- // Identify next run
- int runLen = countRunAndMakeAscending(a, lo, hi, c);
-
- // If run is short, extend to min(minRun, nRemaining)
- if (runLen < minRun) {
- int force = nRemaining <= minRun ? nRemaining : minRun;
- binarySort(a, lo, lo + force, lo + runLen, c);
- runLen = force;
- }
-
- // Push run onto pending-run stack, and maybe merge
- sortState.pushRun(lo, runLen);
- sortState.mergeCollapse();
-
- // Advance to find next run
- lo += runLen;
- nRemaining -= runLen;
- } while (nRemaining != 0);
-
- // Merge all remaining runs to complete sort
- assert lo == hi;
- sortState.mergeForceCollapse();
- assert sortState.stackSize == 1;
- }
-
- /**
- * Sorts the specified portion of the specified array using a binary
- * insertion sort. This is the best method for sorting small numbers
- * of elements. It requires O(n log n) compares, but O(n^2) data
- * movement (worst case).
- *
- * If the initial part of the specified range is already sorted,
- * this method can take advantage of it: the method assumes that the
- * elements from index {@code lo}, inclusive, to {@code start},
- * exclusive are already sorted.
- *
- * @param a the array in which a range is to be sorted
- * @param lo the index of the first element in the range to be sorted
- * @param hi the index after the last element in the range to be sorted
- * @param start the index of the first element in the range that is
- * not already known to be sorted ({@code lo <= start <= hi})
- * @param c comparator to used for the sort
- */
- @SuppressWarnings("fallthrough")
- private void binarySort(Buffer a, int lo, int hi, int start, Comparator super K> c) {
- assert lo <= start && start <= hi;
- if (start == lo)
- start++;
-
- Buffer pivotStore = s.allocate(1);
- for ( ; start < hi; start++) {
- s.copyElement(a, start, pivotStore, 0);
- K pivot = s.getKey(pivotStore, 0);
-
- // Set left (and right) to the index where a[start] (pivot) belongs
- int left = lo;
- int right = start;
- assert left <= right;
- /*
- * Invariants:
- * pivot >= all in [lo, left).
- * pivot < all in [right, start).
- */
- while (left < right) {
- int mid = (left + right) >>> 1;
- if (c.compare(pivot, s.getKey(a, mid)) < 0)
- right = mid;
- else
- left = mid + 1;
- }
- assert left == right;
-
- /*
- * The invariants still hold: pivot >= all in [lo, left) and
- * pivot < all in [left, start), so pivot belongs at left. Note
- * that if there are elements equal to pivot, left points to the
- * first slot after them -- that's why this sort is stable.
- * Slide elements over to make room for pivot.
- */
- int n = start - left; // The number of elements to move
- // Switch is just an optimization for arraycopy in default case
- switch (n) {
- case 2: s.copyElement(a, left + 1, a, left + 2);
- case 1: s.copyElement(a, left, a, left + 1);
- break;
- default: s.copyRange(a, left, a, left + 1, n);
- }
- s.copyElement(pivotStore, 0, a, left);
- }
- }
-
- /**
- * Returns the length of the run beginning at the specified position in
- * the specified array and reverses the run if it is descending (ensuring
- * that the run will always be ascending when the method returns).
- *
- * A run is the longest ascending sequence with:
- *
- * a[lo] <= a[lo + 1] <= a[lo + 2] <= ...
- *
- * or the longest descending sequence with:
- *
- * a[lo] > a[lo + 1] > a[lo + 2] > ...
- *
- * For its intended use in a stable mergesort, the strictness of the
- * definition of "descending" is needed so that the call can safely
- * reverse a descending sequence without violating stability.
- *
- * @param a the array in which a run is to be counted and possibly reversed
- * @param lo index of the first element in the run
- * @param hi index after the last element that may be contained in the run.
- It is required that {@code lo < hi}.
- * @param c the comparator to used for the sort
- * @return the length of the run beginning at the specified position in
- * the specified array
- */
- private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator super K> c) {
- assert lo < hi;
- int runHi = lo + 1;
- if (runHi == hi)
- return 1;
-
- // Find end of run, and reverse range if descending
- if (c.compare(s.getKey(a, runHi++), s.getKey(a, lo)) < 0) { // Descending
- while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) < 0)
- runHi++;
- reverseRange(a, lo, runHi);
- } else { // Ascending
- while (runHi < hi && c.compare(s.getKey(a, runHi), s.getKey(a, runHi - 1)) >= 0)
- runHi++;
- }
-
- return runHi - lo;
- }
-
- /**
- * Reverse the specified range of the specified array.
- *
- * @param a the array in which a range is to be reversed
- * @param lo the index of the first element in the range to be reversed
- * @param hi the index after the last element in the range to be reversed
- */
- private void reverseRange(Buffer a, int lo, int hi) {
- hi--;
- while (lo < hi) {
- s.swap(a, lo, hi);
- lo++;
- hi--;
- }
- }
-
- /**
- * Returns the minimum acceptable run length for an array of the specified
- * length. Natural runs shorter than this will be extended with
- * {@link #binarySort}.
- *
- * Roughly speaking, the computation is:
- *
- * If n < MIN_MERGE, return n (it's too small to bother with fancy stuff).
- * Else if n is an exact power of 2, return MIN_MERGE/2.
- * Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k
- * is close to, but strictly less than, an exact power of 2.
- *
- * For the rationale, see listsort.txt.
- *
- * @param n the length of the array to be sorted
- * @return the length of the minimum run to be merged
- */
- private int minRunLength(int n) {
- assert n >= 0;
- int r = 0; // Becomes 1 if any 1 bits are shifted off
- while (n >= MIN_MERGE) {
- r |= (n & 1);
- n >>= 1;
- }
- return n + r;
- }
-
- private class SortState {
-
- /**
- * The Buffer being sorted.
- */
- private final Buffer a;
-
- /**
- * Length of the sort Buffer.
- */
- private final int aLength;
-
- /**
- * The comparator for this sort.
- */
- private final Comparator super K> c;
-
- /**
- * When we get into galloping mode, we stay there until both runs win less
- * often than MIN_GALLOP consecutive times.
- */
- private static final int MIN_GALLOP = 7;
-
- /**
- * This controls when we get *into* galloping mode. It is initialized
- * to MIN_GALLOP. The mergeLo and mergeHi methods nudge it higher for
- * random data, and lower for highly structured data.
- */
- private int minGallop = MIN_GALLOP;
-
- /**
- * Maximum initial size of tmp array, which is used for merging. The array
- * can grow to accommodate demand.
- *
- * Unlike Tim's original C version, we do not allocate this much storage
- * when sorting smaller arrays. This change was required for performance.
- */
- private static final int INITIAL_TMP_STORAGE_LENGTH = 256;
-
- /**
- * Temp storage for merges.
- */
- private Buffer tmp; // Actual runtime type will be Object[], regardless of T
-
- /**
- * Length of the temp storage.
- */
- private int tmpLength = 0;
-
- /**
- * A stack of pending runs yet to be merged. Run i starts at
- * address base[i] and extends for len[i] elements. It's always
- * true (so long as the indices are in bounds) that:
- *
- * runBase[i] + runLen[i] == runBase[i + 1]
- *
- * so we could cut the storage for this, but it's a minor amount,
- * and keeping all the info explicit simplifies the code.
- */
- private int stackSize = 0; // Number of pending runs on stack
- private final int[] runBase;
- private final int[] runLen;
-
- /**
- * Creates a TimSort instance to maintain the state of an ongoing sort.
- *
- * @param a the array to be sorted
- * @param c the comparator to determine the order of the sort
- */
- private SortState(Buffer a, Comparator super K> c, int len) {
- this.aLength = len;
- this.a = a;
- this.c = c;
-
- // Allocate temp storage (which may be increased later if necessary)
- tmpLength = len < 2 * INITIAL_TMP_STORAGE_LENGTH ? len >>> 1 : INITIAL_TMP_STORAGE_LENGTH;
- tmp = s.allocate(tmpLength);
-
- /*
- * Allocate runs-to-be-merged stack (which cannot be expanded). The
- * stack length requirements are described in listsort.txt. The C
- * version always uses the same stack length (85), but this was
- * measured to be too expensive when sorting "mid-sized" arrays (e.g.,
- * 100 elements) in Java. Therefore, we use smaller (but sufficiently
- * large) stack lengths for smaller arrays. The "magic numbers" in the
- * computation below must be changed if MIN_MERGE is decreased. See
- * the MIN_MERGE declaration above for more information.
- */
- int stackLen = (len < 120 ? 5 :
- len < 1542 ? 10 :
- len < 119151 ? 19 : 40);
- runBase = new int[stackLen];
- runLen = new int[stackLen];
- }
-
- /**
- * Pushes the specified run onto the pending-run stack.
- *
- * @param runBase index of the first element in the run
- * @param runLen the number of elements in the run
- */
- private void pushRun(int runBase, int runLen) {
- this.runBase[stackSize] = runBase;
- this.runLen[stackSize] = runLen;
- stackSize++;
- }
-
- /**
- * Examines the stack of runs waiting to be merged and merges adjacent runs
- * until the stack invariants are reestablished:
- *
- * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
- * 2. runLen[i - 2] > runLen[i - 1]
- *
- * This method is called each time a new run is pushed onto the stack,
- * so the invariants are guaranteed to hold for i < stackSize upon
- * entry to the method.
- */
- private void mergeCollapse() {
- while (stackSize > 1) {
- int n = stackSize - 2;
- if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
- if (runLen[n - 1] < runLen[n + 1])
- n--;
- mergeAt(n);
- } else if (runLen[n] <= runLen[n + 1]) {
- mergeAt(n);
- } else {
- break; // Invariant is established
- }
- }
- }
-
- /**
- * Merges all runs on the stack until only one remains. This method is
- * called once, to complete the sort.
- */
- private void mergeForceCollapse() {
- while (stackSize > 1) {
- int n = stackSize - 2;
- if (n > 0 && runLen[n - 1] < runLen[n + 1])
- n--;
- mergeAt(n);
- }
- }
-
- /**
- * Merges the two runs at stack indices i and i+1. Run i must be
- * the penultimate or antepenultimate run on the stack. In other words,
- * i must be equal to stackSize-2 or stackSize-3.
- *
- * @param i stack index of the first of the two runs to merge
- */
- private void mergeAt(int i) {
- assert stackSize >= 2;
- assert i >= 0;
- assert i == stackSize - 2 || i == stackSize - 3;
-
- int base1 = runBase[i];
- int len1 = runLen[i];
- int base2 = runBase[i + 1];
- int len2 = runLen[i + 1];
- assert len1 > 0 && len2 > 0;
- assert base1 + len1 == base2;
-
- /*
- * Record the length of the combined runs; if i is the 3rd-last
- * run now, also slide over the last run (which isn't involved
- * in this merge). The current run (i+1) goes away in any case.
- */
- runLen[i] = len1 + len2;
- if (i == stackSize - 3) {
- runBase[i + 1] = runBase[i + 2];
- runLen[i + 1] = runLen[i + 2];
- }
- stackSize--;
-
- /*
- * Find where the first element of run2 goes in run1. Prior elements
- * in run1 can be ignored (because they're already in place).
- */
- int k = gallopRight(s.getKey(a, base2), a, base1, len1, 0, c);
- assert k >= 0;
- base1 += k;
- len1 -= k;
- if (len1 == 0)
- return;
-
- /*
- * Find where the last element of run1 goes in run2. Subsequent elements
- * in run2 can be ignored (because they're already in place).
- */
- len2 = gallopLeft(s.getKey(a, base1 + len1 - 1), a, base2, len2, len2 - 1, c);
- assert len2 >= 0;
- if (len2 == 0)
- return;
-
- // Merge remaining runs, using tmp array with min(len1, len2) elements
- if (len1 <= len2)
- mergeLo(base1, len1, base2, len2);
- else
- mergeHi(base1, len1, base2, len2);
- }
-
- /**
- * Locates the position at which to insert the specified key into the
- * specified sorted range; if the range contains an element equal to key,
- * returns the index of the leftmost equal element.
- *
- * @param key the key whose insertion point to search for
- * @param a the array in which to search
- * @param base the index of the first element in the range
- * @param len the length of the range; must be > 0
- * @param hint the index at which to begin the search, 0 <= hint < n.
- * The closer hint is to the result, the faster this method will run.
- * @param c the comparator used to order the range, and to search
- * @return the int k, 0 <= k <= n such that a[b + k - 1] < key <= a[b + k],
- * pretending that a[b - 1] is minus infinity and a[b + n] is infinity.
- * In other words, key belongs at index b + k; or in other words,
- * the first k elements of a should precede key, and the last n - k
- * should follow it.
- */
- private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
- assert len > 0 && hint >= 0 && hint < len;
- int lastOfs = 0;
- int ofs = 1;
- if (c.compare(key, s.getKey(a, base + hint)) > 0) {
- // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
- int maxOfs = len - hint;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) > 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to base
- lastOfs += hint;
- ofs += hint;
- } else { // key <= a[base + hint]
- // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
- final int maxOfs = hint + 1;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) <= 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to base
- int tmp = lastOfs;
- lastOfs = hint - ofs;
- ofs = hint - tmp;
- }
- assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
-
- /*
- * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere
- * to the right of lastOfs but no farther right than ofs. Do a binary
- * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs].
- */
- lastOfs++;
- while (lastOfs < ofs) {
- int m = lastOfs + ((ofs - lastOfs) >>> 1);
-
- if (c.compare(key, s.getKey(a, base + m)) > 0)
- lastOfs = m + 1; // a[base + m] < key
- else
- ofs = m; // key <= a[base + m]
- }
- assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs]
- return ofs;
- }
-
- /**
- * Like gallopLeft, except that if the range contains an element equal to
- * key, gallopRight returns the index after the rightmost equal element.
- *
- * @param key the key whose insertion point to search for
- * @param a the array in which to search
- * @param base the index of the first element in the range
- * @param len the length of the range; must be > 0
- * @param hint the index at which to begin the search, 0 <= hint < n.
- * The closer hint is to the result, the faster this method will run.
- * @param c the comparator used to order the range, and to search
- * @return the int k, 0 <= k <= n such that a[b + k - 1] <= key < a[b + k]
- */
- private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
- assert len > 0 && hint >= 0 && hint < len;
-
- int ofs = 1;
- int lastOfs = 0;
- if (c.compare(key, s.getKey(a, base + hint)) < 0) {
- // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
- int maxOfs = hint + 1;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs)) < 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to b
- int tmp = lastOfs;
- lastOfs = hint - ofs;
- ofs = hint - tmp;
- } else { // a[b + hint] <= key
- // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
- int maxOfs = len - hint;
- while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs)) >= 0) {
- lastOfs = ofs;
- ofs = (ofs << 1) + 1;
- if (ofs <= 0) // int overflow
- ofs = maxOfs;
- }
- if (ofs > maxOfs)
- ofs = maxOfs;
-
- // Make offsets relative to b
- lastOfs += hint;
- ofs += hint;
- }
- assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
-
- /*
- * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to
- * the right of lastOfs but no farther right than ofs. Do a binary
- * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs].
- */
- lastOfs++;
- while (lastOfs < ofs) {
- int m = lastOfs + ((ofs - lastOfs) >>> 1);
-
- if (c.compare(key, s.getKey(a, base + m)) < 0)
- ofs = m; // key < a[b + m]
- else
- lastOfs = m + 1; // a[b + m] <= key
- }
- assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs]
- return ofs;
- }
-
- /**
- * Merges two adjacent runs in place, in a stable fashion. The first
- * element of the first run must be greater than the first element of the
- * second run (a[base1] > a[base2]), and the last element of the first run
- * (a[base1 + len1-1]) must be greater than all elements of the second run.
- *
- * For performance, this method should be called only when len1 <= len2;
- * its twin, mergeHi should be called if len1 >= len2. (Either method
- * may be called if len1 == len2.)
- *
- * @param base1 index of first element in first run to be merged
- * @param len1 length of first run to be merged (must be > 0)
- * @param base2 index of first element in second run to be merged
- * (must be aBase + aLen)
- * @param len2 length of second run to be merged (must be > 0)
- */
- private void mergeLo(int base1, int len1, int base2, int len2) {
- assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
-
- // Copy first run into temp array
- Buffer a = this.a; // For performance
- Buffer tmp = ensureCapacity(len1);
- s.copyRange(a, base1, tmp, 0, len1);
-
- int cursor1 = 0; // Indexes into tmp array
- int cursor2 = base2; // Indexes int a
- int dest = base1; // Indexes int a
-
- // Move first element of second run and deal with degenerate cases
- s.copyElement(a, cursor2++, a, dest++);
- if (--len2 == 0) {
- s.copyRange(tmp, cursor1, a, dest, len1);
- return;
- }
- if (len1 == 1) {
- s.copyRange(a, cursor2, a, dest, len2);
- s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
- return;
- }
-
- Comparator super K> c = this.c; // Use local variable for performance
- int minGallop = this.minGallop; // " " " " "
- outer:
- while (true) {
- int count1 = 0; // Number of times in a row that first run won
- int count2 = 0; // Number of times in a row that second run won
-
- /*
- * Do the straightforward thing until (if ever) one run starts
- * winning consistently.
- */
- do {
- assert len1 > 1 && len2 > 0;
- if (c.compare(s.getKey(a, cursor2), s.getKey(tmp, cursor1)) < 0) {
- s.copyElement(a, cursor2++, a, dest++);
- count2++;
- count1 = 0;
- if (--len2 == 0)
- break outer;
- } else {
- s.copyElement(tmp, cursor1++, a, dest++);
- count1++;
- count2 = 0;
- if (--len1 == 1)
- break outer;
- }
- } while ((count1 | count2) < minGallop);
-
- /*
- * One run is winning so consistently that galloping may be a
- * huge win. So try that, and continue galloping until (if ever)
- * neither run appears to be winning consistently anymore.
- */
- do {
- assert len1 > 1 && len2 > 0;
- count1 = gallopRight(s.getKey(a, cursor2), tmp, cursor1, len1, 0, c);
- if (count1 != 0) {
- s.copyRange(tmp, cursor1, a, dest, count1);
- dest += count1;
- cursor1 += count1;
- len1 -= count1;
- if (len1 <= 1) // len1 == 1 || len1 == 0
- break outer;
- }
- s.copyElement(a, cursor2++, a, dest++);
- if (--len2 == 0)
- break outer;
-
- count2 = gallopLeft(s.getKey(tmp, cursor1), a, cursor2, len2, 0, c);
- if (count2 != 0) {
- s.copyRange(a, cursor2, a, dest, count2);
- dest += count2;
- cursor2 += count2;
- len2 -= count2;
- if (len2 == 0)
- break outer;
- }
- s.copyElement(tmp, cursor1++, a, dest++);
- if (--len1 == 1)
- break outer;
- minGallop--;
- } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
- if (minGallop < 0)
- minGallop = 0;
- minGallop += 2; // Penalize for leaving gallop mode
- } // End of "outer" loop
- this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
-
- if (len1 == 1) {
- assert len2 > 0;
- s.copyRange(a, cursor2, a, dest, len2);
- s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
- } else if (len1 == 0) {
- throw new IllegalArgumentException(
- "Comparison method violates its general contract!");
- } else {
- assert len2 == 0;
- assert len1 > 1;
- s.copyRange(tmp, cursor1, a, dest, len1);
- }
- }
-
- /**
- * Like mergeLo, except that this method should be called only if
- * len1 >= len2; mergeLo should be called if len1 <= len2. (Either method
- * may be called if len1 == len2.)
- *
- * @param base1 index of first element in first run to be merged
- * @param len1 length of first run to be merged (must be > 0)
- * @param base2 index of first element in second run to be merged
- * (must be aBase + aLen)
- * @param len2 length of second run to be merged (must be > 0)
- */
- private void mergeHi(int base1, int len1, int base2, int len2) {
- assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
-
- // Copy second run into temp array
- Buffer a = this.a; // For performance
- Buffer tmp = ensureCapacity(len2);
- s.copyRange(a, base2, tmp, 0, len2);
-
- int cursor1 = base1 + len1 - 1; // Indexes into a
- int cursor2 = len2 - 1; // Indexes into tmp array
- int dest = base2 + len2 - 1; // Indexes into a
-
- // Move last element of first run and deal with degenerate cases
- s.copyElement(a, cursor1--, a, dest--);
- if (--len1 == 0) {
- s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
- return;
- }
- if (len2 == 1) {
- dest -= len1;
- cursor1 -= len1;
- s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
- s.copyElement(tmp, cursor2, a, dest);
- return;
- }
-
- Comparator super K> c = this.c; // Use local variable for performance
- int minGallop = this.minGallop; // " " " " "
- outer:
- while (true) {
- int count1 = 0; // Number of times in a row that first run won
- int count2 = 0; // Number of times in a row that second run won
-
- /*
- * Do the straightforward thing until (if ever) one run
- * appears to win consistently.
- */
- do {
- assert len1 > 0 && len2 > 1;
- if (c.compare(s.getKey(tmp, cursor2), s.getKey(a, cursor1)) < 0) {
- s.copyElement(a, cursor1--, a, dest--);
- count1++;
- count2 = 0;
- if (--len1 == 0)
- break outer;
- } else {
- s.copyElement(tmp, cursor2--, a, dest--);
- count2++;
- count1 = 0;
- if (--len2 == 1)
- break outer;
- }
- } while ((count1 | count2) < minGallop);
-
- /*
- * One run is winning so consistently that galloping may be a
- * huge win. So try that, and continue galloping until (if ever)
- * neither run appears to be winning consistently anymore.
- */
- do {
- assert len1 > 0 && len2 > 1;
- count1 = len1 - gallopRight(s.getKey(tmp, cursor2), a, base1, len1, len1 - 1, c);
- if (count1 != 0) {
- dest -= count1;
- cursor1 -= count1;
- len1 -= count1;
- s.copyRange(a, cursor1 + 1, a, dest + 1, count1);
- if (len1 == 0)
- break outer;
- }
- s.copyElement(tmp, cursor2--, a, dest--);
- if (--len2 == 1)
- break outer;
-
- count2 = len2 - gallopLeft(s.getKey(a, cursor1), tmp, 0, len2, len2 - 1, c);
- if (count2 != 0) {
- dest -= count2;
- cursor2 -= count2;
- len2 -= count2;
- s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2);
- if (len2 <= 1) // len2 == 1 || len2 == 0
- break outer;
- }
- s.copyElement(a, cursor1--, a, dest--);
- if (--len1 == 0)
- break outer;
- minGallop--;
- } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
- if (minGallop < 0)
- minGallop = 0;
- minGallop += 2; // Penalize for leaving gallop mode
- } // End of "outer" loop
- this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
-
- if (len2 == 1) {
- assert len1 > 0;
- dest -= len1;
- cursor1 -= len1;
- s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
- s.copyElement(tmp, cursor2, a, dest); // Move first elt of run2 to front of merge
- } else if (len2 == 0) {
- throw new IllegalArgumentException(
- "Comparison method violates its general contract!");
- } else {
- assert len1 == 0;
- assert len2 > 0;
- s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
- }
- }
-
- /**
- * Ensures that the external array tmp has at least the specified
- * number of elements, increasing its size if necessary. The size
- * increases exponentially to ensure amortized linear time complexity.
- *
- * @param minCapacity the minimum required capacity of the tmp array
- * @return tmp, whether or not it grew
- */
- private Buffer ensureCapacity(int minCapacity) {
- if (tmpLength < minCapacity) {
- // Compute smallest power of 2 > minCapacity
- int newSize = minCapacity;
- newSize |= newSize >> 1;
- newSize |= newSize >> 2;
- newSize |= newSize >> 4;
- newSize |= newSize >> 8;
- newSize |= newSize >> 16;
- newSize++;
-
- if (newSize < 0) // Not bloody likely!
- newSize = minCapacity;
- else
- newSize = Math.min(newSize, aLength >>> 1);
-
- tmp = s.allocate(newSize);
- tmpLength = newSize;
- }
- return tmp;
- }
- }
-}
diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
new file mode 100644
index 0000000000000..409e1a41c5d49
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
@@ -0,0 +1,940 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection;
+
+import java.util.Comparator;
+
+/**
+ * A port of the Android TimSort class, which utilizes a "stable, adaptive, iterative mergesort."
+ * See the method comment on sort() for more details.
+ *
+ * This has been kept in Java with the original style in order to match very closely with the
+ * Android source code, and thus be easy to verify correctness. The class is package private. We put
+ * a simple Scala wrapper {@link org.apache.spark.util.collection.Sorter}, which is available to
+ * package org.apache.spark.
+ *
+ * The purpose of the port is to generalize the interface to the sort to accept input data formats
+ * besides simple arrays where every element is sorted individually. For instance, the AppendOnlyMap
+ * uses this to sort an Array with alternating elements of the form [key, value, key, value].
+ * This generalization comes with minimal overhead -- see SortDataFormat for more information.
+ *
+ * We allow key reuse to prevent creating many key objects -- see SortDataFormat.
+ *
+ * @see org.apache.spark.util.collection.SortDataFormat
+ * @see org.apache.spark.util.collection.Sorter
+ */
+class TimSort {
+
+ /**
+ * This is the minimum sized sequence that will be merged. Shorter
+ * sequences will be lengthened by calling binarySort. If the entire
+ * array is less than this length, no merges will be performed.
+ *
+ * This constant should be a power of two. It was 64 in Tim Peter's C
+ * implementation, but 32 was empirically determined to work better in
+ * this implementation. In the unlikely event that you set this constant
+ * to be a number that's not a power of two, you'll need to change the
+ * minRunLength computation.
+ *
+ * If you decrease this constant, you must change the stackLen
+ * computation in the TimSort constructor, or you risk an
+ * ArrayOutOfBounds exception. See listsort.txt for a discussion
+ * of the minimum stack length required as a function of the length
+ * of the array being sorted and the minimum merge sequence length.
+ */
+ private static final int MIN_MERGE = 32;
+
+ private final SortDataFormat s;
+
+ public TimSort(SortDataFormat sortDataFormat) {
+ this.s = sortDataFormat;
+ }
+
+ /**
+ * A stable, adaptive, iterative mergesort that requires far fewer than
+ * n lg(n) comparisons when running on partially sorted arrays, while
+ * offering performance comparable to a traditional mergesort when run
+ * on random arrays. Like all proper mergesorts, this sort is stable and
+ * runs O(n log n) time (worst case). In the worst case, this sort requires
+ * temporary storage space for n/2 object references; in the best case,
+ * it requires only a small constant amount of space.
+ *
+ * This implementation was adapted from Tim Peters's list sort for
+ * Python, which is described in detail here:
+ *
+ * http://svn.python.org/projects/python/trunk/Objects/listsort.txt
+ *
+ * Tim's C code may be found here:
+ *
+ * http://svn.python.org/projects/python/trunk/Objects/listobject.c
+ *
+ * The underlying techniques are described in this paper (and may have
+ * even earlier origins):
+ *
+ * "Optimistic Sorting and Information Theoretic Complexity"
+ * Peter McIlroy
+ * SODA (Fourth Annual ACM-SIAM Symposium on Discrete Algorithms),
+ * pp 467-474, Austin, Texas, 25-27 January 1993.
+ *
+ * While the API to this class consists solely of static methods, it is
+ * (privately) instantiable; a TimSort instance holds the state of an ongoing
+ * sort, assuming the input array is large enough to warrant the full-blown
+ * TimSort. Small arrays are sorted in place, using a binary insertion sort.
+ *
+ * @author Josh Bloch
+ */
+ public void sort(Buffer a, int lo, int hi, Comparator super K> c) {
+ assert c != null;
+
+ int nRemaining = hi - lo;
+ if (nRemaining < 2)
+ return; // Arrays of size 0 and 1 are always sorted
+
+ // If array is small, do a "mini-TimSort" with no merges
+ if (nRemaining < MIN_MERGE) {
+ int initRunLen = countRunAndMakeAscending(a, lo, hi, c);
+ binarySort(a, lo, hi, lo + initRunLen, c);
+ return;
+ }
+
+ /**
+ * March over the array once, left to right, finding natural runs,
+ * extending short natural runs to minRun elements, and merging runs
+ * to maintain stack invariant.
+ */
+ SortState sortState = new SortState(a, c, hi - lo);
+ int minRun = minRunLength(nRemaining);
+ do {
+ // Identify next run
+ int runLen = countRunAndMakeAscending(a, lo, hi, c);
+
+ // If run is short, extend to min(minRun, nRemaining)
+ if (runLen < minRun) {
+ int force = nRemaining <= minRun ? nRemaining : minRun;
+ binarySort(a, lo, lo + force, lo + runLen, c);
+ runLen = force;
+ }
+
+ // Push run onto pending-run stack, and maybe merge
+ sortState.pushRun(lo, runLen);
+ sortState.mergeCollapse();
+
+ // Advance to find next run
+ lo += runLen;
+ nRemaining -= runLen;
+ } while (nRemaining != 0);
+
+ // Merge all remaining runs to complete sort
+ assert lo == hi;
+ sortState.mergeForceCollapse();
+ assert sortState.stackSize == 1;
+ }
+
+ /**
+ * Sorts the specified portion of the specified array using a binary
+ * insertion sort. This is the best method for sorting small numbers
+ * of elements. It requires O(n log n) compares, but O(n^2) data
+ * movement (worst case).
+ *
+ * If the initial part of the specified range is already sorted,
+ * this method can take advantage of it: the method assumes that the
+ * elements from index {@code lo}, inclusive, to {@code start},
+ * exclusive are already sorted.
+ *
+ * @param a the array in which a range is to be sorted
+ * @param lo the index of the first element in the range to be sorted
+ * @param hi the index after the last element in the range to be sorted
+ * @param start the index of the first element in the range that is
+ * not already known to be sorted ({@code lo <= start <= hi})
+ * @param c comparator to used for the sort
+ */
+ @SuppressWarnings("fallthrough")
+ private void binarySort(Buffer a, int lo, int hi, int start, Comparator super K> c) {
+ assert lo <= start && start <= hi;
+ if (start == lo)
+ start++;
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ Buffer pivotStore = s.allocate(1);
+ for ( ; start < hi; start++) {
+ s.copyElement(a, start, pivotStore, 0);
+ K pivot = s.getKey(pivotStore, 0, key0);
+
+ // Set left (and right) to the index where a[start] (pivot) belongs
+ int left = lo;
+ int right = start;
+ assert left <= right;
+ /*
+ * Invariants:
+ * pivot >= all in [lo, left).
+ * pivot < all in [right, start).
+ */
+ while (left < right) {
+ int mid = (left + right) >>> 1;
+ if (c.compare(pivot, s.getKey(a, mid, key1)) < 0)
+ right = mid;
+ else
+ left = mid + 1;
+ }
+ assert left == right;
+
+ /*
+ * The invariants still hold: pivot >= all in [lo, left) and
+ * pivot < all in [left, start), so pivot belongs at left. Note
+ * that if there are elements equal to pivot, left points to the
+ * first slot after them -- that's why this sort is stable.
+ * Slide elements over to make room for pivot.
+ */
+ int n = start - left; // The number of elements to move
+ // Switch is just an optimization for arraycopy in default case
+ switch (n) {
+ case 2: s.copyElement(a, left + 1, a, left + 2);
+ case 1: s.copyElement(a, left, a, left + 1);
+ break;
+ default: s.copyRange(a, left, a, left + 1, n);
+ }
+ s.copyElement(pivotStore, 0, a, left);
+ }
+ }
+
+ /**
+ * Returns the length of the run beginning at the specified position in
+ * the specified array and reverses the run if it is descending (ensuring
+ * that the run will always be ascending when the method returns).
+ *
+ * A run is the longest ascending sequence with:
+ *
+ * a[lo] <= a[lo + 1] <= a[lo + 2] <= ...
+ *
+ * or the longest descending sequence with:
+ *
+ * a[lo] > a[lo + 1] > a[lo + 2] > ...
+ *
+ * For its intended use in a stable mergesort, the strictness of the
+ * definition of "descending" is needed so that the call can safely
+ * reverse a descending sequence without violating stability.
+ *
+ * @param a the array in which a run is to be counted and possibly reversed
+ * @param lo index of the first element in the run
+ * @param hi index after the last element that may be contained in the run.
+ It is required that {@code lo < hi}.
+ * @param c the comparator to used for the sort
+ * @return the length of the run beginning at the specified position in
+ * the specified array
+ */
+ private int countRunAndMakeAscending(Buffer a, int lo, int hi, Comparator super K> c) {
+ assert lo < hi;
+ int runHi = lo + 1;
+ if (runHi == hi)
+ return 1;
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ // Find end of run, and reverse range if descending
+ if (c.compare(s.getKey(a, runHi++, key0), s.getKey(a, lo, key1)) < 0) { // Descending
+ while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) < 0)
+ runHi++;
+ reverseRange(a, lo, runHi);
+ } else { // Ascending
+ while (runHi < hi && c.compare(s.getKey(a, runHi, key0), s.getKey(a, runHi - 1, key1)) >= 0)
+ runHi++;
+ }
+
+ return runHi - lo;
+ }
+
+ /**
+ * Reverse the specified range of the specified array.
+ *
+ * @param a the array in which a range is to be reversed
+ * @param lo the index of the first element in the range to be reversed
+ * @param hi the index after the last element in the range to be reversed
+ */
+ private void reverseRange(Buffer a, int lo, int hi) {
+ hi--;
+ while (lo < hi) {
+ s.swap(a, lo, hi);
+ lo++;
+ hi--;
+ }
+ }
+
+ /**
+ * Returns the minimum acceptable run length for an array of the specified
+ * length. Natural runs shorter than this will be extended with
+ * {@link #binarySort}.
+ *
+ * Roughly speaking, the computation is:
+ *
+ * If n < MIN_MERGE, return n (it's too small to bother with fancy stuff).
+ * Else if n is an exact power of 2, return MIN_MERGE/2.
+ * Else return an int k, MIN_MERGE/2 <= k <= MIN_MERGE, such that n/k
+ * is close to, but strictly less than, an exact power of 2.
+ *
+ * For the rationale, see listsort.txt.
+ *
+ * @param n the length of the array to be sorted
+ * @return the length of the minimum run to be merged
+ */
+ private int minRunLength(int n) {
+ assert n >= 0;
+ int r = 0; // Becomes 1 if any 1 bits are shifted off
+ while (n >= MIN_MERGE) {
+ r |= (n & 1);
+ n >>= 1;
+ }
+ return n + r;
+ }
+
+ private class SortState {
+
+ /**
+ * The Buffer being sorted.
+ */
+ private final Buffer a;
+
+ /**
+ * Length of the sort Buffer.
+ */
+ private final int aLength;
+
+ /**
+ * The comparator for this sort.
+ */
+ private final Comparator super K> c;
+
+ /**
+ * When we get into galloping mode, we stay there until both runs win less
+ * often than MIN_GALLOP consecutive times.
+ */
+ private static final int MIN_GALLOP = 7;
+
+ /**
+ * This controls when we get *into* galloping mode. It is initialized
+ * to MIN_GALLOP. The mergeLo and mergeHi methods nudge it higher for
+ * random data, and lower for highly structured data.
+ */
+ private int minGallop = MIN_GALLOP;
+
+ /**
+ * Maximum initial size of tmp array, which is used for merging. The array
+ * can grow to accommodate demand.
+ *
+ * Unlike Tim's original C version, we do not allocate this much storage
+ * when sorting smaller arrays. This change was required for performance.
+ */
+ private static final int INITIAL_TMP_STORAGE_LENGTH = 256;
+
+ /**
+ * Temp storage for merges.
+ */
+ private Buffer tmp; // Actual runtime type will be Object[], regardless of T
+
+ /**
+ * Length of the temp storage.
+ */
+ private int tmpLength = 0;
+
+ /**
+ * A stack of pending runs yet to be merged. Run i starts at
+ * address base[i] and extends for len[i] elements. It's always
+ * true (so long as the indices are in bounds) that:
+ *
+ * runBase[i] + runLen[i] == runBase[i + 1]
+ *
+ * so we could cut the storage for this, but it's a minor amount,
+ * and keeping all the info explicit simplifies the code.
+ */
+ private int stackSize = 0; // Number of pending runs on stack
+ private final int[] runBase;
+ private final int[] runLen;
+
+ /**
+ * Creates a TimSort instance to maintain the state of an ongoing sort.
+ *
+ * @param a the array to be sorted
+ * @param c the comparator to determine the order of the sort
+ */
+ private SortState(Buffer a, Comparator super K> c, int len) {
+ this.aLength = len;
+ this.a = a;
+ this.c = c;
+
+ // Allocate temp storage (which may be increased later if necessary)
+ tmpLength = len < 2 * INITIAL_TMP_STORAGE_LENGTH ? len >>> 1 : INITIAL_TMP_STORAGE_LENGTH;
+ tmp = s.allocate(tmpLength);
+
+ /*
+ * Allocate runs-to-be-merged stack (which cannot be expanded). The
+ * stack length requirements are described in listsort.txt. The C
+ * version always uses the same stack length (85), but this was
+ * measured to be too expensive when sorting "mid-sized" arrays (e.g.,
+ * 100 elements) in Java. Therefore, we use smaller (but sufficiently
+ * large) stack lengths for smaller arrays. The "magic numbers" in the
+ * computation below must be changed if MIN_MERGE is decreased. See
+ * the MIN_MERGE declaration above for more information.
+ */
+ int stackLen = (len < 120 ? 5 :
+ len < 1542 ? 10 :
+ len < 119151 ? 19 : 40);
+ runBase = new int[stackLen];
+ runLen = new int[stackLen];
+ }
+
+ /**
+ * Pushes the specified run onto the pending-run stack.
+ *
+ * @param runBase index of the first element in the run
+ * @param runLen the number of elements in the run
+ */
+ private void pushRun(int runBase, int runLen) {
+ this.runBase[stackSize] = runBase;
+ this.runLen[stackSize] = runLen;
+ stackSize++;
+ }
+
+ /**
+ * Examines the stack of runs waiting to be merged and merges adjacent runs
+ * until the stack invariants are reestablished:
+ *
+ * 1. runLen[i - 3] > runLen[i - 2] + runLen[i - 1]
+ * 2. runLen[i - 2] > runLen[i - 1]
+ *
+ * This method is called each time a new run is pushed onto the stack,
+ * so the invariants are guaranteed to hold for i < stackSize upon
+ * entry to the method.
+ */
+ private void mergeCollapse() {
+ while (stackSize > 1) {
+ int n = stackSize - 2;
+ if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
+ if (runLen[n - 1] < runLen[n + 1])
+ n--;
+ mergeAt(n);
+ } else if (runLen[n] <= runLen[n + 1]) {
+ mergeAt(n);
+ } else {
+ break; // Invariant is established
+ }
+ }
+ }
+
+ /**
+ * Merges all runs on the stack until only one remains. This method is
+ * called once, to complete the sort.
+ */
+ private void mergeForceCollapse() {
+ while (stackSize > 1) {
+ int n = stackSize - 2;
+ if (n > 0 && runLen[n - 1] < runLen[n + 1])
+ n--;
+ mergeAt(n);
+ }
+ }
+
+ /**
+ * Merges the two runs at stack indices i and i+1. Run i must be
+ * the penultimate or antepenultimate run on the stack. In other words,
+ * i must be equal to stackSize-2 or stackSize-3.
+ *
+ * @param i stack index of the first of the two runs to merge
+ */
+ private void mergeAt(int i) {
+ assert stackSize >= 2;
+ assert i >= 0;
+ assert i == stackSize - 2 || i == stackSize - 3;
+
+ int base1 = runBase[i];
+ int len1 = runLen[i];
+ int base2 = runBase[i + 1];
+ int len2 = runLen[i + 1];
+ assert len1 > 0 && len2 > 0;
+ assert base1 + len1 == base2;
+
+ /*
+ * Record the length of the combined runs; if i is the 3rd-last
+ * run now, also slide over the last run (which isn't involved
+ * in this merge). The current run (i+1) goes away in any case.
+ */
+ runLen[i] = len1 + len2;
+ if (i == stackSize - 3) {
+ runBase[i + 1] = runBase[i + 2];
+ runLen[i + 1] = runLen[i + 2];
+ }
+ stackSize--;
+
+ K key0 = s.newKey();
+
+ /*
+ * Find where the first element of run2 goes in run1. Prior elements
+ * in run1 can be ignored (because they're already in place).
+ */
+ int k = gallopRight(s.getKey(a, base2, key0), a, base1, len1, 0, c);
+ assert k >= 0;
+ base1 += k;
+ len1 -= k;
+ if (len1 == 0)
+ return;
+
+ /*
+ * Find where the last element of run1 goes in run2. Subsequent elements
+ * in run2 can be ignored (because they're already in place).
+ */
+ len2 = gallopLeft(s.getKey(a, base1 + len1 - 1, key0), a, base2, len2, len2 - 1, c);
+ assert len2 >= 0;
+ if (len2 == 0)
+ return;
+
+ // Merge remaining runs, using tmp array with min(len1, len2) elements
+ if (len1 <= len2)
+ mergeLo(base1, len1, base2, len2);
+ else
+ mergeHi(base1, len1, base2, len2);
+ }
+
+ /**
+ * Locates the position at which to insert the specified key into the
+ * specified sorted range; if the range contains an element equal to key,
+ * returns the index of the leftmost equal element.
+ *
+ * @param key the key whose insertion point to search for
+ * @param a the array in which to search
+ * @param base the index of the first element in the range
+ * @param len the length of the range; must be > 0
+ * @param hint the index at which to begin the search, 0 <= hint < n.
+ * The closer hint is to the result, the faster this method will run.
+ * @param c the comparator used to order the range, and to search
+ * @return the int k, 0 <= k <= n such that a[b + k - 1] < key <= a[b + k],
+ * pretending that a[b - 1] is minus infinity and a[b + n] is infinity.
+ * In other words, key belongs at index b + k; or in other words,
+ * the first k elements of a should precede key, and the last n - k
+ * should follow it.
+ */
+ private int gallopLeft(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
+ assert len > 0 && hint >= 0 && hint < len;
+ int lastOfs = 0;
+ int ofs = 1;
+ K key0 = s.newKey();
+
+ if (c.compare(key, s.getKey(a, base + hint, key0)) > 0) {
+ // Gallop right until a[base+hint+lastOfs] < key <= a[base+hint+ofs]
+ int maxOfs = len - hint;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key0)) > 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to base
+ lastOfs += hint;
+ ofs += hint;
+ } else { // key <= a[base + hint]
+ // Gallop left until a[base+hint-ofs] < key <= a[base+hint-lastOfs]
+ final int maxOfs = hint + 1;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key0)) <= 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to base
+ int tmp = lastOfs;
+ lastOfs = hint - ofs;
+ ofs = hint - tmp;
+ }
+ assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
+
+ /*
+ * Now a[base+lastOfs] < key <= a[base+ofs], so key belongs somewhere
+ * to the right of lastOfs but no farther right than ofs. Do a binary
+ * search, with invariant a[base + lastOfs - 1] < key <= a[base + ofs].
+ */
+ lastOfs++;
+ while (lastOfs < ofs) {
+ int m = lastOfs + ((ofs - lastOfs) >>> 1);
+
+ if (c.compare(key, s.getKey(a, base + m, key0)) > 0)
+ lastOfs = m + 1; // a[base + m] < key
+ else
+ ofs = m; // key <= a[base + m]
+ }
+ assert lastOfs == ofs; // so a[base + ofs - 1] < key <= a[base + ofs]
+ return ofs;
+ }
+
+ /**
+ * Like gallopLeft, except that if the range contains an element equal to
+ * key, gallopRight returns the index after the rightmost equal element.
+ *
+ * @param key the key whose insertion point to search for
+ * @param a the array in which to search
+ * @param base the index of the first element in the range
+ * @param len the length of the range; must be > 0
+ * @param hint the index at which to begin the search, 0 <= hint < n.
+ * The closer hint is to the result, the faster this method will run.
+ * @param c the comparator used to order the range, and to search
+ * @return the int k, 0 <= k <= n such that a[b + k - 1] <= key < a[b + k]
+ */
+ private int gallopRight(K key, Buffer a, int base, int len, int hint, Comparator super K> c) {
+ assert len > 0 && hint >= 0 && hint < len;
+
+ int ofs = 1;
+ int lastOfs = 0;
+ K key1 = s.newKey();
+
+ if (c.compare(key, s.getKey(a, base + hint, key1)) < 0) {
+ // Gallop left until a[b+hint - ofs] <= key < a[b+hint - lastOfs]
+ int maxOfs = hint + 1;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint - ofs, key1)) < 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to b
+ int tmp = lastOfs;
+ lastOfs = hint - ofs;
+ ofs = hint - tmp;
+ } else { // a[b + hint] <= key
+ // Gallop right until a[b+hint + lastOfs] <= key < a[b+hint + ofs]
+ int maxOfs = len - hint;
+ while (ofs < maxOfs && c.compare(key, s.getKey(a, base + hint + ofs, key1)) >= 0) {
+ lastOfs = ofs;
+ ofs = (ofs << 1) + 1;
+ if (ofs <= 0) // int overflow
+ ofs = maxOfs;
+ }
+ if (ofs > maxOfs)
+ ofs = maxOfs;
+
+ // Make offsets relative to b
+ lastOfs += hint;
+ ofs += hint;
+ }
+ assert -1 <= lastOfs && lastOfs < ofs && ofs <= len;
+
+ /*
+ * Now a[b + lastOfs] <= key < a[b + ofs], so key belongs somewhere to
+ * the right of lastOfs but no farther right than ofs. Do a binary
+ * search, with invariant a[b + lastOfs - 1] <= key < a[b + ofs].
+ */
+ lastOfs++;
+ while (lastOfs < ofs) {
+ int m = lastOfs + ((ofs - lastOfs) >>> 1);
+
+ if (c.compare(key, s.getKey(a, base + m, key1)) < 0)
+ ofs = m; // key < a[b + m]
+ else
+ lastOfs = m + 1; // a[b + m] <= key
+ }
+ assert lastOfs == ofs; // so a[b + ofs - 1] <= key < a[b + ofs]
+ return ofs;
+ }
+
+ /**
+ * Merges two adjacent runs in place, in a stable fashion. The first
+ * element of the first run must be greater than the first element of the
+ * second run (a[base1] > a[base2]), and the last element of the first run
+ * (a[base1 + len1-1]) must be greater than all elements of the second run.
+ *
+ * For performance, this method should be called only when len1 <= len2;
+ * its twin, mergeHi should be called if len1 >= len2. (Either method
+ * may be called if len1 == len2.)
+ *
+ * @param base1 index of first element in first run to be merged
+ * @param len1 length of first run to be merged (must be > 0)
+ * @param base2 index of first element in second run to be merged
+ * (must be aBase + aLen)
+ * @param len2 length of second run to be merged (must be > 0)
+ */
+ private void mergeLo(int base1, int len1, int base2, int len2) {
+ assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
+
+ // Copy first run into temp array
+ Buffer a = this.a; // For performance
+ Buffer tmp = ensureCapacity(len1);
+ s.copyRange(a, base1, tmp, 0, len1);
+
+ int cursor1 = 0; // Indexes into tmp array
+ int cursor2 = base2; // Indexes int a
+ int dest = base1; // Indexes int a
+
+ // Move first element of second run and deal with degenerate cases
+ s.copyElement(a, cursor2++, a, dest++);
+ if (--len2 == 0) {
+ s.copyRange(tmp, cursor1, a, dest, len1);
+ return;
+ }
+ if (len1 == 1) {
+ s.copyRange(a, cursor2, a, dest, len2);
+ s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
+ return;
+ }
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ Comparator super K> c = this.c; // Use local variable for performance
+ int minGallop = this.minGallop; // " " " " "
+ outer:
+ while (true) {
+ int count1 = 0; // Number of times in a row that first run won
+ int count2 = 0; // Number of times in a row that second run won
+
+ /*
+ * Do the straightforward thing until (if ever) one run starts
+ * winning consistently.
+ */
+ do {
+ assert len1 > 1 && len2 > 0;
+ if (c.compare(s.getKey(a, cursor2, key0), s.getKey(tmp, cursor1, key1)) < 0) {
+ s.copyElement(a, cursor2++, a, dest++);
+ count2++;
+ count1 = 0;
+ if (--len2 == 0)
+ break outer;
+ } else {
+ s.copyElement(tmp, cursor1++, a, dest++);
+ count1++;
+ count2 = 0;
+ if (--len1 == 1)
+ break outer;
+ }
+ } while ((count1 | count2) < minGallop);
+
+ /*
+ * One run is winning so consistently that galloping may be a
+ * huge win. So try that, and continue galloping until (if ever)
+ * neither run appears to be winning consistently anymore.
+ */
+ do {
+ assert len1 > 1 && len2 > 0;
+ count1 = gallopRight(s.getKey(a, cursor2, key0), tmp, cursor1, len1, 0, c);
+ if (count1 != 0) {
+ s.copyRange(tmp, cursor1, a, dest, count1);
+ dest += count1;
+ cursor1 += count1;
+ len1 -= count1;
+ if (len1 <= 1) // len1 == 1 || len1 == 0
+ break outer;
+ }
+ s.copyElement(a, cursor2++, a, dest++);
+ if (--len2 == 0)
+ break outer;
+
+ count2 = gallopLeft(s.getKey(tmp, cursor1, key0), a, cursor2, len2, 0, c);
+ if (count2 != 0) {
+ s.copyRange(a, cursor2, a, dest, count2);
+ dest += count2;
+ cursor2 += count2;
+ len2 -= count2;
+ if (len2 == 0)
+ break outer;
+ }
+ s.copyElement(tmp, cursor1++, a, dest++);
+ if (--len1 == 1)
+ break outer;
+ minGallop--;
+ } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
+ if (minGallop < 0)
+ minGallop = 0;
+ minGallop += 2; // Penalize for leaving gallop mode
+ } // End of "outer" loop
+ this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
+
+ if (len1 == 1) {
+ assert len2 > 0;
+ s.copyRange(a, cursor2, a, dest, len2);
+ s.copyElement(tmp, cursor1, a, dest + len2); // Last elt of run 1 to end of merge
+ } else if (len1 == 0) {
+ throw new IllegalArgumentException(
+ "Comparison method violates its general contract!");
+ } else {
+ assert len2 == 0;
+ assert len1 > 1;
+ s.copyRange(tmp, cursor1, a, dest, len1);
+ }
+ }
+
+ /**
+ * Like mergeLo, except that this method should be called only if
+ * len1 >= len2; mergeLo should be called if len1 <= len2. (Either method
+ * may be called if len1 == len2.)
+ *
+ * @param base1 index of first element in first run to be merged
+ * @param len1 length of first run to be merged (must be > 0)
+ * @param base2 index of first element in second run to be merged
+ * (must be aBase + aLen)
+ * @param len2 length of second run to be merged (must be > 0)
+ */
+ private void mergeHi(int base1, int len1, int base2, int len2) {
+ assert len1 > 0 && len2 > 0 && base1 + len1 == base2;
+
+ // Copy second run into temp array
+ Buffer a = this.a; // For performance
+ Buffer tmp = ensureCapacity(len2);
+ s.copyRange(a, base2, tmp, 0, len2);
+
+ int cursor1 = base1 + len1 - 1; // Indexes into a
+ int cursor2 = len2 - 1; // Indexes into tmp array
+ int dest = base2 + len2 - 1; // Indexes into a
+
+ K key0 = s.newKey();
+ K key1 = s.newKey();
+
+ // Move last element of first run and deal with degenerate cases
+ s.copyElement(a, cursor1--, a, dest--);
+ if (--len1 == 0) {
+ s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
+ return;
+ }
+ if (len2 == 1) {
+ dest -= len1;
+ cursor1 -= len1;
+ s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
+ s.copyElement(tmp, cursor2, a, dest);
+ return;
+ }
+
+ Comparator super K> c = this.c; // Use local variable for performance
+ int minGallop = this.minGallop; // " " " " "
+ outer:
+ while (true) {
+ int count1 = 0; // Number of times in a row that first run won
+ int count2 = 0; // Number of times in a row that second run won
+
+ /*
+ * Do the straightforward thing until (if ever) one run
+ * appears to win consistently.
+ */
+ do {
+ assert len1 > 0 && len2 > 1;
+ if (c.compare(s.getKey(tmp, cursor2, key0), s.getKey(a, cursor1, key1)) < 0) {
+ s.copyElement(a, cursor1--, a, dest--);
+ count1++;
+ count2 = 0;
+ if (--len1 == 0)
+ break outer;
+ } else {
+ s.copyElement(tmp, cursor2--, a, dest--);
+ count2++;
+ count1 = 0;
+ if (--len2 == 1)
+ break outer;
+ }
+ } while ((count1 | count2) < minGallop);
+
+ /*
+ * One run is winning so consistently that galloping may be a
+ * huge win. So try that, and continue galloping until (if ever)
+ * neither run appears to be winning consistently anymore.
+ */
+ do {
+ assert len1 > 0 && len2 > 1;
+ count1 = len1 - gallopRight(s.getKey(tmp, cursor2, key0), a, base1, len1, len1 - 1, c);
+ if (count1 != 0) {
+ dest -= count1;
+ cursor1 -= count1;
+ len1 -= count1;
+ s.copyRange(a, cursor1 + 1, a, dest + 1, count1);
+ if (len1 == 0)
+ break outer;
+ }
+ s.copyElement(tmp, cursor2--, a, dest--);
+ if (--len2 == 1)
+ break outer;
+
+ count2 = len2 - gallopLeft(s.getKey(a, cursor1, key0), tmp, 0, len2, len2 - 1, c);
+ if (count2 != 0) {
+ dest -= count2;
+ cursor2 -= count2;
+ len2 -= count2;
+ s.copyRange(tmp, cursor2 + 1, a, dest + 1, count2);
+ if (len2 <= 1) // len2 == 1 || len2 == 0
+ break outer;
+ }
+ s.copyElement(a, cursor1--, a, dest--);
+ if (--len1 == 0)
+ break outer;
+ minGallop--;
+ } while (count1 >= MIN_GALLOP | count2 >= MIN_GALLOP);
+ if (minGallop < 0)
+ minGallop = 0;
+ minGallop += 2; // Penalize for leaving gallop mode
+ } // End of "outer" loop
+ this.minGallop = minGallop < 1 ? 1 : minGallop; // Write back to field
+
+ if (len2 == 1) {
+ assert len1 > 0;
+ dest -= len1;
+ cursor1 -= len1;
+ s.copyRange(a, cursor1 + 1, a, dest + 1, len1);
+ s.copyElement(tmp, cursor2, a, dest); // Move first elt of run2 to front of merge
+ } else if (len2 == 0) {
+ throw new IllegalArgumentException(
+ "Comparison method violates its general contract!");
+ } else {
+ assert len1 == 0;
+ assert len2 > 0;
+ s.copyRange(tmp, 0, a, dest - (len2 - 1), len2);
+ }
+ }
+
+ /**
+ * Ensures that the external array tmp has at least the specified
+ * number of elements, increasing its size if necessary. The size
+ * increases exponentially to ensure amortized linear time complexity.
+ *
+ * @param minCapacity the minimum required capacity of the tmp array
+ * @return tmp, whether or not it grew
+ */
+ private Buffer ensureCapacity(int minCapacity) {
+ if (tmpLength < minCapacity) {
+ // Compute smallest power of 2 > minCapacity
+ int newSize = minCapacity;
+ newSize |= newSize >> 1;
+ newSize |= newSize >> 2;
+ newSize |= newSize >> 4;
+ newSize |= newSize >> 8;
+ newSize |= newSize >> 16;
+ newSize++;
+
+ if (newSize < 0) // Not bloody likely!
+ newSize = minCapacity;
+ else
+ newSize = Math.min(newSize, aLength >>> 1);
+
+ tmp = s.allocate(newSize);
+ tmpLength = newSize;
+ }
+ return tmp;
+ }
+ }
+}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
new file mode 100644
index 0000000000000..c5936b5038ac9
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -0,0 +1,53 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Register functions to show/hide columns based on checkboxes. These need
+ * to be registered after the page loads. */
+$(function() {
+ $("span.expand-additional-metrics").click(function(){
+ // Expand the list of additional metrics.
+ var additionalMetricsDiv = $(this).parent().find('.additional-metrics');
+ $(additionalMetricsDiv).toggleClass('collapsed');
+
+ // Switch the class of the arrow from open to closed.
+ $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open');
+ $(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed');
+
+ // If clicking caused the metrics to expand, automatically check all options for additional
+ // metrics (don't trigger a click when collapsing metrics, because it leads to weird
+ // toggling behavior).
+ if (!$(additionalMetricsDiv).hasClass('collapsed')) {
+ $(this).parent().find('input:checkbox:not(:checked)').trigger('click');
+ }
+ });
+
+ $("input:checkbox:not(:checked)").each(function() {
+ var column = "table ." + $(this).attr("name");
+ $(column).hide();
+ });
+
+ $("input:checkbox").click(function() {
+ var column = "table ." + $(this).attr("name");
+ $(column).toggle();
+ stripeTables();
+ });
+
+ // Trigger a click on the checkbox if a user clicks the label next to it.
+ $("span.additional-metric-title").click(function() {
+ $(this).parent().find('input:checkbox').trigger('click');
+ });
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js
new file mode 100644
index 0000000000000..32187ba6e8df0
--- /dev/null
+++ b/core/src/main/resources/org/apache/spark/ui/static/table.js
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* Adds background colors to stripe table rows. This is necessary (instead of using css or the
+ * table striping provided by bootstrap) to appropriately stripe tables with hidden rows. */
+function stripeTables() {
+ $("table.table-striped-custom").each(function() {
+ $(this).find("tr:not(:hidden)").each(function (index) {
+ if (index % 2 == 1) {
+ $(this).css("background-color", "#f9f9f9");
+ } else {
+ $(this).css("background-color", "#ffffff");
+ }
+ });
+ });
+}
+
+/* Stripe all tables after pages finish loading. */
+$(function() {
+ stripeTables();
+});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index 152bde5f6994f..a2220e761ac98 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -120,7 +120,37 @@ pre {
border: none;
}
+span.expand-additional-metrics {
+ cursor: pointer;
+}
+
+span.additional-metric-title {
+ cursor: pointer;
+}
+
+.additional-metrics.collapsed {
+ display: none;
+}
+
.tooltip {
font-weight: normal;
}
+.arrow-open {
+ width: 0;
+ height: 0;
+ border-left: 5px solid transparent;
+ border-right: 5px solid transparent;
+ border-top: 5px solid black;
+ float: left;
+ margin-top: 6px;
+}
+
+.arrow-closed {
+ width: 0;
+ height: 0;
+ border-top: 5px solid transparent;
+ border-bottom: 5px solid transparent;
+ border-left: 5px solid black;
+ display: inline-block;
+}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 12f2fe031cb1d..2301caafb07ff 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -24,6 +24,7 @@ import scala.collection.mutable.Map
import scala.reflect.ClassTag
import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.util.Utils
/**
* A data type that can be accumulated, ie has an commutative and associative "add" operation,
@@ -126,7 +127,7 @@ class Accumulable[R, T] (
}
// Called by Java when deserializing an object
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
value_ = zero
deserialized = true
diff --git a/core/src/main/scala/org/apache/spark/CacheManager.scala b/core/src/main/scala/org/apache/spark/CacheManager.scala
index d89bb50076c9a..80da62c44edc5 100644
--- a/core/src/main/scala/org/apache/spark/CacheManager.scala
+++ b/core/src/main/scala/org/apache/spark/CacheManager.scala
@@ -61,7 +61,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
val computedValues = rdd.computeOrReadCheckpoint(partition, context)
// If the task is running locally, do not persist the result
- if (context.runningLocally) {
+ if (context.isRunningLocally) {
return computedValues
}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
new file mode 100644
index 0000000000000..c11f1db0064fd
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -0,0 +1,462 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.mutable
+
+import org.apache.spark.scheduler._
+
+/**
+ * An agent that dynamically allocates and removes executors based on the workload.
+ *
+ * The add policy depends on whether there are backlogged tasks waiting to be scheduled. If
+ * the scheduler queue is not drained in N seconds, then new executors are added. If the queue
+ * persists for another M seconds, then more executors are added and so on. The number added
+ * in each round increases exponentially from the previous round until an upper bound on the
+ * number of executors has been reached.
+ *
+ * The rationale for the exponential increase is twofold: (1) Executors should be added slowly
+ * in the beginning in case the number of extra executors needed turns out to be small. Otherwise,
+ * we may add more executors than we need just to remove them later. (2) Executors should be added
+ * quickly over time in case the maximum number of executors is very high. Otherwise, it will take
+ * a long time to ramp up under heavy workloads.
+ *
+ * The remove policy is simpler: If an executor has been idle for K seconds, meaning it has not
+ * been scheduled to run any tasks, then it is removed.
+ *
+ * There is no retry logic in either case because we make the assumption that the cluster manager
+ * will eventually fulfill all requests it receives asynchronously.
+ *
+ * The relevant Spark properties include the following:
+ *
+ * spark.dynamicAllocation.enabled - Whether this feature is enabled
+ * spark.dynamicAllocation.minExecutors - Lower bound on the number of executors
+ * spark.dynamicAllocation.maxExecutors - Upper bound on the number of executors
+ *
+ * spark.dynamicAllocation.schedulerBacklogTimeout (M) -
+ * If there are backlogged tasks for this duration, add new executors
+ *
+ * spark.dynamicAllocation.sustainedSchedulerBacklogTimeout (N) -
+ * If the backlog is sustained for this duration, add more executors
+ * This is used only after the initial backlog timeout is exceeded
+ *
+ * spark.dynamicAllocation.executorIdleTimeout (K) -
+ * If an executor has been idle for this duration, remove it
+ */
+private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging {
+ import ExecutorAllocationManager._
+
+ private val conf = sc.conf
+
+ // Lower and upper bounds on the number of executors. These are required.
+ private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1)
+ private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1)
+ verifyBounds()
+
+ // How long there must be backlogged tasks for before an addition is triggered
+ private val schedulerBacklogTimeout = conf.getLong(
+ "spark.dynamicAllocation.schedulerBacklogTimeout", 60)
+
+ // Same as above, but used only after `schedulerBacklogTimeout` is exceeded
+ private val sustainedSchedulerBacklogTimeout = conf.getLong(
+ "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout)
+
+ // How long an executor must be idle for before it is removed
+ private val removeThresholdSeconds = conf.getLong(
+ "spark.dynamicAllocation.executorIdleTimeout", 600)
+
+ // Number of executors to add in the next round
+ private var numExecutorsToAdd = 1
+
+ // Number of executors that have been requested but have not registered yet
+ private var numExecutorsPending = 0
+
+ // Executors that have been requested to be removed but have not been killed yet
+ private val executorsPendingToRemove = new mutable.HashSet[String]
+
+ // All known executors
+ private val executorIds = new mutable.HashSet[String]
+
+ // A timestamp of when an addition should be triggered, or NOT_SET if it is not set
+ // This is set when pending tasks are added but not scheduled yet
+ private var addTime: Long = NOT_SET
+
+ // A timestamp for each executor of when the executor should be removed, indexed by the ID
+ // This is set when an executor is no longer running a task, or when it first registers
+ private val removeTimes = new mutable.HashMap[String, Long]
+
+ // Polling loop interval (ms)
+ private val intervalMillis: Long = 100
+
+ // Whether we are testing this class. This should only be used internally.
+ private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
+
+ // Clock used to schedule when executors should be added and removed
+ private var clock: Clock = new RealClock
+
+ /**
+ * Verify that the lower and upper bounds on the number of executors are valid.
+ * If not, throw an appropriate exception.
+ */
+ private def verifyBounds(): Unit = {
+ if (minNumExecutors < 0 || maxNumExecutors < 0) {
+ throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!")
+ }
+ if (minNumExecutors == 0 || maxNumExecutors == 0) {
+ throw new SparkException("spark.dynamicAllocation.{min/max}Executors cannot be 0!")
+ }
+ if (minNumExecutors > maxNumExecutors) {
+ throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " +
+ s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!")
+ }
+ }
+
+ /**
+ * Use a different clock for this allocation manager. This is mainly used for testing.
+ */
+ def setClock(newClock: Clock): Unit = {
+ clock = newClock
+ }
+
+ /**
+ * Register for scheduler callbacks to decide when to add and remove executors.
+ */
+ def start(): Unit = {
+ val listener = new ExecutorAllocationListener(this)
+ sc.addSparkListener(listener)
+ startPolling()
+ }
+
+ /**
+ * Start the main polling thread that keeps track of when to add and remove executors.
+ */
+ private def startPolling(): Unit = {
+ val t = new Thread {
+ override def run(): Unit = {
+ while (true) {
+ try {
+ schedule()
+ } catch {
+ case e: Exception => logError("Exception in dynamic executor allocation thread!", e)
+ }
+ Thread.sleep(intervalMillis)
+ }
+ }
+ }
+ t.setName("spark-dynamic-executor-allocation")
+ t.setDaemon(true)
+ t.start()
+ }
+
+ /**
+ * If the add time has expired, request new executors and refresh the add time.
+ * If the remove time for an existing executor has expired, kill the executor.
+ * This is factored out into its own method for testing.
+ */
+ private def schedule(): Unit = synchronized {
+ val now = clock.getTimeMillis
+ if (addTime != NOT_SET && now >= addTime) {
+ addExecutors()
+ logDebug(s"Starting timer to add more executors (to " +
+ s"expire in $sustainedSchedulerBacklogTimeout seconds)")
+ addTime += sustainedSchedulerBacklogTimeout * 1000
+ }
+
+ removeTimes.foreach { case (executorId, expireTime) =>
+ if (now >= expireTime) {
+ removeExecutor(executorId)
+ removeTimes.remove(executorId)
+ }
+ }
+ }
+
+ /**
+ * Request a number of executors from the cluster manager.
+ * If the cap on the number of executors is reached, give up and reset the
+ * number of executors to add next round instead of continuing to double it.
+ * Return the number actually requested.
+ */
+ private def addExecutors(): Int = synchronized {
+ // Do not request more executors if we have already reached the upper bound
+ val numExistingExecutors = executorIds.size + numExecutorsPending
+ if (numExistingExecutors >= maxNumExecutors) {
+ logDebug(s"Not adding executors because there are already ${executorIds.size} " +
+ s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)")
+ numExecutorsToAdd = 1
+ return 0
+ }
+
+ // Request executors with respect to the upper bound
+ val actualNumExecutorsToAdd =
+ if (numExistingExecutors + numExecutorsToAdd <= maxNumExecutors) {
+ numExecutorsToAdd
+ } else {
+ maxNumExecutors - numExistingExecutors
+ }
+ val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd
+ val addRequestAcknowledged = testing || sc.requestExecutors(actualNumExecutorsToAdd)
+ if (addRequestAcknowledged) {
+ logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " +
+ s"tasks are backlogged (new desired total will be $newTotalExecutors)")
+ numExecutorsToAdd =
+ if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1
+ numExecutorsPending += actualNumExecutorsToAdd
+ actualNumExecutorsToAdd
+ } else {
+ logWarning(s"Unable to reach the cluster manager " +
+ s"to request $actualNumExecutorsToAdd executors!")
+ 0
+ }
+ }
+
+ /**
+ * Request the cluster manager to remove the given executor.
+ * Return whether the request is received.
+ */
+ private def removeExecutor(executorId: String): Boolean = synchronized {
+ // Do not kill the executor if we are not aware of it (should never happen)
+ if (!executorIds.contains(executorId)) {
+ logWarning(s"Attempted to remove unknown executor $executorId!")
+ return false
+ }
+
+ // Do not kill the executor again if it is already pending to be killed (should never happen)
+ if (executorsPendingToRemove.contains(executorId)) {
+ logWarning(s"Attempted to remove executor $executorId " +
+ s"when it is already pending to be removed!")
+ return false
+ }
+
+ // Do not kill the executor if we have already reached the lower bound
+ val numExistingExecutors = executorIds.size - executorsPendingToRemove.size
+ if (numExistingExecutors - 1 < minNumExecutors) {
+ logInfo(s"Not removing idle executor $executorId because there are only " +
+ s"$numExistingExecutors executor(s) left (limit $minNumExecutors)")
+ return false
+ }
+
+ // Send a request to the backend to kill this executor
+ val removeRequestAcknowledged = testing || sc.killExecutor(executorId)
+ if (removeRequestAcknowledged) {
+ logInfo(s"Removing executor $executorId because it has been idle for " +
+ s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})")
+ executorsPendingToRemove.add(executorId)
+ true
+ } else {
+ logWarning(s"Unable to reach the cluster manager to kill executor $executorId!")
+ false
+ }
+ }
+
+ /**
+ * Callback invoked when the specified executor has been added.
+ */
+ private def onExecutorAdded(executorId: String): Unit = synchronized {
+ if (!executorIds.contains(executorId)) {
+ executorIds.add(executorId)
+ executorIds.foreach(onExecutorIdle)
+ logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})")
+ if (numExecutorsPending > 0) {
+ numExecutorsPending -= 1
+ logDebug(s"Decremented number of pending executors ($numExecutorsPending left)")
+ }
+ } else {
+ logWarning(s"Duplicate executor $executorId has registered")
+ }
+ }
+
+ /**
+ * Callback invoked when the specified executor has been removed.
+ */
+ private def onExecutorRemoved(executorId: String): Unit = synchronized {
+ if (executorIds.contains(executorId)) {
+ executorIds.remove(executorId)
+ removeTimes.remove(executorId)
+ logInfo(s"Existing executor $executorId has been removed (new total is ${executorIds.size})")
+ if (executorsPendingToRemove.contains(executorId)) {
+ executorsPendingToRemove.remove(executorId)
+ logDebug(s"Executor $executorId is no longer pending to " +
+ s"be removed (${executorsPendingToRemove.size} left)")
+ }
+ } else {
+ logWarning(s"Unknown executor $executorId has been removed!")
+ }
+ }
+
+ /**
+ * Callback invoked when the scheduler receives new pending tasks.
+ * This sets a time in the future that decides when executors should be added
+ * if it is not already set.
+ */
+ private def onSchedulerBacklogged(): Unit = synchronized {
+ if (addTime == NOT_SET) {
+ logDebug(s"Starting timer to add executors because pending tasks " +
+ s"are building up (to expire in $schedulerBacklogTimeout seconds)")
+ addTime = clock.getTimeMillis + schedulerBacklogTimeout * 1000
+ }
+ }
+
+ /**
+ * Callback invoked when the scheduler queue is drained.
+ * This resets all variables used for adding executors.
+ */
+ private def onSchedulerQueueEmpty(): Unit = synchronized {
+ logDebug(s"Clearing timer to add executors because there are no more pending tasks")
+ addTime = NOT_SET
+ numExecutorsToAdd = 1
+ }
+
+ /**
+ * Callback invoked when the specified executor is no longer running any tasks.
+ * This sets a time in the future that decides when this executor should be removed if
+ * the executor is not already marked as idle.
+ */
+ private def onExecutorIdle(executorId: String): Unit = synchronized {
+ if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
+ logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
+ s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)")
+ removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000
+ }
+ }
+
+ /**
+ * Callback invoked when the specified executor is now running a task.
+ * This resets all variables used for removing this executor.
+ */
+ private def onExecutorBusy(executorId: String): Unit = synchronized {
+ logDebug(s"Clearing idle timer for $executorId because it is now running a task")
+ removeTimes.remove(executorId)
+ }
+
+ /**
+ * A listener that notifies the given allocation manager of when to add and remove executors.
+ *
+ * This class is intentionally conservative in its assumptions about the relative ordering
+ * and consistency of events returned by the listener. For simplicity, it does not account
+ * for speculated tasks.
+ */
+ private class ExecutorAllocationListener(allocationManager: ExecutorAllocationManager)
+ extends SparkListener {
+
+ private val stageIdToNumTasks = new mutable.HashMap[Int, Int]
+ private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
+ private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]]
+
+ override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
+ synchronized {
+ val stageId = stageSubmitted.stageInfo.stageId
+ val numTasks = stageSubmitted.stageInfo.numTasks
+ stageIdToNumTasks(stageId) = numTasks
+ allocationManager.onSchedulerBacklogged()
+ }
+ }
+
+ override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
+ synchronized {
+ val stageId = stageCompleted.stageInfo.stageId
+ stageIdToNumTasks -= stageId
+ stageIdToTaskIndices -= stageId
+
+ // If this is the last stage with pending tasks, mark the scheduler queue as empty
+ // This is needed in case the stage is aborted for any reason
+ if (stageIdToNumTasks.isEmpty) {
+ allocationManager.onSchedulerQueueEmpty()
+ }
+ }
+ }
+
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
+ val stageId = taskStart.stageId
+ val taskId = taskStart.taskInfo.taskId
+ val taskIndex = taskStart.taskInfo.index
+ val executorId = taskStart.taskInfo.executorId
+
+ // If this is the last pending task, mark the scheduler queue as empty
+ stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
+ val numTasksScheduled = stageIdToTaskIndices(stageId).size
+ val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
+ if (numTasksScheduled == numTasksTotal) {
+ // No more pending tasks for this stage
+ stageIdToNumTasks -= stageId
+ if (stageIdToNumTasks.isEmpty) {
+ allocationManager.onSchedulerQueueEmpty()
+ }
+ }
+
+ // Mark the executor on which this task is scheduled as busy
+ executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId
+ allocationManager.onExecutorBusy(executorId)
+ }
+
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
+ val executorId = taskEnd.taskInfo.executorId
+ val taskId = taskEnd.taskInfo.taskId
+
+ // If the executor is no longer running scheduled any tasks, mark it as idle
+ if (executorIdToTaskIds.contains(executorId)) {
+ executorIdToTaskIds(executorId) -= taskId
+ if (executorIdToTaskIds(executorId).isEmpty) {
+ executorIdToTaskIds -= executorId
+ allocationManager.onExecutorIdle(executorId)
+ }
+ }
+ }
+
+ override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = {
+ val executorId = blockManagerAdded.blockManagerId.executorId
+ if (executorId != SparkContext.DRIVER_IDENTIFIER) {
+ allocationManager.onExecutorAdded(executorId)
+ }
+ }
+
+ override def onBlockManagerRemoved(
+ blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = {
+ allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId)
+ }
+ }
+
+}
+
+private object ExecutorAllocationManager {
+ val NOT_SET = Long.MaxValue
+}
+
+/**
+ * An abstract clock for measuring elapsed time.
+ */
+private trait Clock {
+ def getTimeMillis: Long
+}
+
+/**
+ * A clock backed by a monotonically increasing time source.
+ * The time returned by this clock does not correspond to any notion of wall-clock time.
+ */
+private class RealClock extends Clock {
+ override def getTimeMillis: Long = System.nanoTime / (1000 * 1000)
+}
+
+/**
+ * A clock that allows the caller to customize the time.
+ * This is used mainly for testing.
+ */
+private class TestClock(startTimeMillis: Long) extends Clock {
+ private var time: Long = startTimeMillis
+ override def getTimeMillis: Long = time
+ def tick(ms: Long): Unit = { time += ms }
+}
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index d5c8f9d76c476..e97a7375a267b 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -210,7 +210,11 @@ class ComplexFutureAction[T] extends FutureAction[T] {
} catch {
case e: Exception => p.failure(e)
} finally {
- thread = null
+ // This lock guarantees when calling `thread.interrupt()` in `cancel`,
+ // thread won't be set to null.
+ ComplexFutureAction.this.synchronized {
+ thread = null
+ }
}
}
this
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 4cb0bd4142435..7d96962c4acd7 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -178,6 +178,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses)
}
} else {
+ logError("Missing all output locations for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing all output locations for shuffle " + shuffleId)
}
@@ -348,7 +349,7 @@ private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTr
new ConcurrentHashMap[Int, Array[MapStatus]]
}
-private[spark] object MapOutputTracker {
+private[spark] object MapOutputTracker extends Logging {
// Serialize an array of map output locations into an efficient byte format so that we can send
// it to reduce tasks. We do this by compressing the serialized bytes using GZIP. They will
@@ -381,6 +382,7 @@ private[spark] object MapOutputTracker {
statuses.map {
status =>
if (status == null) {
+ logError("Missing an output location for shuffle " + shuffleId)
throw new MetadataFetchFailedException(
shuffleId, reduceId, "Missing an output location for shuffle " + shuffleId)
} else {
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 37053bb6f37ad..e53a78ead2c0e 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -204,7 +204,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
@throws(classOf[IOException])
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => out.defaultWriteObject()
@@ -222,7 +222,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
@throws(classOf[IOException])
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
case js: JavaSerializer => in.defaultReadObject()
diff --git a/core/src/main/scala/org/apache/spark/SerializableWritable.scala b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
index e50b9ac2291f9..55cb25946c2ad 100644
--- a/core/src/main/scala/org/apache/spark/SerializableWritable.scala
+++ b/core/src/main/scala/org/apache/spark/SerializableWritable.scala
@@ -24,18 +24,19 @@ import org.apache.hadoop.io.ObjectWritable
import org.apache.hadoop.io.Writable
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
@DeveloperApi
class SerializableWritable[T <: Writable](@transient var t: T) extends Serializable {
def value = t
override def toString = t.toString
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
out.defaultWriteObject()
new ObjectWritable(t).write(out)
}
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
val ow = new ObjectWritable()
ow.setConf(new Configuration())
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 605df0e929faa..ad0a9017afead 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -18,7 +18,8 @@
package org.apache.spark
import scala.collection.JavaConverters._
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{HashMap, LinkedHashSet}
+import org.apache.spark.serializer.KryoSerializer
/**
* Configuration for a Spark application. Used to set various Spark parameters as key-value pairs.
@@ -140,6 +141,20 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
this
}
+ /**
+ * Use Kryo serialization and register the given set of classes with Kryo.
+ * If called multiple times, this will append the classes from all calls together.
+ */
+ def registerKryoClasses(classes: Array[Class[_]]): SparkConf = {
+ val allClassNames = new LinkedHashSet[String]()
+ allClassNames ++= get("spark.kryo.classesToRegister", "").split(',').filter(!_.isEmpty)
+ allClassNames ++= classes.map(_.getName)
+
+ set("spark.kryo.classesToRegister", allClassNames.mkString(","))
+ set("spark.serializer", classOf[KryoSerializer].getName)
+ this
+ }
+
/** Remove a parameter from the configuration */
def remove(key: String): SparkConf = {
settings.remove(key)
@@ -229,6 +244,19 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
val executorClasspathKey = "spark.executor.extraClassPath"
val driverOptsKey = "spark.driver.extraJavaOptions"
val driverClassPathKey = "spark.driver.extraClassPath"
+ val driverLibraryPathKey = "spark.driver.extraLibraryPath"
+
+ // Used by Yarn in 1.1 and before
+ sys.props.get("spark.driver.libraryPath").foreach { value =>
+ val warning =
+ s"""
+ |spark.driver.libraryPath was detected (set to '$value').
+ |This is deprecated in Spark 1.2+.
+ |
+ |Please instead use: $driverLibraryPathKey
+ """.stripMargin
+ logWarning(warning)
+ }
// Validate spark.executor.extraJavaOptions
settings.get(executorOptsKey).map { javaOpts =>
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index dd3157990ef2d..40444c237b738 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -21,12 +21,10 @@ import scala.language.implicitConversions
import java.io._
import java.net.URI
-import java.util.Arrays
+import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
-import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
-import scala.collection.JavaConversions._
import scala.collection.generic.Growable
import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
@@ -42,7 +40,8 @@ import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
-import org.apache.spark.input.WholeTextFileInputFormat
+import org.apache.spark.executor.TriggerThreadDump
+import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
import org.apache.spark.scheduler._
@@ -51,7 +50,8 @@ import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, Me
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage._
import org.apache.spark.ui.SparkUI
-import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
+import org.apache.spark.ui.jobs.JobProgressListener
+import org.apache.spark.util._
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -61,7 +61,7 @@ import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, Metadat
* this config overrides the default configs as well as system properties.
*/
-class SparkContext(config: SparkConf) extends Logging {
+class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
// This is used only by YARN for now, but should be relevant to other cluster types (Mesos,
// etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It
@@ -209,16 +209,10 @@ class SparkContext(config: SparkConf) extends Logging {
// An asynchronous listener bus for Spark events
private[spark] val listenerBus = new LiveListenerBus
- // Create the Spark execution environment (cache, map output tracker, etc)
conf.set("spark.executor.id", "driver")
- private[spark] val env = SparkEnv.create(
- conf,
- "",
- conf.get("spark.driver.host"),
- conf.get("spark.driver.port").toInt,
- isDriver = true,
- isLocal = isLocal,
- listenerBus = listenerBus)
+
+ // Create the Spark execution environment (cache, map output tracker, etc)
+ private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
@@ -230,15 +224,24 @@ class SparkContext(config: SparkConf) extends Logging {
private[spark] val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf)
- // Initialize the Spark UI, registering all associated listeners
+
+ private[spark] val jobProgressListener = new JobProgressListener(conf)
+ listenerBus.addListener(jobProgressListener)
+
+ // Initialize the Spark UI
private[spark] val ui: Option[SparkUI] =
if (conf.getBoolean("spark.ui.enabled", true)) {
- Some(new SparkUI(this))
+ Some(SparkUI.createLiveUI(this, conf, listenerBus, jobProgressListener,
+ env.securityManager,appName))
} else {
// For tests, do not enable the UI
None
}
+ // Bind the UI before starting the task scheduler to communicate
+ // the bound port to the cluster manager properly
+ ui.foreach(_.bind())
+
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf)
@@ -291,7 +294,8 @@ class SparkContext(config: SparkConf) extends Logging {
executorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler
- private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master)
+ private[spark] var (schedulerBackend, taskScheduler) =
+ SparkContext.createTaskScheduler(this, master)
private val heartbeatReceiver = env.actorSystem.actorOf(
Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver")
@volatile private[spark] var dagScheduler: DAGScheduler = _
@@ -326,6 +330,15 @@ class SparkContext(config: SparkConf) extends Logging {
} else None
}
+ // Optionally scale number of executors dynamically based on workload. Exposed for testing.
+ private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] =
+ if (conf.getBoolean("spark.dynamicAllocation.enabled", false)) {
+ Some(new ExecutorAllocationManager(this))
+ } else {
+ None
+ }
+ executorAllocationManager.foreach(_.start())
+
// At this point, all relevant SparkListeners have been registered, so begin releasing events
listenerBus.start()
@@ -341,10 +354,6 @@ class SparkContext(config: SparkConf) extends Logging {
postEnvironmentUpdate()
postApplicationStart()
- // Bind the SparkUI after starting the task scheduler
- // because certain pages and listeners depend on it
- ui.foreach(_.bind())
-
private[spark] var checkpointDir: Option[String] = None
// Thread Local variable that can be used by users to pass information down the stack
@@ -352,6 +361,29 @@ class SparkContext(config: SparkConf) extends Logging {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ /**
+ * Called by the web UI to obtain executor thread dumps. This method may be expensive.
+ * Logs an error and returns None if we failed to obtain a thread dump, which could occur due
+ * to an executor being dead or unresponsive or due to network issues while sending the thread
+ * dump message back to the driver.
+ */
+ private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = {
+ try {
+ if (executorId == SparkContext.DRIVER_IDENTIFIER) {
+ Some(Utils.getThreadDump())
+ } else {
+ val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
+ val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
+ Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
+ AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Exception getting thread dump from executor $executorId", e)
+ None
+ }
+ }
+
private[spark] def getLocalProperties: Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
@@ -524,6 +556,69 @@ class SparkContext(config: SparkConf) extends Logging {
minPartitions).setName(path)
}
+
+ /**
+ * Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file
+ * (useful for binary data)
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `val rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ *
+ * @note Small files are preferred; very large files may cause bad performance.
+ */
+ @Experimental
+ def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
+ RDD[(String, PortableDataStream)] = {
+ val job = new NewHadoopJob(hadoopConfiguration)
+ NewFileInputFormat.addInputPath(job, new Path(path))
+ val updateConf = job.getConfiguration
+ new BinaryFileRDD(
+ this,
+ classOf[StreamInputFormat],
+ classOf[String],
+ classOf[PortableDataStream],
+ updateConf,
+ minPartitions).setName(path)
+ }
+
+ /**
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @param recordLength The length at which to split the records
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ @Experimental
+ def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
+ : RDD[Array[Byte]] = {
+ conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
+ val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
+ classOf[FixedLengthBinaryInputFormat],
+ classOf[LongWritable],
+ classOf[BytesWritable],
+ conf=conf)
+ val data = br.map{ case (k, v) => v.getBytes}
+ data
+ }
+
/**
* Get an RDD for a Hadoop-readable dataset from a Hadoop JobConf given its InputFormat and other
* necessary info (e.g. file name for a filesystem-based dataset, table name for HyperTable),
@@ -837,11 +932,12 @@ class SparkContext(config: SparkConf) extends Logging {
case "local" => "file:" + uri.getPath
case _ => path
}
- addedFiles(key) = System.currentTimeMillis
+ val timestamp = System.currentTimeMillis
+ addedFiles(key) = timestamp
// Fetch the file locally in case a job is executed using DAGScheduler.runLocally().
Utils.fetchFile(path, new File(SparkFiles.getRootDirectory()), conf, env.securityManager,
- hadoopConfiguration)
+ hadoopConfiguration, timestamp, useCache = false)
logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
postEnvironmentUpdate()
@@ -856,71 +952,48 @@ class SparkContext(config: SparkConf) extends Logging {
listenerBus.addListener(listener)
}
- /** The version of Spark on which this application is running. */
- def version = SPARK_VERSION
-
- /**
- * Return a map from the slave to the max memory available for caching and the remaining
- * memory available for caching.
- */
- def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
- env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
- (blockManagerId.host + ":" + blockManagerId.port, mem)
- }
- }
-
/**
* :: DeveloperApi ::
- * Return information about what RDDs are cached, if they are in mem or on disk, how much space
- * they take, etc.
+ * Request an additional number of executors from the cluster manager.
+ * This is currently only supported in Yarn mode. Return whether the request is received.
*/
@DeveloperApi
- def getRDDStorageInfo: Array[RDDInfo] = {
- val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
- StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
- rddInfos.filter(_.isCached)
- }
-
- /**
- * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call.
- * Note that this does not necessarily mean the caching or computation was successful.
- */
- def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap
-
- /**
- * :: DeveloperApi ::
- * Return information about blocks stored in all of the slaves
- */
- @DeveloperApi
- def getExecutorStorageStatus: Array[StorageStatus] = {
- env.blockManager.master.getStorageStatus
+ def requestExecutors(numAdditionalExecutors: Int): Boolean = {
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.requestExecutors(numAdditionalExecutors)
+ case _ =>
+ logWarning("Requesting executors is only supported in coarse-grained mode")
+ false
+ }
}
/**
* :: DeveloperApi ::
- * Return pools for fair scheduler
+ * Request that the cluster manager kill the specified executors.
+ * This is currently only supported in Yarn mode. Return whether the request is received.
*/
@DeveloperApi
- def getAllPools: Seq[Schedulable] = {
- // TODO(xiajunluan): We should take nested pools into account
- taskScheduler.rootPool.schedulableQueue.toSeq
+ def killExecutors(executorIds: Seq[String]): Boolean = {
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.killExecutors(executorIds)
+ case _ =>
+ logWarning("Killing executors is only supported in coarse-grained mode")
+ false
+ }
}
/**
* :: DeveloperApi ::
- * Return the pool associated with the given name, if one exists
+ * Request that cluster manager the kill the specified executor.
+ * This is currently only supported in Yarn mode. Return whether the request is received.
*/
@DeveloperApi
- def getPoolForName(pool: String): Option[Schedulable] = {
- Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
- }
+ def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId))
- /**
- * Return current scheduling mode
- */
- def getSchedulingMode: SchedulingMode.SchedulingMode = {
- taskScheduler.schedulingMode
- }
+ /** The version of Spark on which this application is running. */
+ def version = SPARK_VERSION
/**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
@@ -1346,6 +1419,8 @@ object SparkContext extends Logging {
private[spark] val SPARK_UNKNOWN_USER = ""
+ private[spark] val DRIVER_IDENTIFIER = ""
+
implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] {
def addInPlace(t1: Double, t2: Double): Double = t1 + t2
def zero(initialValue: Double) = 0.0
@@ -1501,8 +1576,13 @@ object SparkContext extends Logging {
res
}
- /** Creates a task scheduler based on a given master URL. Extracted for testing. */
- private def createTaskScheduler(sc: SparkContext, master: String): TaskScheduler = {
+ /**
+ * Create a task scheduler based on a given master URL.
+ * Return a 2-tuple of the scheduler backend and the task scheduler.
+ */
+ private def createTaskScheduler(
+ sc: SparkContext,
+ master: String): (SchedulerBackend, TaskScheduler) = {
// Regular expression used for local[N] and local[*] master formats
val LOCAL_N_REGEX = """local\[([0-9]+|\*)\]""".r
// Regular expression for local[N, maxRetries], used in tests with failing tasks
@@ -1524,7 +1604,7 @@ object SparkContext extends Logging {
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalBackend(scheduler, 1)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case LOCAL_N_REGEX(threads) =>
def localCpuCount = Runtime.getRuntime.availableProcessors()
@@ -1533,7 +1613,7 @@ object SparkContext extends Logging {
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
def localCpuCount = Runtime.getRuntime.availableProcessors()
@@ -1543,14 +1623,14 @@ object SparkContext extends Logging {
val scheduler = new TaskSchedulerImpl(sc, maxFailures.toInt, isLocal = true)
val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case SPARK_REGEX(sparkUrl) =>
val scheduler = new TaskSchedulerImpl(sc)
val masterUrls = sparkUrl.split(",").map("spark://" + _)
val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
// Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
@@ -1570,7 +1650,7 @@ object SparkContext extends Logging {
backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
localCluster.stop()
}
- scheduler
+ (backend, scheduler)
case "yarn-standalone" | "yarn-cluster" =>
if (master == "yarn-standalone") {
@@ -1599,7 +1679,7 @@ object SparkContext extends Logging {
}
}
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case "yarn-client" =>
val scheduler = try {
@@ -1626,7 +1706,7 @@ object SparkContext extends Logging {
}
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case mesosUrl @ MESOS_REGEX(_) =>
MesosNativeLibrary.load()
@@ -1639,13 +1719,13 @@ object SparkContext extends Logging {
new MesosSchedulerBackend(scheduler, sc, url)
}
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case SIMR_REGEX(simrUrl) =>
val scheduler = new TaskSchedulerImpl(sc)
val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl)
scheduler.initialize(backend)
- scheduler
+ (backend, scheduler)
case _ =>
throw new SparkException("Could not parse Master URL: '" + master + "'")
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index aba713cb4267a..e2f13accdfab5 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,6 +32,7 @@ import org.apache.spark.api.python.PythonWorkerFactory
import org.apache.spark.broadcast.BroadcastManager
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
import org.apache.spark.serializer.Serializer
@@ -68,6 +69,7 @@ class SparkEnv (
val shuffleMemoryManager: ShuffleMemoryManager,
val conf: SparkConf) extends Logging {
+ private[spark] var isStopped = false
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
// A general, soft-reference map for metadata needed during HadoopRDD split computation
@@ -75,6 +77,7 @@ class SparkEnv (
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
private[spark] def stop() {
+ isStopped = true
pythonWorkers.foreach { case(key, worker) => worker.stop() }
Option(httpFileServer).foreach(_.stop())
mapOutputTracker.stop()
@@ -142,14 +145,46 @@ object SparkEnv extends Logging {
env
}
- private[spark] def create(
+ /**
+ * Create a SparkEnv for the driver.
+ */
+ private[spark] def createDriverEnv(
+ conf: SparkConf,
+ isLocal: Boolean,
+ listenerBus: LiveListenerBus): SparkEnv = {
+ assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
+ assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
+ val hostname = conf.get("spark.driver.host")
+ val port = conf.get("spark.driver.port").toInt
+ create(conf, SparkContext.DRIVER_IDENTIFIER, hostname, port, true, isLocal, listenerBus)
+ }
+
+ /**
+ * Create a SparkEnv for an executor.
+ * In coarse-grained mode, the executor provides an actor system that is already instantiated.
+ */
+ private[spark] def createExecutorEnv(
+ conf: SparkConf,
+ executorId: String,
+ hostname: String,
+ port: Int,
+ isLocal: Boolean,
+ actorSystem: ActorSystem = null): SparkEnv = {
+ create(conf, executorId, hostname, port, false, isLocal, defaultActorSystem = actorSystem)
+ }
+
+ /**
+ * Helper method to create a SparkEnv for a driver or an executor.
+ */
+ private def create(
conf: SparkConf,
executorId: String,
hostname: String,
port: Int,
isDriver: Boolean,
isLocal: Boolean,
- listenerBus: LiveListenerBus = null): SparkEnv = {
+ listenerBus: LiveListenerBus = null,
+ defaultActorSystem: ActorSystem = null): SparkEnv = {
// Listener bus is only used on the driver
if (isDriver) {
@@ -157,9 +192,16 @@ object SparkEnv extends Logging {
}
val securityManager = new SecurityManager(conf)
- val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- actorSystemName, hostname, port, conf, securityManager)
+
+ // If an existing actor system is already provided, use it.
+ // This is the case when an executor is launched in coarse-grained mode.
+ val (actorSystem, boundPort) =
+ Option(defaultActorSystem) match {
+ case Some(as) => (as, port)
+ case None =>
+ val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName
+ AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager)
+ }
// Figure out which port Akka actually bound to in case the original port is 0 or occupied.
// This is so that we tell the executors the correct port to connect to.
@@ -231,7 +273,13 @@ object SparkEnv extends Logging {
val shuffleMemoryManager = new ShuffleMemoryManager(conf)
- val blockTransferService = new NioBlockTransferService(conf, securityManager)
+ val blockTransferService =
+ conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
+ case "netty" =>
+ new NettyBlockTransferService(conf)
+ case "nio" =>
+ new NioBlockTransferService(conf, securityManager)
+ }
val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
"BlockManagerMaster",
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
index 65003b6ac6a0a..a954fcc0c31fa 100644
--- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
+++ b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
@@ -17,7 +17,6 @@
package org.apache.spark
-import java.io.IOException
import javax.security.auth.callback.Callback
import javax.security.auth.callback.CallbackHandler
import javax.security.auth.callback.NameCallback
@@ -31,6 +30,8 @@ import javax.security.sasl.SaslException
import scala.collection.JavaConversions.mapAsJavaMap
+import com.google.common.base.Charsets.UTF_8
+
/**
* Implements SASL Client logic for Spark
*/
@@ -111,10 +112,10 @@ private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logg
CallbackHandler {
private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
private val secretKey = securityMgr.getSecretKey()
private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
- if (secretKey != null) secretKey.getBytes("utf-8") else "".getBytes("utf-8"))
+ if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8))
/**
* Implementation used to respond to SASL request from the server.
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
index f6b0a9132aca4..7c2afb364661f 100644
--- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
+++ b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
@@ -28,6 +28,8 @@ import javax.security.sasl.Sasl
import javax.security.sasl.SaslException
import javax.security.sasl.SaslServer
import scala.collection.JavaConversions.mapAsJavaMap
+
+import com.google.common.base.Charsets.UTF_8
import org.apache.commons.net.util.Base64
/**
@@ -89,7 +91,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi
extends CallbackHandler {
private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes("utf-8"))
+ SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
override def handle(callbacks: Array[Callback]) {
logDebug("In the sasl server callback handler")
@@ -101,7 +103,7 @@ private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Loggi
case pc: PasswordCallback => {
logDebug("handle: SASL server callback: setting userPassword")
val password: Array[Char] =
- SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes("utf-8"))
+ SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8))
pc.setPassword(password)
}
case rc: RealmCallback => {
@@ -159,7 +161,7 @@ private[spark] object SparkSaslServer {
* @return Base64-encoded string
*/
def encodeIdentifier(identifier: Array[Byte]): String = {
- new String(Base64.encodeBase64(identifier), "utf-8")
+ new String(Base64.encodeBase64(identifier), UTF_8)
}
/**
@@ -168,7 +170,7 @@ private[spark] object SparkSaslServer {
* @return password as a char array.
*/
def encodePassword(password: Array[Byte]): Array[Char] = {
- new String(Base64.encodeBase64(password), "utf-8").toCharArray()
+ new String(Base64.encodeBase64(password), UTF_8).toCharArray()
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala b/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala
new file mode 100644
index 0000000000000..1982499c5e1d3
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.collection.Map
+import scala.collection.JavaConversions._
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.{SchedulingMode, Schedulable}
+import org.apache.spark.storage.{StorageStatus, StorageUtils, RDDInfo}
+
+/**
+ * Trait that implements Spark's status APIs. This trait is designed to be mixed into
+ * SparkContext; it allows the status API code to live in its own file.
+ */
+private[spark] trait SparkStatusAPI { this: SparkContext =>
+
+ /**
+ * Return a map from the slave to the max memory available for caching and the remaining
+ * memory available for caching.
+ */
+ def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
+ (blockManagerId.host + ":" + blockManagerId.port, mem)
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Return information about what RDDs are cached, if they are in mem or on disk, how much space
+ * they take, etc.
+ */
+ @DeveloperApi
+ def getRDDStorageInfo: Array[RDDInfo] = {
+ val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
+ StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
+ rddInfos.filter(_.isCached)
+ }
+
+ /**
+ * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call.
+ * Note that this does not necessarily mean the caching or computation was successful.
+ */
+ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap
+
+ /**
+ * :: DeveloperApi ::
+ * Return information about blocks stored in all of the slaves
+ */
+ @DeveloperApi
+ def getExecutorStorageStatus: Array[StorageStatus] = {
+ env.blockManager.master.getStorageStatus
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Return pools for fair scheduler
+ */
+ @DeveloperApi
+ def getAllPools: Seq[Schedulable] = {
+ // TODO(xiajunluan): We should take nested pools into account
+ taskScheduler.rootPool.schedulableQueue.toSeq
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Return the pool associated with the given name, if one exists
+ */
+ @DeveloperApi
+ def getPoolForName(pool: String): Option[Schedulable] = {
+ Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
+ }
+
+ /**
+ * Return current scheduling mode
+ */
+ def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ taskScheduler.schedulingMode
+ }
+
+
+ /**
+ * Return a list of all known jobs in a particular job group. The returned list may contain
+ * running, failed, and completed jobs, and may vary across invocations of this method. This
+ * method does not guarantee the order of the elements in its result.
+ */
+ def getJobIdsForGroup(jobGroup: String): Array[Int] = {
+ jobProgressListener.synchronized {
+ val jobData = jobProgressListener.jobIdToData.valuesIterator
+ jobData.filter(_.jobGroup.exists(_ == jobGroup)).map(_.jobId).toArray
+ }
+ }
+
+ /**
+ * Returns job information, or `None` if the job info could not be found or was garbage collected.
+ */
+ def getJobInfo(jobId: Int): Option[SparkJobInfo] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.jobIdToData.get(jobId).map { data =>
+ new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status)
+ }
+ }
+ }
+
+ /**
+ * Returns stage information, or `None` if the stage info could not be found or was
+ * garbage collected.
+ */
+ def getStageInfo(stageId: Int): Option[SparkStageInfo] = {
+ jobProgressListener.synchronized {
+ for (
+ info <- jobProgressListener.stageIdToInfo.get(stageId);
+ data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId))
+ ) yield {
+ new SparkStageInfoImpl(
+ stageId,
+ info.attemptId,
+ info.name,
+ info.numTasks,
+ data.numActiveTasks,
+ data.numCompleteTasks,
+ data.numFailedTasks)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
new file mode 100644
index 0000000000000..90b47c847fbca
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+private class SparkJobInfoImpl (
+ val jobId: Int,
+ val stageIds: Array[Int],
+ val status: JobExecutionStatus)
+ extends SparkJobInfo
+
+private class SparkStageInfoImpl(
+ val stageId: Int,
+ val currentAttemptId: Int,
+ val name: String,
+ val numTasks: Int,
+ val numActiveTasks: Int,
+ val numCompletedTasks: Int,
+ val numFailedTasks: Int)
+ extends SparkStageInfo
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 8f0c5e78416c2..f45b463fb6f62 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -69,11 +69,13 @@ case class FetchFailed(
bmAddress: BlockManagerId, // Note that bmAddress can be null
shuffleId: Int,
mapId: Int,
- reduceId: Int)
+ reduceId: Int,
+ message: String)
extends TaskFailedReason {
override def toErrorString: String = {
val bmAddressString = if (bmAddress == null) "null" else bmAddress.toString
- s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId)"
+ s"FetchFailed($bmAddressString, shuffleId=$shuffleId, mapId=$mapId, reduceId=$reduceId, " +
+ s"message=\n$message\n)"
}
}
@@ -117,8 +119,8 @@ case object TaskKilled extends TaskFailedReason {
* the task crashed the JVM.
*/
@DeveloperApi
-case object ExecutorLostFailure extends TaskFailedReason {
- override def toErrorString: String = "ExecutorLostFailure (executor lost)"
+case class ExecutorLostFailure(execId: String) extends TaskFailedReason {
+ override def toErrorString: String = s"ExecutorLostFailure (executor ${execId} lost)"
}
/**
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index e72826dc25f41..34078142f5385 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -23,8 +23,8 @@ import java.util.jar.{JarEntry, JarOutputStream}
import scala.collection.JavaConversions._
+import com.google.common.io.{ByteStreams, Files}
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
-import com.google.common.io.Files
import org.apache.spark.util.Utils
@@ -64,12 +64,7 @@ private[spark] object TestUtils {
jarStream.putNextEntry(jarEntry)
val in = new FileInputStream(file)
- val buffer = new Array[Byte](10240)
- var nRead = 0
- while (nRead <= 0) {
- nRead = in.read(buffer, 0, buffer.length)
- jarStream.write(buffer, 0, nRead)
- }
+ ByteStreams.copy(in, jarStream)
in.close()
}
jarStream.close()
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index a6123bd108c11..8e8f7f6c4fda2 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -114,7 +114,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtract(other: JavaDoubleRDD): JavaDoubleRDD =
fromRDD(srdd.subtract(other))
@@ -233,11 +233,11 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja
* to the left except for the last which is closed
* e.g. for the array
* [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
- * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
* And on the input of 1 and 50 we would have a histogram of 1,0,0
*
* Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
- * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * from an O(log n) insertion to O(1) per element. (where n = # buckets) if you set evenBuckets
* to true.
* buckets must be sorted and not contain any duplicates.
* buckets array must be at least two elements
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index c38b96528d037..e37f3acaf6e30 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -392,7 +392,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return an RDD with the elements from `this` that are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtract(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] =
fromRDD(rdd.subtract(other))
@@ -413,7 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* Return an RDD with the pairs from `this` whose keys are not in `other`.
*
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
- * RDD will be <= us.
+ * RDD will be <= us.
*/
def subtractByKey[W](other: JavaPairRDD[K, W]): JavaPairRDD[K, V] = {
implicit val ctag: ClassTag[W] = fakeClassTag
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 791d853a015a1..e3aeba7e6c39d 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -21,6 +21,11 @@ import java.io.Closeable
import java.util
import java.util.{Map => JMap}
+import java.io.DataInputStream
+
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat}
+
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
import scala.language.implicitConversions
@@ -32,7 +37,8 @@ import org.apache.hadoop.mapred.{InputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.spark._
-import org.apache.spark.SparkContext.{DoubleAccumulatorParam, IntAccumulatorParam}
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
@@ -132,6 +138,25 @@ class JavaSparkContext(val sc: SparkContext)
/** Default min number of partitions for Hadoop RDDs when not given by user */
def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions
+
+ /**
+ * Return a list of all known jobs in a particular job group. The returned list may contain
+ * running, failed, and completed jobs, and may vary across invocations of this method. This
+ * method does not guarantee the order of the elements in its result.
+ */
+ def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.getJobIdsForGroup(jobGroup)
+
+ /**
+ * Returns job information, or `null` if the job info could not be found or was garbage collected.
+ */
+ def getJobInfo(jobId: Int): SparkJobInfo = sc.getJobInfo(jobId).orNull
+
+ /**
+ * Returns stage information, or `null` if the stage info could not be found or was
+ * garbage collected.
+ */
+ def getStageInfo(stageId: Int): SparkStageInfo = sc.getStageInfo(stageId).orNull
+
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
implicit val ctag: ClassTag[T] = fakeClassTag
@@ -183,6 +208,8 @@ class JavaSparkContext(val sc: SparkContext)
def textFile(path: String, minPartitions: Int): JavaRDD[String] =
sc.textFile(path, minPartitions)
+
+
/**
* Read a directory of text files from HDFS, a local file system (available on all nodes), or any
* Hadoop-supported file system URI. Each file is read as a single record and returned in a
@@ -196,7 +223,10 @@ class JavaSparkContext(val sc: SparkContext)
* hdfs://a-hdfs-path/part-nnnnn
* }}}
*
- * Do `JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")`,
+ * Do
+ * {{{
+ * JavaPairRDD rdd = sparkContext.wholeTextFiles("hdfs://a-hdfs-path")
+ * }}}
*
*
then `rdd` contains
* {{{
@@ -223,6 +253,78 @@ class JavaSparkContext(val sc: SparkContext)
def wholeTextFiles(path: String): JavaPairRDD[String, String] =
new JavaPairRDD(sc.wholeTextFiles(path))
+ /**
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ *
+ * @param minPartitions A suggestion value of the minimal splitting number for input data.
+ */
+ def binaryFiles(path: String, minPartitions: Int): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, minPartitions))
+
+ /**
+ * Read a directory of binary files from HDFS, a local file system (available on all nodes),
+ * or any Hadoop-supported file system URI as a byte array. Each file is read as a single
+ * record and returned in a key-value pair, where the key is the path of each file,
+ * the value is the content of each file.
+ *
+ * For example, if you have the following files:
+ * {{{
+ * hdfs://a-hdfs-path/part-00000
+ * hdfs://a-hdfs-path/part-00001
+ * ...
+ * hdfs://a-hdfs-path/part-nnnnn
+ * }}}
+ *
+ * Do
+ * `JavaPairRDD rdd = sparkContext.dataStreamFiles("hdfs://a-hdfs-path")`,
+ *
+ * then `rdd` contains
+ * {{{
+ * (a-hdfs-path/part-00000, its content)
+ * (a-hdfs-path/part-00001, its content)
+ * ...
+ * (a-hdfs-path/part-nnnnn, its content)
+ * }}}
+ *
+ * @note Small files are preferred; very large files but may cause bad performance.
+ */
+ def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] =
+ new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions))
+
+ /**
+ * Load data from a flat binary file, assuming the length of each record is constant.
+ *
+ * @param path Directory to the input data files
+ * @return An RDD of data with values, represented as byte arrays
+ */
+ def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = {
+ new JavaRDD(sc.binaryRecords(path, recordLength))
+ }
+
/** Get an RDD for a Hadoop SequenceFile with given key and value types.
*
* '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index 49dc95f349eac..5ba66178e2b78 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -61,8 +61,7 @@ private[python] object Converter extends Logging {
* Other objects are passed through without conversion.
*/
private[python] class WritableToJavaConverter(
- conf: Broadcast[SerializableWritable[Configuration]],
- batchSize: Int) extends Converter[Any, Any] {
+ conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] {
/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
@@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter(
map.put(convertWritable(k), convertWritable(v))
}
map
- case w: Writable =>
- if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w
+ case w: Writable => WritableUtils.clone(w, conf.value.value)
case other => other
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 29ca751519abd..e94ccdcd47bb7 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,15 +19,13 @@ package org.apache.spark.api.python
import java.io._
import java.net._
-import java.nio.charset.Charset
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials
-import net.razorvine.pickle.{Pickler, Unpickler}
+import com.google.common.base.Charsets.UTF_8
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
@@ -75,6 +73,7 @@ private[spark] class PythonRDD(
var complete_cleanly = false
context.addTaskCompletionListener { context =>
writerThread.shutdownOnTaskCompletion()
+ writerThread.join()
if (reuse_worker && complete_cleanly) {
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
} else {
@@ -133,7 +132,7 @@ private[spark] class PythonRDD(
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
- throw new PythonException(new String(obj, "utf-8"),
+ throw new PythonException(new String(obj, UTF_8),
writerThread.exception.getOrElse(null))
case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
@@ -145,7 +144,9 @@ private[spark] class PythonRDD(
stream.readFully(update)
accumulator += Collections.singletonList(update)
}
- complete_cleanly = true
+ if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
+ complete_cleanly = true
+ }
null
}
} catch {
@@ -154,6 +155,10 @@ private[spark] class PythonRDD(
logDebug("Exception thrown after task interruption", e)
throw new TaskKilledException
+ case e: Exception if env.isStopped =>
+ logDebug("Exception thrown after context is stopped", e)
+ null // exit silently
+
case e: Exception if writerThread.exception.isDefined =>
logError("Python worker exited unexpectedly (crashed)", e)
logError("This may have been caused by a prior exception:", writerThread.exception.get)
@@ -235,6 +240,7 @@ private[spark] class PythonRDD(
// Data values
PythonRDD.writeIteratorToStream(firstParent.iterator(split, context), dataOut)
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
@@ -306,10 +312,10 @@ private object SpecialLengths {
val END_OF_DATA_SECTION = -1
val PYTHON_EXCEPTION_THROWN = -2
val TIMING_DATA = -3
+ val END_OF_STREAM = -4
}
private[spark] object PythonRDD extends Logging {
- val UTF8 = Charset.forName("UTF-8")
// remember the broadcasts sent to each worker
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
@@ -434,7 +440,7 @@ private[spark] object PythonRDD extends Logging {
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration()))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -460,7 +466,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -486,7 +492,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -529,7 +535,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -555,7 +561,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -577,7 +583,7 @@ private[spark] object PythonRDD extends Logging {
}
def writeUTF(str: String, dataOut: DataOutputStream) {
- val bytes = str.getBytes(UTF8)
+ val bytes = str.getBytes(UTF_8)
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
@@ -738,109 +744,11 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}
-
-
- /**
- * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
- */
- @deprecated("PySpark does not use it anymore", "1.1")
- def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- SerDeUtil.initialize()
- iter.flatMap { row =>
- unpickle.loads(row) match {
- // in case of objects are pickled in batch mode
- case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
- // not in batch mode
- case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
- }
- }
- }
- }
-
- /**
- * Convert an RDD of serialized Python tuple to Array (no recursive conversions).
- * It is only used by pyspark.sql.
- */
- def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {
-
- def toArray(obj: Any): Array[_] = {
- obj match {
- case objs: JArrayList[_] =>
- objs.toArray
- case obj if obj.getClass.isArray =>
- obj.asInstanceOf[Array[_]].toArray
- }
- }
-
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].map(toArray)
- } else {
- Seq(toArray(obj))
- }
- }
- }.toJavaRDD()
- }
-
- private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
- private val pickle = new Pickler()
- private var batch = 1
- private val buffer = new mutable.ArrayBuffer[Any]
-
- override def hasNext(): Boolean = iter.hasNext
-
- override def next(): Array[Byte] = {
- while (iter.hasNext && buffer.length < batch) {
- buffer += iter.next()
- }
- val bytes = pickle.dumps(buffer.toArray)
- val size = bytes.length
- // let 1M < size < 10M
- if (size < 1024 * 1024) {
- batch *= 2
- } else if (size > 1024 * 1024 * 10 && batch > 1) {
- batch /= 2
- }
- buffer.clear()
- bytes
- }
- }
-
- /**
- * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
- * PySpark.
- */
- def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
- jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
- }
-
- /**
- * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
- */
- def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
- pyRDD.rdd.mapPartitions { iter =>
- SerDeUtil.initialize()
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].asScala
- } else {
- Seq(obj)
- }
- }
- }.toJavaRDD()
- }
}
private
class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
- override def call(arr: Array[Byte]) : String = new String(arr, PythonRDD.UTF8)
+ override def call(arr: Array[Byte]) : String = new String(arr, UTF_8)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index ebdc3533e0992..a4153aaa926f8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -18,8 +18,13 @@
package org.apache.spark.api.python
import java.nio.ByteOrder
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.spark.api.java.JavaRDD
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.Failure
import scala.util.Try
@@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging {
}
initialize()
+
+ /**
+ * Convert an RDD of Java objects to Array (no recursive conversions).
+ * It is only used by pyspark.sql.
+ */
+ def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
+ jrdd.rdd.map {
+ case objs: JArrayList[_] =>
+ objs.toArray
+ case obj if obj.getClass.isArray =>
+ obj.asInstanceOf[Array[_]].toArray
+ }.toJavaRDD()
+ }
+
+ /**
+ * Choose batch size based on size of objects
+ */
+ private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
+ private val pickle = new Pickler()
+ private var batch = 1
+ private val buffer = new mutable.ArrayBuffer[Any]
+
+ override def hasNext: Boolean = iter.hasNext
+
+ override def next(): Array[Byte] = {
+ while (iter.hasNext && buffer.length < batch) {
+ buffer += iter.next()
+ }
+ val bytes = pickle.dumps(buffer.toArray)
+ val size = bytes.length
+ // let 1M < size < 10M
+ if (size < 1024 * 1024) {
+ batch *= 2
+ } else if (size > 1024 * 1024 * 10 && batch > 1) {
+ batch /= 2
+ }
+ buffer.clear()
+ bytes
+ }
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize()
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj.asInstanceOf[JArrayList[_]].asScala
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
+ }
+
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val kt = Try {
@@ -128,17 +200,18 @@ private[spark] object SerDeUtil extends Logging {
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())
+
rdd.mapPartitions { iter =>
- val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
- if (batchSize > 1) {
- cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
+ if (batchSize == 0) {
+ new AutoBatchedPickler(cleaned)
} else {
- cleaned.map(pickle.dumps(_))
+ val pickle = new Pickler
+ cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}
@@ -146,36 +219,22 @@ private[spark] object SerDeUtil extends Logging {
/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
- def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
+ def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
- Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
+ Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
- pyRDD.mapPartitions { iter =>
- initialize()
- val unpickle = new Unpickler
- val unpickled =
- if (batchSerialized) {
- iter.flatMap { batch =>
- unpickle.loads(batch) match {
- case objs: java.util.List[_] => collectionAsScalaIterable(objs)
- case other => throw new SparkException(
- s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
- }
- }
- } else {
- iter.map(unpickle.loads(_))
- }
- unpickled.map {
- case obj if isPair(obj) =>
- // we only accept (K, V)
- val arr = obj.asInstanceOf[Array[_]]
- (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
- case other => throw new SparkException(
- s"RDD element of type ${other.getClass.getName} cannot be used")
- }
+
+ val rdd = pythonToJava(pyRDD, batched).rdd
+ rdd.first match {
+ case obj if isPair(obj) =>
+ // we only accept (K, V)
+ case other => throw new SparkException(
+ s"RDD element of type ${other.getClass.getName} cannot be used")
+ }
+ rdd.map { obj =>
+ val arr = obj.asInstanceOf[Array[_]]
+ (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}
-
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
index d11db978b842e..c0cbd28a845be 100644
--- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -18,7 +18,8 @@
package org.apache.spark.api.python
import java.io.{DataOutput, DataInput}
-import java.nio.charset.Charset
+
+import com.google.common.base.Charsets.UTF_8
import org.apache.hadoop.io._
import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
@@ -136,7 +137,7 @@ object WriteInputFormatTestDataGenerator {
sc.parallelize(intKeys).saveAsSequenceFile(intPath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath)
sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath)
- sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(Charset.forName("UTF-8"))) }
+ sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes(UTF_8)) }
).saveAsSequenceFile(bytesPath)
val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false))
sc.parallelize(bools).saveAsSequenceFile(boolPath)
@@ -175,11 +176,11 @@ object WriteInputFormatTestDataGenerator {
// Create test data for arbitrary custom writable TestWritable
val testClass = Seq(
- ("1", TestWritable("test1", 123, 54.0)),
- ("2", TestWritable("test2", 456, 8762.3)),
- ("1", TestWritable("test3", 123, 423.1)),
- ("3", TestWritable("test56", 456, 423.5)),
- ("2", TestWritable("test2", 123, 5435.2))
+ ("1", TestWritable("test1", 1, 1.0)),
+ ("2", TestWritable("test2", 2, 2.3)),
+ ("3", TestWritable("test3", 3, 3.1)),
+ ("5", TestWritable("test56", 5, 5.5)),
+ ("4", TestWritable("test4", 4, 4.2))
)
val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
rdd.saveAsNewAPIHadoopFile(classPath,
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 15fd30e65761d..87f5cf944ed85 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -20,6 +20,8 @@ package org.apache.spark.broadcast
import java.io.Serializable
import org.apache.spark.SparkException
+import org.apache.spark.Logging
+import org.apache.spark.util.Utils
import scala.reflect.ClassTag
@@ -52,7 +54,7 @@ import scala.reflect.ClassTag
* @param id A unique identifier for the broadcast variable.
* @tparam T Type of the data contained in the broadcast variable.
*/
-abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
+abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable with Logging {
/**
* Flag signifying whether the broadcast variable is valid
@@ -60,6 +62,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
*/
@volatile private var _isValid = true
+ private var _destroySite = ""
+
/** Get the broadcasted value. */
def value: T = {
assertValid()
@@ -84,13 +88,26 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
doUnpersist(blocking)
}
+
+ /**
+ * Destroy all data and metadata related to this broadcast variable. Use this with caution;
+ * once a broadcast variable has been destroyed, it cannot be used again.
+ * This method blocks until destroy has completed
+ */
+ def destroy() {
+ destroy(blocking = true)
+ }
+
/**
* Destroy all data and metadata related to this broadcast variable. Use this with caution;
* once a broadcast variable has been destroyed, it cannot be used again.
+ * @param blocking Whether to block until destroy has completed
*/
private[spark] def destroy(blocking: Boolean) {
assertValid()
_isValid = false
+ _destroySite = Utils.getCallSite().shortForm
+ logInfo("Destroying %s (from %s)".format(toString, _destroySite))
doDestroy(blocking)
}
@@ -124,7 +141,8 @@ abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
/** Check if this broadcast is valid. If not valid, exception is thrown. */
protected def assertValid() {
if (!_isValid) {
- throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString))
+ throw new SparkException(
+ "Attempted to use %s after it was destroyed (%s) ".format(toString, _destroySite))
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 4cd4f4f96fd16..7dade04273b08 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -72,13 +72,13 @@ private[spark] class HttpBroadcast[T: ClassTag](
}
/** Used by the JVM when serializing this object. */
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
assertValid()
out.defaultWriteObject()
}
/** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
HttpBroadcast.synchronized {
SparkEnv.get.blockManager.getSingle(blockId) match {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 42d58682a1e23..94142d33369c7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -26,8 +26,9 @@ import scala.util.Random
import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
-import org.apache.spark.util.ByteBufferInputStream
+import org.apache.spark.util.{ByteBufferInputStream, Utils}
import org.apache.spark.util.io.ByteArrayChunkOutputStream
/**
@@ -46,53 +47,66 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream
* This prevents the driver from being the bottleneck in sending out multiple copies of the
* broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
*
+ * When initialized, TorrentBroadcast objects read SparkEnv.get.conf.
+ *
* @param obj object to broadcast
- * @param isLocal whether Spark is running in local mode (single JVM process).
* @param id A unique identifier for the broadcast variable.
*/
-private[spark] class TorrentBroadcast[T: ClassTag](
- obj : T,
- @transient private val isLocal: Boolean,
- id: Long)
+private[spark] class TorrentBroadcast[T: ClassTag](obj: T, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
/**
- * Value of the broadcast object. On driver, this is set directly by the constructor.
- * On executors, this is reconstructed by [[readObject]], which builds this value by reading
- * blocks from the driver and/or other executors.
+ * Value of the broadcast object on executors. This is reconstructed by [[readBroadcastBlock]],
+ * which builds this value by reading blocks from the driver and/or other executors.
+ *
+ * On the driver, if the value is required, it is read lazily from the block manager.
*/
- @transient private var _value: T = obj
+ @transient private lazy val _value: T = readBroadcastBlock()
+
+ /** The compression codec to use, or None if compression is disabled */
+ @transient private var compressionCodec: Option[CompressionCodec] = _
+ /** Size of each block. Default value is 4MB. This value is only read by the broadcaster. */
+ @transient private var blockSize: Int = _
+
+ private def setConf(conf: SparkConf) {
+ compressionCodec = if (conf.getBoolean("spark.broadcast.compress", true)) {
+ Some(CompressionCodec.createCodec(conf))
+ } else {
+ None
+ }
+ blockSize = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
+ }
+ setConf(SparkEnv.get.conf)
private val broadcastId = BroadcastBlockId(id)
/** Total number of blocks this broadcast variable contains. */
- private val numBlocks: Int = writeBlocks()
+ private val numBlocks: Int = writeBlocks(obj)
- override protected def getValue() = _value
+ override protected def getValue() = {
+ _value
+ }
/**
* Divide the object into multiple blocks and put those blocks in the block manager.
- *
+ * @param value the object to divide
* @return number of blocks this broadcast variable is divided into
*/
- private def writeBlocks(): Int = {
- // For local mode, just put the object in the BlockManager so we can find it later.
- SparkEnv.get.blockManager.putSingle(
- broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
-
- if (!isLocal) {
- val blocks = TorrentBroadcast.blockifyObject(_value)
- blocks.zipWithIndex.foreach { case (block, i) =>
- SparkEnv.get.blockManager.putBytes(
- BroadcastBlockId(id, "piece" + i),
- block,
- StorageLevel.MEMORY_AND_DISK_SER,
- tellMaster = true)
- }
- blocks.length
- } else {
- 0
+ private def writeBlocks(value: T): Int = {
+ // Store a copy of the broadcast variable in the driver so that tasks run on the driver
+ // do not create a duplicate copy of the broadcast variable's value.
+ SparkEnv.get.blockManager.putSingle(broadcastId, value, StorageLevel.MEMORY_AND_DISK,
+ tellMaster = false)
+ val blocks =
+ TorrentBroadcast.blockifyObject(value, blockSize, SparkEnv.get.serializer, compressionCodec)
+ blocks.zipWithIndex.foreach { case (block, i) =>
+ SparkEnv.get.blockManager.putBytes(
+ BroadcastBlockId(id, "piece" + i),
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
}
+ blocks.length
}
/** Fetch torrent blocks from the driver and/or other executors. */
@@ -104,29 +118,24 @@ private[spark] class TorrentBroadcast[T: ClassTag](
for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
val pieceId = BroadcastBlockId(id, "piece" + pid)
-
- // First try getLocalBytes because there is a chance that previous attempts to fetch the
+ logDebug(s"Reading piece $pieceId of $broadcastId")
+ // First try getLocalBytes because there is a chance that previous attempts to fetch the
// broadcast blocks have already fetched some of the blocks. In that case, some blocks
// would be available locally (on this executor).
- var blockOpt = bm.getLocalBytes(pieceId)
- if (!blockOpt.isDefined) {
- blockOpt = bm.getRemoteBytes(pieceId)
- blockOpt match {
- case Some(block) =>
- // If we found the block from remote executors/driver's BlockManager, put the block
- // in this executor's BlockManager.
- SparkEnv.get.blockManager.putBytes(
- pieceId,
- block,
- StorageLevel.MEMORY_AND_DISK_SER,
- tellMaster = true)
-
- case None =>
- throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
- }
+ def getLocal: Option[ByteBuffer] = bm.getLocalBytes(pieceId)
+ def getRemote: Option[ByteBuffer] = bm.getRemoteBytes(pieceId).map { block =>
+ // If we found the block from remote executors/driver's BlockManager, put the block
+ // in this executor's BlockManager.
+ SparkEnv.get.blockManager.putBytes(
+ pieceId,
+ block,
+ StorageLevel.MEMORY_AND_DISK_SER,
+ tellMaster = true)
+ block
}
- // If we get here, the option is defined.
- blocks(pid) = blockOpt.get
+ val block: ByteBuffer = getLocal.orElse(getRemote).getOrElse(
+ throw new SparkException(s"Failed to get $pieceId of $broadcastId"))
+ blocks(pid) = block
}
blocks
}
@@ -147,75 +156,62 @@ private[spark] class TorrentBroadcast[T: ClassTag](
}
/** Used by the JVM when serializing this object. */
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
assertValid()
out.defaultWriteObject()
}
- /** Used by the JVM when deserializing this object. */
- private def readObject(in: ObjectInputStream) {
- in.defaultReadObject()
+ private def readBroadcastBlock(): T = Utils.tryOrIOException {
TorrentBroadcast.synchronized {
+ setConf(SparkEnv.get.conf)
SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
case Some(x) =>
- _value = x.asInstanceOf[T]
+ x.asInstanceOf[T]
case None =>
logInfo("Started reading broadcast variable " + id)
- val start = System.nanoTime()
+ val startTimeMs = System.currentTimeMillis()
val blocks = readBlocks()
- val time = (System.nanoTime() - start) / 1e9
- logInfo("Reading broadcast variable " + id + " took " + time + " s")
+ logInfo("Reading broadcast variable " + id + " took" + Utils.getUsedTimeMs(startTimeMs))
- _value = TorrentBroadcast.unBlockifyObject[T](blocks)
+ val obj = TorrentBroadcast.unBlockifyObject[T](
+ blocks, SparkEnv.get.serializer, compressionCodec)
// Store the merged copy in BlockManager so other tasks on this executor don't
// need to re-fetch it.
SparkEnv.get.blockManager.putSingle(
- broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ broadcastId, obj, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+ obj
}
}
}
+
}
private object TorrentBroadcast extends Logging {
- /** Size of each block. Default value is 4MB. */
- private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
- private var initialized = false
- private var conf: SparkConf = null
- private var compress: Boolean = false
- private var compressionCodec: CompressionCodec = null
-
- def initialize(_isDriver: Boolean, conf: SparkConf) {
- TorrentBroadcast.conf = conf // TODO: we might have to fix it in tests
- synchronized {
- if (!initialized) {
- compress = conf.getBoolean("spark.broadcast.compress", true)
- compressionCodec = CompressionCodec.createCodec(conf)
- initialized = true
- }
- }
- }
- def stop() {
- initialized = false
- }
-
- def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
- val bos = new ByteArrayChunkOutputStream(BLOCK_SIZE)
- val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
- val ser = SparkEnv.get.serializer.newInstance()
+ def blockifyObject[T: ClassTag](
+ obj: T,
+ blockSize: Int,
+ serializer: Serializer,
+ compressionCodec: Option[CompressionCodec]): Array[ByteBuffer] = {
+ val bos = new ByteArrayChunkOutputStream(blockSize)
+ val out: OutputStream = compressionCodec.map(c => c.compressedOutputStream(bos)).getOrElse(bos)
+ val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
serOut.writeObject[T](obj).close()
bos.toArrays.map(ByteBuffer.wrap)
}
- def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
+ def unBlockifyObject[T: ClassTag](
+ blocks: Array[ByteBuffer],
+ serializer: Serializer,
+ compressionCodec: Option[CompressionCodec]): T = {
+ require(blocks.nonEmpty, "Cannot unblockify an empty array of blocks")
val is = new SequenceInputStream(
asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
- val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
-
- val ser = SparkEnv.get.serializer.newInstance()
+ val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
+ val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
val obj = serIn.readObject[T]()
serIn.close()
@@ -227,6 +223,7 @@ private object TorrentBroadcast extends Logging {
* If removeFromDriver is true, also remove these persisted blocks on the driver.
*/
def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean) = {
+ logDebug(s"Unpersisting TorrentBroadcast $id")
SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
}
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
index ad0f701d7a98f..fb024c12094f2 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -28,14 +28,13 @@ import org.apache.spark.{SecurityManager, SparkConf}
*/
class TorrentBroadcastFactory extends BroadcastFactory {
- override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) {
- TorrentBroadcast.initialize(isDriver, conf)
- }
+ override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { }
- override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
- new TorrentBroadcast[T](value_, isLocal, id)
+ override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) = {
+ new TorrentBroadcast[T](value_, id)
+ }
- override def stop() { TorrentBroadcast.stop() }
+ override def stop() { }
/**
* Remove all persisted state associated with the torrent broadcast with the given ID.
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 39150deab863c..4e802e02c4149 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -17,6 +17,8 @@
package org.apache.spark.deploy
+import java.net.{URI, URISyntaxException}
+
import scala.collection.mutable.ListBuffer
import org.apache.log4j.Level
@@ -114,5 +116,12 @@ private[spark] class ClientArguments(args: Array[String]) {
}
object ClientArguments {
- def isValidJarUrl(s: String): Boolean = s.matches("(.+):(.+)jar")
+ def isValidJarUrl(s: String): Boolean = {
+ try {
+ val uri = new URI(s)
+ uri.getScheme != null && uri.getAuthority != null && s.endsWith("jar")
+ } catch {
+ case _: URISyntaxException => false
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index a7368f9f3dfbe..b9dd8557ee904 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -71,6 +71,8 @@ private[deploy] object DeployMessages {
case class RegisterWorkerFailed(message: String) extends DeployMessage
+ case class ReconnectWorker(masterUrl: String) extends DeployMessage
+
case class KillExecutor(masterUrl: String, appId: String, execId: Int) extends DeployMessage
case class LaunchExecutor(
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index fe0ad9ebbca12..e28eaad8a5180 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -20,12 +20,15 @@ package org.apache.spark.deploy
import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.hadoop.fs.FileSystem.Statistics
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
import org.apache.spark.{Logging, SparkContext, SparkConf, SparkException}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
import scala.collection.JavaConversions._
@@ -121,6 +124,33 @@ class SparkHadoopUtil extends Logging {
UserGroupInformation.loginUserFromKeytab(principalName, keytabFilename)
}
+ /**
+ * Returns a function that can be called to find Hadoop FileSystem bytes read. If
+ * getFSBytesReadOnThreadCallback is called from thread r at time t, the returned callback will
+ * return the bytes read on r since t. Reflection is required because thread-level FileSystem
+ * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
+ * Returns None if the required method can't be found.
+ */
+ private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration)
+ : Option[() => Long] = {
+ val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
+ val scheme = qualifiedPath.toUri().getScheme()
+ val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
+ try {
+ val threadStats = stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
+ val statisticsDataClass =
+ Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
+ val getBytesReadMethod = statisticsDataClass.getDeclaredMethod("getBytesRead")
+ val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
+ val baselineBytesRead = f()
+ Some(() => f() - baselineBytesRead)
+ } catch {
+ case e: NoSuchMethodException => {
+ logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e)
+ None
+ }
+ }
+ }
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index f97bf67fa5a3b..b43e68e40f791 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -158,8 +158,9 @@ object SparkSubmit {
args.files = mergeFileLists(args.files, args.primaryResource)
}
args.files = mergeFileLists(args.files, args.pyFiles)
- // Format python file paths properly before adding them to the PYTHONPATH
- sysProps("spark.submit.pyFiles") = PythonRunner.formatPaths(args.pyFiles).mkString(",")
+ if (args.pyFiles != null) {
+ sysProps("spark.submit.pyFiles") = args.pyFiles
+ }
}
// Special flag to avoid deprecation warnings at the client
@@ -273,15 +274,32 @@ object SparkSubmit {
}
}
- // Properties given with --conf are superceded by other options, but take precedence over
- // properties in the defaults file.
+ // Load any properties specified through --conf and the default properties file
for ((k, v) <- args.sparkProperties) {
sysProps.getOrElseUpdate(k, v)
}
- // Read from default spark properties, if any
- for ((k, v) <- args.defaultSparkProperties) {
- sysProps.getOrElseUpdate(k, v)
+ // Resolve paths in certain spark properties
+ val pathConfigs = Seq(
+ "spark.jars",
+ "spark.files",
+ "spark.yarn.jar",
+ "spark.yarn.dist.files",
+ "spark.yarn.dist.archives")
+ pathConfigs.foreach { config =>
+ // Replace old URIs with resolved URIs, if they exist
+ sysProps.get(config).foreach { oldValue =>
+ sysProps(config) = Utils.resolveURIs(oldValue)
+ }
+ }
+
+ // Resolve and format python file paths properly before adding them to the PYTHONPATH.
+ // The resolving part is redundant in the case of --py-files, but necessary if the user
+ // explicitly sets `spark.submit.pyFiles` in his/her default properties file.
+ sysProps.get("spark.submit.pyFiles").foreach { pyFiles =>
+ val resolvedPyFiles = Utils.resolveURIs(pyFiles)
+ val formattedPyFiles = PythonRunner.formatPaths(resolvedPyFiles).mkString(",")
+ sysProps("spark.submit.pyFiles") = formattedPyFiles
}
(childArgs, childClasspath, sysProps, childMainClass)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 72a452e0aefb5..f0e9ee67f6a67 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -19,7 +19,6 @@ package org.apache.spark.deploy
import java.util.jar.JarFile
-import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import org.apache.spark.util.Utils
@@ -72,39 +71,54 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
defaultProperties
}
- // Respect SPARK_*_MEMORY for cluster mode
- driverMemory = sys.env.get("SPARK_DRIVER_MEMORY").orNull
- executorMemory = sys.env.get("SPARK_EXECUTOR_MEMORY").orNull
-
+ // Set parameters from command line arguments
parseOpts(args.toList)
- mergeSparkProperties()
+ // Populate `sparkProperties` map from properties file
+ mergeDefaultSparkProperties()
+ // Use `sparkProperties` map along with env vars to fill in any missing parameters
+ loadEnvironmentArguments()
+
checkRequiredArguments()
/**
- * Fill in any undefined values based on the default properties file or options passed in through
- * the '--conf' flag.
+ * Merge values from the default properties file with those specified through --conf.
+ * When this is called, `sparkProperties` is already filled with configs from the latter.
*/
- private def mergeSparkProperties(): Unit = {
+ private def mergeDefaultSparkProperties(): Unit = {
// Use common defaults file, if not specified by user
propertiesFile = Option(propertiesFile).getOrElse(Utils.getDefaultPropertiesFile(env))
+ // Honor --conf before the defaults file
+ defaultSparkProperties.foreach { case (k, v) =>
+ if (!sparkProperties.contains(k)) {
+ sparkProperties(k) = v
+ }
+ }
+ }
- val properties = HashMap[String, String]()
- properties.putAll(defaultSparkProperties)
- properties.putAll(sparkProperties)
-
- // Use properties file as fallback for values which have a direct analog to
- // arguments in this script.
- master = Option(master).orElse(properties.get("spark.master")).orNull
- executorMemory = Option(executorMemory).orElse(properties.get("spark.executor.memory")).orNull
- executorCores = Option(executorCores).orElse(properties.get("spark.executor.cores")).orNull
+ /**
+ * Load arguments from environment variables, Spark properties etc.
+ */
+ private def loadEnvironmentArguments(): Unit = {
+ master = Option(master)
+ .orElse(sparkProperties.get("spark.master"))
+ .orElse(env.get("MASTER"))
+ .orNull
+ driverMemory = Option(driverMemory)
+ .orElse(sparkProperties.get("spark.driver.memory"))
+ .orElse(env.get("SPARK_DRIVER_MEMORY"))
+ .orNull
+ executorMemory = Option(executorMemory)
+ .orElse(sparkProperties.get("spark.executor.memory"))
+ .orElse(env.get("SPARK_EXECUTOR_MEMORY"))
+ .orNull
+ executorCores = Option(executorCores)
+ .orElse(sparkProperties.get("spark.executor.cores"))
+ .orNull
totalExecutorCores = Option(totalExecutorCores)
- .orElse(properties.get("spark.cores.max"))
+ .orElse(sparkProperties.get("spark.cores.max"))
.orNull
- name = Option(name).orElse(properties.get("spark.app.name")).orNull
- jars = Option(jars).orElse(properties.get("spark.jars")).orNull
-
- // This supports env vars in older versions of Spark
- master = Option(master).orElse(env.get("MASTER")).orNull
+ name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull
+ jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
// Try to set main class from JAR if no --class argument is given
@@ -131,7 +145,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
/** Ensure that required fields exists. Call this only once all defaults are loaded. */
- private def checkRequiredArguments() = {
+ private def checkRequiredArguments(): Unit = {
if (args.length == 0) {
printUsageAndExit(-1)
}
@@ -166,7 +180,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
- override def toString = {
+ override def toString = {
s"""Parsed arguments:
| master $master
| deployMode $deployMode
@@ -174,7 +188,6 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| executorCores $executorCores
| totalExecutorCores $totalExecutorCores
| propertiesFile $propertiesFile
- | extraSparkProperties $sparkProperties
| driverMemory $driverMemory
| driverCores $driverCores
| driverExtraClassPath $driverExtraClassPath
@@ -193,8 +206,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
| jars $jars
| verbose $verbose
|
- |Default properties from $propertiesFile:
- |${defaultSparkProperties.mkString(" ", "\n ", "\n")}
+ |Spark properties used, including those specified through
+ | --conf and those from the properties file $propertiesFile:
+ |${sparkProperties.mkString(" ", "\n ", "\n")}
""".stripMargin
}
@@ -327,7 +341,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
}
- private def printUsageAndExit(exitCode: Int, unknownParam: Any = null) {
+ private def printUsageAndExit(exitCode: Int, unknownParam: Any = null): Unit = {
val outStream = SparkSubmit.printStream
if (unknownParam != null) {
outStream.println("Unknown/unsupported param " + unknownParam)
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
index 0125330589da5..2b894a796c8c6 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
@@ -82,17 +82,8 @@ private[spark] object SparkSubmitDriverBootstrapper {
.orElse(confDriverMemory)
.getOrElse(defaultDriverMemory)
- val newLibraryPath =
- if (submitLibraryPath.isDefined) {
- // SPARK_SUBMIT_LIBRARY_PATH is already captured in JAVA_OPTS
- ""
- } else {
- confLibraryPath.map("-Djava.library.path=" + _).getOrElse("")
- }
-
val newClasspath =
if (submitClasspath.isDefined) {
- // SPARK_SUBMIT_CLASSPATH is already captured in CLASSPATH
classpath
} else {
classpath + confClasspath.map(sys.props("path.separator") + _).getOrElse("")
@@ -114,7 +105,6 @@ private[spark] object SparkSubmitDriverBootstrapper {
val command: Seq[String] =
Seq(runner) ++
Seq("-cp", newClasspath) ++
- Seq(newLibraryPath) ++
filteredJavaOpts ++
Seq(s"-Xms$newDriverMemory", s"-Xmx$newDriverMemory") ++
Seq("org.apache.spark.deploy.SparkSubmit") ++
@@ -130,6 +120,13 @@ private[spark] object SparkSubmitDriverBootstrapper {
// Start the driver JVM
val filteredCommand = command.filter(_.nonEmpty)
val builder = new ProcessBuilder(filteredCommand)
+ val env = builder.environment()
+
+ if (submitLibraryPath.isEmpty && confLibraryPath.nonEmpty) {
+ val libraryPaths = confLibraryPath ++ sys.env.get(Utils.libraryPathEnvName)
+ env.put(Utils.libraryPathEnvName, libraryPaths.mkString(sys.props("path.separator")))
+ }
+
val process = builder.start()
// Redirect stdout and stderr from the child JVM
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 481f6c93c6a8d..2d1609b973607 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -112,7 +112,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
val ui = {
val conf = this.conf.clone()
val appSecManager = new SecurityManager(conf)
- new SparkUI(conf, appSecManager, replayBus, appId,
+ SparkUI.createHistoryUI(conf, replayBus, appSecManager, appId,
s"${HistoryServer.UI_PATH_PREFIX}/$appId")
// Do not call ui.bind() to avoid creating a new server for each application
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index d25c29113d6da..0e249e51a77d8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -84,11 +84,11 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index c3ca43f8d0734..6ba395be1cc2c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import akka.actor.ActorRef
import org.apache.spark.deploy.ApplicationDescription
+import org.apache.spark.util.Utils
private[spark] class ApplicationInfo(
val startTime: Long,
@@ -46,7 +47,7 @@ private[spark] class ApplicationInfo(
init()
- private def readObject(in: java.io.ObjectInputStream): Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
index 80b570a44af18..2ac21186881fa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/DriverInfo.scala
@@ -20,6 +20,7 @@ package org.apache.spark.deploy.master
import java.util.Date
import org.apache.spark.deploy.DriverDescription
+import org.apache.spark.util.Utils
private[spark] class DriverInfo(
val startTime: Long,
@@ -36,7 +37,7 @@ private[spark] class DriverInfo(
init()
- private def readObject(in: java.io.ObjectInputStream): Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index f98b531316a3d..2f81d472d7b78 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -341,7 +341,14 @@ private[spark] class Master(
case Some(workerInfo) =>
workerInfo.lastHeartbeat = System.currentTimeMillis()
case None =>
- logWarning("Got heartbeat from unregistered worker " + workerId)
+ if (workers.map(_.id).contains(workerId)) {
+ logWarning(s"Got heartbeat from unregistered worker $workerId." +
+ " Asking it to re-register.")
+ sender ! ReconnectWorker(masterUrl)
+ } else {
+ logWarning(s"Got heartbeat from unregistered worker $workerId." +
+ " This worker was never registered, so ignoring the heartbeat.")
+ }
}
}
@@ -714,8 +721,8 @@ private[spark] class Master(
try {
val replayBus = new ReplayListenerBus(eventLogPaths, fileSystem, compressionCodec)
- val ui = new SparkUI(new SparkConf, replayBus, appName + " (completed)",
- HistoryServer.UI_PATH_PREFIX + s"/${app.id}")
+ val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf),
+ appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}")
replayBus.replay()
appIdToUI(app.id) = ui
webUi.attachSparkUI(ui)
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
index c5fa9cf7d7c2d..d221b0f6cc86b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala
@@ -50,7 +50,7 @@ private[spark] class WorkerInfo(
def coresFree: Int = cores - coresUsed
def memoryFree: Int = memory - memoryUsed
- private def readObject(in: java.io.ObjectInputStream) : Unit = {
+ private def readObject(in: java.io.ObjectInputStream): Unit = Utils.tryOrIOException {
in.defaultReadObject()
init()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
index 2e9be2a180c68..28e9662db5da9 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala
@@ -20,6 +20,8 @@ package org.apache.spark.deploy.worker
import java.io.{File, FileOutputStream, InputStream, IOException}
import java.lang.System._
+import scala.collection.Map
+
import org.apache.spark.Logging
import org.apache.spark.deploy.Command
import org.apache.spark.util.Utils
@@ -29,7 +31,29 @@ import org.apache.spark.util.Utils
*/
private[spark]
object CommandUtils extends Logging {
- def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+
+ /**
+ * Build a ProcessBuilder based on the given parameters.
+ * The `env` argument is exposed for testing.
+ */
+ def buildProcessBuilder(
+ command: Command,
+ memory: Int,
+ sparkHome: String,
+ substituteArguments: String => String,
+ classPaths: Seq[String] = Seq[String](),
+ env: Map[String, String] = sys.env): ProcessBuilder = {
+ val localCommand = buildLocalCommand(command, substituteArguments, classPaths, env)
+ val commandSeq = buildCommandSeq(localCommand, memory, sparkHome)
+ val builder = new ProcessBuilder(commandSeq: _*)
+ val environment = builder.environment()
+ for ((key, value) <- localCommand.environment) {
+ environment.put(key, value)
+ }
+ builder
+ }
+
+ private def buildCommandSeq(command: Command, memory: Int, sparkHome: String): Seq[String] = {
val runner = sys.env.get("JAVA_HOME").map(_ + "/bin/java").getOrElse("java")
// SPARK-698: do not call the run.cmd script, as process.destroy()
@@ -38,11 +62,41 @@ object CommandUtils extends Logging {
command.arguments
}
+ /**
+ * Build a command based on the given one, taking into account the local environment
+ * of where this command is expected to run, substitute any placeholders, and append
+ * any extra class paths.
+ */
+ private def buildLocalCommand(
+ command: Command,
+ substituteArguments: String => String,
+ classPath: Seq[String] = Seq[String](),
+ env: Map[String, String]): Command = {
+ val libraryPathName = Utils.libraryPathEnvName
+ val libraryPathEntries = command.libraryPathEntries
+ val cmdLibraryPath = command.environment.get(libraryPathName)
+
+ val newEnvironment = if (libraryPathEntries.nonEmpty && libraryPathName.nonEmpty) {
+ val libraryPaths = libraryPathEntries ++ cmdLibraryPath ++ env.get(libraryPathName)
+ command.environment + ((libraryPathName, libraryPaths.mkString(File.pathSeparator)))
+ } else {
+ command.environment
+ }
+
+ Command(
+ command.mainClass,
+ command.arguments.map(substituteArguments),
+ newEnvironment,
+ command.classPathEntries ++ classPath,
+ Seq[String](), // library path already captured in environment variable
+ command.javaOpts)
+ }
+
/**
* Attention: this must always be aligned with the environment variables in the run scripts and
* the way the JAVA_OPTS are assembled there.
*/
- def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
+ private def buildJavaOpts(command: Command, memory: Int, sparkHome: String): Seq[String] = {
val memoryOpts = Seq(s"-Xms${memory}M", s"-Xmx${memory}M")
// Exists for backwards compatibility with older Spark versions
@@ -53,14 +107,6 @@ object CommandUtils extends Logging {
logWarning("Set SPARK_LOCAL_DIRS for node-specific storage locations.")
}
- val libraryOpts =
- if (command.libraryPathEntries.size > 0) {
- val joined = command.libraryPathEntries.mkString(File.pathSeparator)
- Seq(s"-Djava.library.path=$joined")
- } else {
- Seq()
- }
-
// Figure out our classpath with the external compute-classpath script
val ext = if (System.getProperty("os.name").startsWith("Windows")) ".cmd" else ".sh"
val classPath = Utils.executeAndGetOutput(
@@ -71,7 +117,7 @@ object CommandUtils extends Logging {
val javaVersion = System.getProperty("java.version")
val permGenOpt = if (!javaVersion.startsWith("1.8")) Some("-XX:MaxPermSize=128m") else None
Seq("-cp", userClassPath.filterNot(_.isEmpty).mkString(File.pathSeparator)) ++
- permGenOpt ++ libraryOpts ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts
+ permGenOpt ++ workerLocalOpts ++ command.javaOpts ++ memoryOpts
}
/** Spawn a thread that will redirect a given stream to a file */
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
index 9f9911762505a..28cab36c7b9e2 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala
@@ -23,7 +23,7 @@ import scala.collection.JavaConversions._
import scala.collection.Map
import akka.actor.ActorRef
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileUtil, Path}
@@ -76,17 +76,9 @@ private[spark] class DriverRunner(
// Make sure user application jar is on the classpath
// TODO: If we add ability to submit multiple jars they should also be added here
- val classPath = driverDesc.command.classPathEntries ++ Seq(s"$localJarFilename")
- val newCommand = Command(
- driverDesc.command.mainClass,
- driverDesc.command.arguments.map(substituteVariables),
- driverDesc.command.environment,
- classPath,
- driverDesc.command.libraryPathEntries,
- driverDesc.command.javaOpts)
- val command = CommandUtils.buildCommandSeq(newCommand, driverDesc.mem,
- sparkHome.getAbsolutePath)
- launchDriver(command, driverDesc.command.environment, driverDir, driverDesc.supervise)
+ val builder = CommandUtils.buildProcessBuilder(driverDesc.command, driverDesc.mem,
+ sparkHome.getAbsolutePath, substituteVariables, Seq(localJarFilename))
+ launchDriver(builder, driverDir, driverDesc.supervise)
}
catch {
case e: Exception => finalException = Some(e)
@@ -165,11 +157,8 @@ private[spark] class DriverRunner(
localJarFilename
}
- private def launchDriver(command: Seq[String], envVars: Map[String, String], baseDir: File,
- supervise: Boolean) {
- val builder = new ProcessBuilder(command: _*).directory(baseDir)
- envVars.map{ case(k,v) => builder.environment().put(k, v) }
-
+ private def launchDriver(builder: ProcessBuilder, baseDir: File, supervise: Boolean) {
+ builder.directory(baseDir)
def initialize(process: Process) = {
// Redirect stdout and stderr to files
val stdout = new File(baseDir, "stdout")
@@ -177,8 +166,8 @@ private[spark] class DriverRunner(
val stderr = new File(baseDir, "stderr")
val header = "Launch Command: %s\n%s\n\n".format(
- command.mkString("\"", "\" \"", "\""), "=" * 40)
- Files.append(header, stderr, Charsets.UTF_8)
+ builder.command.mkString("\"", "\" \"", "\""), "=" * 40)
+ Files.append(header, stderr, UTF_8)
CommandUtils.redirectStream(process.getErrorStream, stderr)
}
runCommandWithRetry(ProcessBuilderLike(builder), initialize, supervise)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 71d7385b08eb9..8ba6a01bbcb97 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -19,8 +19,10 @@ package org.apache.spark.deploy.worker
import java.io._
+import scala.collection.JavaConversions._
+
import akka.actor.ActorRef
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.spark.{SparkConf, Logging}
@@ -115,33 +117,21 @@ private[spark] class ExecutorRunner(
case other => other
}
- def getCommandSeq = {
- val command = Command(
- appDesc.command.mainClass,
- appDesc.command.arguments.map(substituteVariables),
- appDesc.command.environment,
- appDesc.command.classPathEntries,
- appDesc.command.libraryPathEntries,
- appDesc.command.javaOpts)
- CommandUtils.buildCommandSeq(command, memory, sparkHome.getAbsolutePath)
- }
-
/**
* Download and run the executor described in our ApplicationDescription
*/
def fetchAndRunExecutor() {
try {
// Launch the process
- val command = getCommandSeq
+ val builder = CommandUtils.buildProcessBuilder(appDesc.command, memory,
+ sparkHome.getAbsolutePath, substituteVariables)
+ val command = builder.command()
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
- val builder = new ProcessBuilder(command: _*).directory(executorDir)
- val env = builder.environment()
- for ((key, value) <- appDesc.command.environment) {
- env.put(key, value)
- }
+
+ builder.directory(executorDir)
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
- env.put("SPARK_LAUNCH_WITH_SCALA", "0")
+ builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0")
process = builder.start()
val header = "Spark Executor Command: %s\n%s\n\n".format(
command.mkString("\"", "\" \"", "\""), "=" * 40)
@@ -151,7 +141,7 @@ private[spark] class ExecutorRunner(
stdoutAppender = FileAppender(process.getInputStream, stdout, conf)
val stderr = new File(executorDir, "stderr")
- Files.write(header, stderr, Charsets.UTF_8)
+ Files.write(header, stderr, UTF_8)
stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
state = ExecutorState.RUNNING
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 9b52cb06fb6fa..f1f66d0903f1c 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -20,12 +20,14 @@ package org.apache.spark.deploy.worker
import java.io.File
import java.io.IOException
import java.text.SimpleDateFormat
-import java.util.Date
+import java.util.{UUID, Date}
+import java.util.concurrent.TimeUnit
import scala.collection.JavaConversions._
import scala.collection.mutable.HashMap
import scala.concurrent.duration._
import scala.language.postfixOps
+import scala.util.Random
import akka.actor._
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
@@ -64,8 +66,22 @@ private[spark] class Worker(
// Send a heartbeat every (heartbeat timeout) / 4 milliseconds
val HEARTBEAT_MILLIS = conf.getLong("spark.worker.timeout", 60) * 1000 / 4
- val REGISTRATION_TIMEOUT = 20.seconds
- val REGISTRATION_RETRIES = 3
+ // Model retries to connect to the master, after Hadoop's model.
+ // The first six attempts to reconnect are in shorter intervals (between 5 and 15 seconds)
+ // Afterwards, the next 10 attempts are between 30 and 90 seconds.
+ // A bit of randomness is introduced so that not all of the workers attempt to reconnect at
+ // the same time.
+ val INITIAL_REGISTRATION_RETRIES = 6
+ val TOTAL_REGISTRATION_RETRIES = INITIAL_REGISTRATION_RETRIES + 10
+ val FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND = 0.500
+ val REGISTRATION_RETRY_FUZZ_MULTIPLIER = {
+ val randomNumberGenerator = new Random(UUID.randomUUID.getMostSignificantBits)
+ randomNumberGenerator.nextDouble + FUZZ_MULTIPLIER_INTERVAL_LOWER_BOUND
+ }
+ val INITIAL_REGISTRATION_RETRY_INTERVAL = (math.round(10 *
+ REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
+ val PROLONGED_REGISTRATION_RETRY_INTERVAL = (math.round(60
+ * REGISTRATION_RETRY_FUZZ_MULTIPLIER)).seconds
val CLEANUP_ENABLED = conf.getBoolean("spark.worker.cleanup.enabled", false)
// How often worker will clean up old app folders
@@ -103,6 +119,7 @@ private[spark] class Worker(
var coresUsed = 0
var memoryUsed = 0
+ var connectionAttemptCount = 0
val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr)
val workerSource = new WorkerSource(this)
@@ -158,7 +175,7 @@ private[spark] class Worker(
connected = true
}
- def tryRegisterAllMasters() {
+ private def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
@@ -166,26 +183,47 @@ private[spark] class Worker(
}
}
- def registerWithMaster() {
- tryRegisterAllMasters()
- var retries = 0
- registrationRetryTimer = Some {
- context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
- Utils.tryOrExit {
- retries += 1
- if (registered) {
- registrationRetryTimer.foreach(_.cancel())
- } else if (retries >= REGISTRATION_RETRIES) {
- logError("All masters are unresponsive! Giving up.")
- System.exit(1)
- } else {
- tryRegisterAllMasters()
+ private def retryConnectToMaster() {
+ Utils.tryOrExit {
+ connectionAttemptCount += 1
+ if (registered) {
+ registrationRetryTimer.foreach(_.cancel())
+ registrationRetryTimer = None
+ } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
+ logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
+ tryRegisterAllMasters()
+ if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
+ registrationRetryTimer.foreach(_.cancel())
+ registrationRetryTimer = Some {
+ context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL,
+ PROLONGED_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster)
}
}
+ } else {
+ logError("All masters are unresponsive! Giving up.")
+ System.exit(1)
}
}
}
+ def registerWithMaster() {
+ // DisassociatedEvent may be triggered multiple times, so don't attempt registration
+ // if there are outstanding registration attempts scheduled.
+ registrationRetryTimer match {
+ case None =>
+ registered = false
+ tryRegisterAllMasters()
+ connectionAttemptCount = 0
+ registrationRetryTimer = Some {
+ context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL,
+ INITIAL_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster)
+ }
+ case Some(_) =>
+ logInfo("Not spawning another attempt to register with the master, since there is an" +
+ " attempt scheduled already.")
+ }
+ }
+
override def receiveWithLogging = {
case RegisteredWorker(masterUrl, masterWebUiUrl) =>
logInfo("Successfully registered with master " + masterUrl)
@@ -243,6 +281,10 @@ private[spark] class Worker(
System.exit(1)
}
+ case ReconnectWorker(masterUrl) =>
+ logInfo(s"Master with url $masterUrl requested this worker to reconnect.")
+ registerWithMaster()
+
case LaunchExecutor(masterUrl, appId, execId, appDesc, cores_, memory_) =>
if (masterUrl != activeMasterUrl) {
logWarning("Invalid Master (" + masterUrl + ") attempted to launch executor.")
@@ -362,9 +404,10 @@ private[spark] class Worker(
}
}
- def masterDisconnected() {
+ private def masterDisconnected() {
logError("Connection to master failed! Waiting for master to reconnect...")
connected = false
+ registerWithMaster()
}
def generateWorkerId(): String = {
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index c40a3e16675ad..3711824a40cfc 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import scala.concurrent.Await
-import akka.actor.{Actor, ActorSelection, Props}
+import akka.actor.{Actor, ActorSelection, ActorSystem, Props}
import akka.pattern.Patterns
import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent}
@@ -38,7 +38,8 @@ private[spark] class CoarseGrainedExecutorBackend(
executorId: String,
hostPort: String,
cores: Int,
- sparkProperties: Seq[(String, String)])
+ sparkProperties: Seq[(String, String)],
+ actorSystem: ActorSystem)
extends Actor with ActorLogReceive with ExecutorBackend with Logging {
Utils.checkHostPort(hostPort, "Expected hostport")
@@ -57,8 +58,8 @@ private[spark] class CoarseGrainedExecutorBackend(
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
// Make this host instead of hostPort ?
- executor = new Executor(executorId, Utils.parseHostPort(hostPort)._1, sparkProperties,
- false)
+ val (hostname, _) = Utils.parseHostPort(hostPort)
+ executor = new Executor(executorId, hostname, sparkProperties, isLocal = false, actorSystem)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -130,12 +131,13 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Create a new ActorSystem using driver's Spark properties to run the backend.
val driverConf = new SparkConf().setAll(props)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf))
+ SparkEnv.executorActorSystemName,
+ hostname, port, driverConf, new SecurityManager(driverConf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
Props(classOf[CoarseGrainedExecutorBackend],
- driverUrl, executorId, sparkHostPort, cores, props),
+ driverUrl, executorId, sparkHostPort, cores, props, actorSystem),
name = "Executor")
workerUrl.foreach { url =>
actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 616c7e6a46368..8b095e23f32ff 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -26,21 +26,25 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
+import akka.actor.{Props, ActorSystem}
+
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
+ * In coarse-grained mode, an existing actor system is provided.
*/
private[spark] class Executor(
executorId: String,
slaveHostname: String,
properties: Seq[(String, String)],
- isLocal: Boolean = false)
+ isLocal: Boolean = false,
+ actorSystem: ActorSystem = null)
extends Logging
{
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
@@ -68,17 +72,18 @@ private[spark] class Executor(
// Setup an uncaught exception handler for non-local mode.
// Make any thread terminations due to uncaught exceptions kill the entire
// executor process to avoid surprising stalls.
- Thread.setDefaultUncaughtExceptionHandler(ExecutorUncaughtExceptionHandler)
+ Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler)
}
val executorSource = new ExecutorSource(this, executorId)
// Initialize Spark environment (using system properties read above)
- conf.set("spark.executor.id", "executor." + executorId)
+ conf.set("spark.executor.id", executorId)
private val env = {
if (!isLocal) {
- val _env = SparkEnv.create(conf, executorId, slaveHostname, 0,
- isDriver = false, isLocal = false)
+ val port = conf.getInt("spark.executor.port", 0)
+ val _env = SparkEnv.createExecutorEnv(
+ conf, executorId, slaveHostname, port, isLocal, actorSystem)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
_env
@@ -87,6 +92,10 @@ private[spark] class Executor(
}
}
+ // Create an actor for receiving RPCs from the driver
+ private val executorActor = env.actorSystem.actorOf(
+ Props(new ExecutorActor(executorId)), "ExecutorActor")
+
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
@@ -99,6 +108,9 @@ private[spark] class Executor(
// to send the result back.
private val akkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf)
+ // Limit of bytes for total size of results (default is 1GB)
+ private val maxResultSize = Utils.getMaxResultSize(conf)
+
// Start worker thread pool
val threadPool = Utils.newDaemonCachedThreadPool("Executor task launch worker")
@@ -123,6 +135,7 @@ private[spark] class Executor(
def stop() {
env.metricsSystem.report()
+ env.actorSystem.stop(executorActor)
isStopped = true
threadPool.shutdown()
if (!isLocal) {
@@ -205,25 +218,27 @@ private[spark] class Executor(
val resultSize = serializedDirectResult.limit
// directSend = sending directly back to the driver
- val (serializedResult, directSend) = {
- if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
+ val serializedResult = {
+ if (resultSize > maxResultSize) {
+ logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
+ s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
+ s"dropping it.")
+ ser.serialize(new IndirectTaskResult[Any](TaskResultBlockId(taskId), resultSize))
+ } else if (resultSize >= akkaFrameSize - AkkaUtils.reservedSizeBytes) {
val blockId = TaskResultBlockId(taskId)
env.blockManager.putBytes(
blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
- (ser.serialize(new IndirectTaskResult[Any](blockId)), false)
+ logInfo(
+ s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
+ ser.serialize(new IndirectTaskResult[Any](blockId, resultSize))
} else {
- (serializedDirectResult, true)
+ logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
+ serializedDirectResult
}
}
execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
- if (directSend) {
- logInfo(s"Finished $taskName (TID $taskId). $resultSize bytes result sent to driver")
- } else {
- logInfo(
- s"Finished $taskName (TID $taskId). $resultSize bytes result sent via BlockManager)")
- }
} catch {
case ffe: FetchFailedException => {
val reason = ffe.toTaskEndReason
@@ -253,7 +268,7 @@ private[spark] class Executor(
// Don't forcibly exit unless the exception was inherently fatal, to avoid
// stopping other tasks unnecessarily.
if (Utils.isFatalError(t)) {
- ExecutorUncaughtExceptionHandler.uncaughtException(t)
+ SparkUncaughtExceptionHandler.uncaughtException(t)
}
}
} finally {
@@ -322,14 +337,16 @@ private[spark] class Executor(
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager,
- hadoopConf)
+ // Fetch file with useCache mode, close cache for local mode.
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
logInfo("Fetching " + name + " with timestamp " + timestamp)
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf, env.securityManager,
- hadoopConf)
+ // Fetch file with useCache mode, close cache for local mode.
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentJars(name) = timestamp
// Add it to our class loader
val localName = name.split("/").last
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
new file mode 100644
index 0000000000000..41925f7e97e84
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import akka.actor.Actor
+import org.apache.spark.Logging
+
+import org.apache.spark.util.{Utils, ActorLogReceive}
+
+/**
+ * Driver -> Executor message to trigger a thread dump.
+ */
+private[spark] case object TriggerThreadDump
+
+/**
+ * Actor that runs inside of executors to enable driver -> executor RPC.
+ */
+private[spark]
+class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging {
+
+ override def receiveWithLogging = {
+ case TriggerThreadDump =>
+ sender ! Utils.getThreadDump()
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
index 38be2c58b333f..52862ae0ca5e4 100644
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorExitCode.scala
@@ -17,6 +17,8 @@
package org.apache.spark.executor
+import org.apache.spark.util.SparkExitCode._
+
/**
* These are exit codes that executors should use to provide the master with information about
* executor failures assuming that cluster management framework can capture the exit codes (but
@@ -27,16 +29,6 @@ package org.apache.spark.executor
*/
private[spark]
object ExecutorExitCode {
- /** The default uncaught exception handler was reached. */
- val UNCAUGHT_EXCEPTION = 50
-
- /** The default uncaught exception handler was called and an exception was encountered while
- logging the exception. */
- val UNCAUGHT_EXCEPTION_TWICE = 51
-
- /** The default uncaught exception handler was reached, and the uncaught exception was an
- OutOfMemoryError. */
- val OOM = 52
/** DiskStore failed to create a local temporary directory after many attempts. */
val DISK_STORE_FAILED_TO_CREATE_DIR = 53
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala
deleted file mode 100644
index b0e984c03964c..0000000000000
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorUncaughtExceptionHandler.scala
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor
-
-import org.apache.spark.Logging
-import org.apache.spark.util.Utils
-
-/**
- * The default uncaught exception handler for Executors terminates the whole process, to avoid
- * getting into a bad state indefinitely. Since Executors are relatively lightweight, it's better
- * to fail fast when things go wrong.
- */
-private[spark] object ExecutorUncaughtExceptionHandler
- extends Thread.UncaughtExceptionHandler with Logging {
-
- override def uncaughtException(thread: Thread, exception: Throwable) {
- try {
- logError("Uncaught exception in thread " + thread, exception)
-
- // We may have been called from a shutdown hook. If so, we must not call System.exit().
- // (If we do, we will deadlock.)
- if (!Utils.inShutdown()) {
- if (exception.isInstanceOf[OutOfMemoryError]) {
- System.exit(ExecutorExitCode.OOM)
- } else {
- System.exit(ExecutorExitCode.UNCAUGHT_EXCEPTION)
- }
- }
- } catch {
- case oom: OutOfMemoryError => Runtime.getRuntime.halt(ExecutorExitCode.OOM)
- case t: Throwable => Runtime.getRuntime.halt(ExecutorExitCode.UNCAUGHT_EXCEPTION_TWICE)
- }
- }
-
- def uncaughtException(exception: Throwable) {
- uncaughtException(Thread.currentThread(), exception)
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 3e49b6235aff3..57bc2b40cec44 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -169,7 +169,6 @@ case class InputMetrics(readMethod: DataReadMethod.Value) {
var bytesRead: Long = 0L
}
-
/**
* :: DeveloperApi ::
* Metrics pertaining to shuffle data read in a given task.
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
new file mode 100644
index 0000000000000..89b29af2000c8
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+
+/**
+ * Custom Input Format for reading and splitting flat binary files that contain records,
+ * each of which are a fixed size in bytes. The fixed record size is specified through
+ * a parameter recordLength in the Hadoop configuration.
+ */
+private[spark] object FixedLengthBinaryInputFormat {
+ /** Property name to set in Hadoop JobConfs for record length */
+ val RECORD_LENGTH_PROPERTY = "org.apache.spark.input.FixedLengthBinaryInputFormat.recordLength"
+
+ /** Retrieves the record length property from a Hadoop configuration */
+ def getRecordLength(context: JobContext): Int = {
+ context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt
+ }
+}
+
+private[spark] class FixedLengthBinaryInputFormat
+ extends FileInputFormat[LongWritable, BytesWritable] {
+
+ private var recordLength = -1
+
+ /**
+ * Override of isSplitable to ensure initial computation of the record length
+ */
+ override def isSplitable(context: JobContext, filename: Path): Boolean = {
+ if (recordLength == -1) {
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ }
+ if (recordLength <= 0) {
+ println("record length is less than 0, file cannot be split")
+ false
+ } else {
+ true
+ }
+ }
+
+ /**
+ * This input format overrides computeSplitSize() to make sure that each split
+ * only contains full records. Each InputSplit passed to FixedLengthBinaryRecordReader
+ * will start at the first byte of a record, and the last byte will the last byte of a record.
+ */
+ override def computeSplitSize(blockSize: Long, minSize: Long, maxSize: Long): Long = {
+ val defaultSize = super.computeSplitSize(blockSize, minSize, maxSize)
+ // If the default size is less than the length of a record, make it equal to it
+ // Otherwise, make sure the split size is as close to possible as the default size,
+ // but still contains a complete set of records, with the first record
+ // starting at the first byte in the split and the last record ending with the last byte
+ if (defaultSize < recordLength) {
+ recordLength.toLong
+ } else {
+ (Math.floor(defaultSize / recordLength) * recordLength).toLong
+ }
+ }
+
+ /**
+ * Create a FixedLengthBinaryRecordReader
+ */
+ override def createRecordReader(split: InputSplit, context: TaskAttemptContext)
+ : RecordReader[LongWritable, BytesWritable] = {
+ new FixedLengthBinaryRecordReader
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
new file mode 100644
index 0000000000000..5164a74bec4e9
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.IOException
+
+import org.apache.hadoop.fs.FSDataInputStream
+import org.apache.hadoop.io.compress.CompressionCodecFactory
+import org.apache.hadoop.io.{BytesWritable, LongWritable}
+import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
+
+/**
+ * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat.
+ * It uses the record length set in FixedLengthBinaryInputFormat to
+ * read one record at a time from the given InputSplit.
+ *
+ * Each call to nextKeyValue() updates the LongWritable key and BytesWritable value.
+ *
+ * key = record index (Long)
+ * value = the record itself (BytesWritable)
+ */
+private[spark] class FixedLengthBinaryRecordReader
+ extends RecordReader[LongWritable, BytesWritable] {
+
+ private var splitStart: Long = 0L
+ private var splitEnd: Long = 0L
+ private var currentPosition: Long = 0L
+ private var recordLength: Int = 0
+ private var fileInputStream: FSDataInputStream = null
+ private var recordKey: LongWritable = null
+ private var recordValue: BytesWritable = null
+
+ override def close() {
+ if (fileInputStream != null) {
+ fileInputStream.close()
+ }
+ }
+
+ override def getCurrentKey: LongWritable = {
+ recordKey
+ }
+
+ override def getCurrentValue: BytesWritable = {
+ recordValue
+ }
+
+ override def getProgress: Float = {
+ splitStart match {
+ case x if x == splitEnd => 0.0.toFloat
+ case _ => Math.min(
+ ((currentPosition - splitStart) / (splitEnd - splitStart)).toFloat, 1.0
+ ).toFloat
+ }
+ }
+
+ override def initialize(inputSplit: InputSplit, context: TaskAttemptContext) {
+ // the file input
+ val fileSplit = inputSplit.asInstanceOf[FileSplit]
+
+ // the byte position this fileSplit starts at
+ splitStart = fileSplit.getStart
+
+ // splitEnd byte marker that the fileSplit ends at
+ splitEnd = splitStart + fileSplit.getLength
+
+ // the actual file we will be reading from
+ val file = fileSplit.getPath
+ // job configuration
+ val job = context.getConfiguration
+ // check compression
+ val codec = new CompressionCodecFactory(job).getCodec(file)
+ if (codec != null) {
+ throw new IOException("FixedLengthRecordReader does not support reading compressed files")
+ }
+ // get the record length
+ recordLength = FixedLengthBinaryInputFormat.getRecordLength(context)
+ // get the filesystem
+ val fs = file.getFileSystem(job)
+ // open the File
+ fileInputStream = fs.open(file)
+ // seek to the splitStart position
+ fileInputStream.seek(splitStart)
+ // set our current position
+ currentPosition = splitStart
+ }
+
+ override def nextKeyValue(): Boolean = {
+ if (recordKey == null) {
+ recordKey = new LongWritable()
+ }
+ // the key is a linear index of the record, given by the
+ // position the record starts divided by the record length
+ recordKey.set(currentPosition / recordLength)
+ // the recordValue to place the bytes into
+ if (recordValue == null) {
+ recordValue = new BytesWritable(new Array[Byte](recordLength))
+ }
+ // read a record if the currentPosition is less than the split end
+ if (currentPosition < splitEnd) {
+ // setup a buffer to store the record
+ val buffer = recordValue.getBytes
+ fileInputStream.read(buffer, 0, recordLength)
+ // update our current position
+ currentPosition = currentPosition + recordLength
+ // return true
+ return true
+ }
+ false
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
new file mode 100644
index 0000000000000..457472547fcbb
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.input
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+
+import scala.collection.JavaConversions._
+
+import com.google.common.io.ByteStreams
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * A general format for reading whole files in as streams, byte arrays,
+ * or other functions to be added
+ */
+private[spark] abstract class StreamFileInputFormat[T]
+ extends CombineFileInputFormat[String, T]
+{
+ override protected def isSplitable(context: JobContext, file: Path): Boolean = false
+
+ /**
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API
+ * which is set through setMaxSplitSize
+ */
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
+ val files = listStatus(context)
+ val totalLen = files.map { file =>
+ if (file.isDir) 0L else file.getLen
+ }.sum
+
+ val maxSplitSize = Math.ceil(totalLen * 1.0 / files.length).toLong
+ super.setMaxSplitSize(maxSplitSize)
+ }
+
+ def createRecordReader(split: InputSplit, taContext: TaskAttemptContext): RecordReader[String, T]
+
+}
+
+/**
+ * An abstract class of [[org.apache.hadoop.mapreduce.RecordReader RecordReader]]
+ * to reading files out as streams
+ */
+private[spark] abstract class StreamBasedRecordReader[T](
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends RecordReader[String, T] {
+
+ // True means the current file has been processed, then skip it.
+ private var processed = false
+
+ private var key = ""
+ private var value: T = null.asInstanceOf[T]
+
+ override def initialize(split: InputSplit, context: TaskAttemptContext) = {}
+ override def close() = {}
+
+ override def getProgress = if (processed) 1.0f else 0.0f
+
+ override def getCurrentKey = key
+
+ override def getCurrentValue = value
+
+ override def nextKeyValue = {
+ if (!processed) {
+ val fileIn = new PortableDataStream(split, context, index)
+ value = parseStream(fileIn)
+ fileIn.close() // if it has not been open yet, close does nothing
+ key = fileIn.getPath
+ processed = true
+ true
+ } else {
+ false
+ }
+ }
+
+ /**
+ * Parse the stream (and close it afterwards) and return the value as in type T
+ * @param inStream the stream to be read in
+ * @return the data formatted as
+ */
+ def parseStream(inStream: PortableDataStream): T
+}
+
+/**
+ * Reads the record in directly as a stream for other objects to manipulate and handle
+ */
+private[spark] class StreamRecordReader(
+ split: CombineFileSplit,
+ context: TaskAttemptContext,
+ index: Integer)
+ extends StreamBasedRecordReader[PortableDataStream](split, context, index) {
+
+ def parseStream(inStream: PortableDataStream): PortableDataStream = inStream
+}
+
+/**
+ * The format for the PortableDataStream files
+ */
+private[spark] class StreamInputFormat extends StreamFileInputFormat[PortableDataStream] {
+ override def createRecordReader(split: InputSplit, taContext: TaskAttemptContext) = {
+ new CombineFileRecordReader[String, PortableDataStream](
+ split.asInstanceOf[CombineFileSplit], taContext, classOf[StreamRecordReader])
+ }
+}
+
+/**
+ * A class that allows DataStreams to be serialized and moved around by not creating them
+ * until they need to be read
+ * @note TaskAttemptContext is not serializable resulting in the confBytes construct
+ * @note CombineFileSplit is not serializable resulting in the splitBytes construct
+ */
+@Experimental
+class PortableDataStream(
+ @transient isplit: CombineFileSplit,
+ @transient context: TaskAttemptContext,
+ index: Integer)
+ extends Serializable {
+
+ // transient forces file to be reopened after being serialization
+ // it is also used for non-serializable classes
+
+ @transient private var fileIn: DataInputStream = null
+ @transient private var isOpen = false
+
+ private val confBytes = {
+ val baos = new ByteArrayOutputStream()
+ context.getConfiguration.write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ private val splitBytes = {
+ val baos = new ByteArrayOutputStream()
+ isplit.write(new DataOutputStream(baos))
+ baos.toByteArray
+ }
+
+ @transient private lazy val split = {
+ val bais = new ByteArrayInputStream(splitBytes)
+ val nsplit = new CombineFileSplit()
+ nsplit.readFields(new DataInputStream(bais))
+ nsplit
+ }
+
+ @transient private lazy val conf = {
+ val bais = new ByteArrayInputStream(confBytes)
+ val nconf = new Configuration()
+ nconf.readFields(new DataInputStream(bais))
+ nconf
+ }
+ /**
+ * Calculate the path name independently of opening the file
+ */
+ @transient private lazy val path = {
+ val pathp = split.getPath(index)
+ pathp.toString
+ }
+
+ /**
+ * Create a new DataInputStream from the split and context
+ */
+ def open(): DataInputStream = {
+ if (!isOpen) {
+ val pathp = split.getPath(index)
+ val fs = pathp.getFileSystem(conf)
+ fileIn = fs.open(pathp)
+ isOpen = true
+ }
+ fileIn
+ }
+
+ /**
+ * Read the file as a byte array
+ */
+ def toArray(): Array[Byte] = {
+ open()
+ val innerBuffer = ByteStreams.toByteArray(fileIn)
+ close()
+ innerBuffer
+ }
+
+ /**
+ * Close the file (if it is currently open)
+ */
+ def close() = {
+ if (isOpen) {
+ try {
+ fileIn.close()
+ isOpen = false
+ } catch {
+ case ioe: java.io.IOException => // do nothing
+ }
+ }
+ }
+
+ def getPath(): String = path
+}
+
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index 4cb450577796a..183bce3d8d8d3 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -48,9 +48,10 @@ private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[Str
}
/**
- * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API.
+ * Allow minPartitions set by end-user in order to keep compatibility with old Hadoop API,
+ * which is set through setMaxSplitSize
*/
- def setMaxSplitSize(context: JobContext, minPartitions: Int) {
+ def setMinPartitions(context: JobContext, minPartitions: Int) {
val files = listStatus(context)
val totalLen = files.map { file =>
if (file.isDir) 0L else file.getLen
diff --git a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
index e0e91724271c8..1745d52c81923 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockDataManager.scala
@@ -17,20 +17,20 @@
package org.apache.spark.network
-import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+private[spark]
trait BlockDataManager {
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- def getBlockData(blockId: String): Option[ManagedBuffer]
+ def getBlockData(blockId: BlockId): ManagedBuffer
/**
* Put the block locally, using the given storage level.
*/
- def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit
+ def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala b/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
deleted file mode 100644
index 34acaa563ca58..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/BlockFetchingListener.scala
+++ /dev/null
@@ -1,37 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.util.EventListener
-
-
-/**
- * Listener callback interface for [[BlockTransferService.fetchBlocks]].
- */
-trait BlockFetchingListener extends EventListener {
-
- /**
- * Called once per successfully fetched block.
- */
- def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit
-
- /**
- * Called upon failures. For each failure, this is called only once (i.e. not once per block).
- */
- def onBlockFetchFailure(exception: Throwable): Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index 84d991fa6808c..210a581db466e 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -17,13 +17,19 @@
package org.apache.spark.network
-import scala.concurrent.{Await, Future}
-import scala.concurrent.duration.Duration
+import java.io.Closeable
+import java.nio.ByteBuffer
-import org.apache.spark.storage.StorageLevel
+import scala.concurrent.{Promise, Await, Future}
+import scala.concurrent.duration.Duration
+import org.apache.spark.Logging
+import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener}
+import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel}
-abstract class BlockTransferService {
+private[spark]
+abstract class BlockTransferService extends ShuffleClient with Closeable with Logging {
/**
* Initialize the transfer service by giving it the BlockDataManager that can be used to fetch
@@ -34,7 +40,7 @@ abstract class BlockTransferService {
/**
* Tear down the transfer service.
*/
- def stop(): Unit
+ def close(): Unit
/**
* Port number the service is listening on, available only after [[init]] is invoked.
@@ -50,17 +56,15 @@ abstract class BlockTransferService {
* Fetch a sequence of blocks from a remote node asynchronously,
* available only after [[init]] is invoked.
*
- * Note that [[BlockFetchingListener.onBlockFetchSuccess]] is called once per block,
- * while [[BlockFetchingListener.onBlockFetchFailure]] is called once per failure (not per block).
- *
* Note that this API takes a sequence so the implementation can batch requests, and does not
* return a future so the underlying implementation can invoke onBlockFetchSuccess as soon as
* the data of a block is fetched, rather than waiting for all blocks to be fetched.
*/
- def fetchBlocks(
- hostName: String,
+ override def fetchBlocks(
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit
/**
@@ -69,7 +73,7 @@ abstract class BlockTransferService {
def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
@@ -78,40 +82,23 @@ abstract class BlockTransferService {
*
* It is also only available after [[init]] is invoked.
*/
- def fetchBlockSync(hostName: String, port: Int, blockId: String): ManagedBuffer = {
+ def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = {
// A monitor for the thread to wait on.
- val lock = new Object
- @volatile var result: Either[ManagedBuffer, Throwable] = null
- fetchBlocks(hostName, port, Seq(blockId), new BlockFetchingListener {
- override def onBlockFetchFailure(exception: Throwable): Unit = {
- lock.synchronized {
- result = Right(exception)
- lock.notify()
+ val result = Promise[ManagedBuffer]()
+ fetchBlocks(host, port, execId, Array(blockId),
+ new BlockFetchingListener {
+ override def onBlockFetchFailure(blockId: String, exception: Throwable): Unit = {
+ result.failure(exception)
}
- }
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- lock.synchronized {
- result = Left(data)
- lock.notify()
- }
- }
- })
-
- // Sleep until result is no longer null
- lock.synchronized {
- while (result == null) {
- try {
- lock.wait()
- } catch {
- case e: InterruptedException =>
+ override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
+ val ret = ByteBuffer.allocate(data.size.toInt)
+ ret.put(data.nioByteBuffer())
+ ret.flip()
+ result.success(new NioManagedBuffer(ret))
}
- }
- }
+ })
- result match {
- case Left(data) => data
- case Right(e) => throw e
- }
+ Await.result(result.future, Duration.Inf)
}
/**
@@ -123,7 +110,7 @@ abstract class BlockTransferService {
def uploadBlockSync(
hostname: String,
port: Int,
- blockId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
diff --git a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala b/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
deleted file mode 100644
index 4c9ca97a2a6b7..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/ManagedBuffer.scala
+++ /dev/null
@@ -1,160 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network
-
-import java.io._
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-import java.nio.channels.FileChannel.MapMode
-
-import scala.util.Try
-
-import com.google.common.io.ByteStreams
-import io.netty.buffer.{ByteBufInputStream, ByteBuf}
-
-import org.apache.spark.util.{ByteBufferInputStream, Utils}
-
-
-/**
- * This interface provides an immutable view for data in the form of bytes. The implementation
- * should specify how the data is provided:
- *
- * - FileSegmentManagedBuffer: data backed by part of a file
- * - NioByteBufferManagedBuffer: data backed by a NIO ByteBuffer
- * - NettyByteBufManagedBuffer: data backed by a Netty ByteBuf
- */
-sealed abstract class ManagedBuffer {
- // Note that all the methods are defined with parenthesis because their implementations can
- // have side effects (io operations).
-
- /** Number of bytes of the data. */
- def size: Long
-
- /**
- * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
- * returned ByteBuffer should not affect the content of this buffer.
- */
- def nioByteBuffer(): ByteBuffer
-
- /**
- * Exposes this buffer's data as an InputStream. The underlying implementation does not
- * necessarily check for the length of bytes read, so the caller is responsible for making sure
- * it does not go over the limit.
- */
- def inputStream(): InputStream
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a segment in a file
- */
-final class FileSegmentManagedBuffer(val file: File, val offset: Long, val length: Long)
- extends ManagedBuffer {
-
- /**
- * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
- * Avoid unless there's a good reason not to.
- */
- private val MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
-
- override def size: Long = length
-
- override def nioByteBuffer(): ByteBuffer = {
- var channel: FileChannel = null
- try {
- channel = new RandomAccessFile(file, "r").getChannel
- // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
- if (length < MIN_MEMORY_MAP_BYTES) {
- val buf = ByteBuffer.allocate(length.toInt)
- channel.read(buf, offset)
- buf.flip()
- buf
- } else {
- channel.map(MapMode.READ_ONLY, offset, length)
- }
- } catch {
- case e: IOException =>
- Try(channel.size).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- } finally {
- if (channel != null) {
- Utils.tryLog(channel.close())
- }
- }
- }
-
- override def inputStream(): InputStream = {
- var is: FileInputStream = null
- try {
- is = new FileInputStream(file)
- is.skip(offset)
- ByteStreams.limit(is, length)
- } catch {
- case e: IOException =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- Try(file.length).toOption match {
- case Some(fileLen) =>
- throw new IOException(s"Error in reading $this (actual file length $fileLen)", e)
- case None =>
- throw new IOException(s"Error in opening $this", e)
- }
- case e: Throwable =>
- if (is != null) {
- Utils.tryLog(is.close())
- }
- throw e
- }
- }
-
- override def toString: String = s"${getClass.getName}($file, $offset, $length)"
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by [[java.nio.ByteBuffer]].
- */
-final class NioByteBufferManagedBuffer(buf: ByteBuffer) extends ManagedBuffer {
-
- override def size: Long = buf.remaining()
-
- override def nioByteBuffer() = buf.duplicate()
-
- override def inputStream() = new ByteBufferInputStream(buf)
-}
-
-
-/**
- * A [[ManagedBuffer]] backed by a Netty [[ByteBuf]].
- */
-final class NettyByteBufManagedBuffer(buf: ByteBuf) extends ManagedBuffer {
-
- override def size: Long = buf.readableBytes()
-
- override def nioByteBuffer() = buf.nioBuffer()
-
- override def inputStream() = new ByteBufInputStream(buf)
-
- // TODO(rxin): Promote this to top level ManagedBuffer interface and add documentation for it.
- def release(): Unit = buf.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
new file mode 100644
index 0000000000000..1950e7bd634ee
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import java.nio.ByteBuffer
+
+import scala.collection.JavaConversions._
+
+import org.apache.spark.Logging
+import org.apache.spark.network.BlockDataManager
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
+import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
+import org.apache.spark.network.shuffle.ShuffleStreamHandle
+import org.apache.spark.serializer.Serializer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+
+object NettyMessages {
+ /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
+ case class OpenBlocks(blockIds: Seq[BlockId])
+
+ /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
+ case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)
+}
+
+/**
+ * Serves requests to open blocks by simply registering one chunk per block requested.
+ * Handles opening and uploading arbitrary BlockManager blocks.
+ *
+ * Opened blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk
+ * is equivalent to one Spark-level shuffle block.
+ */
+class NettyBlockRpcServer(
+ serializer: Serializer,
+ blockManager: BlockDataManager)
+ extends RpcHandler with Logging {
+
+ import NettyMessages._
+
+ private val streamManager = new OneForOneStreamManager()
+
+ override def receive(
+ client: TransportClient,
+ messageBytes: Array[Byte],
+ responseContext: RpcResponseCallback): Unit = {
+ val ser = serializer.newInstance()
+ val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
+ logTrace(s"Received request: $message")
+
+ message match {
+ case OpenBlocks(blockIds) =>
+ val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
+ val streamId = streamManager.registerStream(blocks.iterator)
+ logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
+ responseContext.onSuccess(
+ ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())
+
+ case UploadBlock(blockId, blockData, level) =>
+ blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level)
+ responseContext.onSuccess(new Array[Byte](0))
+ }
+ }
+
+ override def getStreamManager(): StreamManager = streamManager
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
new file mode 100644
index 0000000000000..1c4327cf13b51
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import scala.concurrent.{Future, Promise}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
+import org.apache.spark.network.server._
+import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage.{BlockId, StorageLevel}
+import org.apache.spark.util.Utils
+
+/**
+ * A BlockTransferService that uses Netty to fetch a set of blocks at at time.
+ */
+class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+ // TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
+ val serializer = new JavaSerializer(conf)
+
+ private[this] var transportContext: TransportContext = _
+ private[this] var server: TransportServer = _
+ private[this] var clientFactory: TransportClientFactory = _
+
+ override def init(blockDataManager: BlockDataManager): Unit = {
+ val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler)
+ clientFactory = transportContext.createClientFactory()
+ server = transportContext.createServer()
+ logInfo("Server created on " + server.getPort)
+ }
+
+ override def fetchBlocks(
+ host: String,
+ port: Int,
+ execId: String,
+ blockIds: Array[String],
+ listener: BlockFetchingListener): Unit = {
+ logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
+ try {
+ val client = clientFactory.createClient(host, port)
+ new OneForOneBlockFetcher(client, blockIds.toArray, listener)
+ .start(OpenBlocks(blockIds.map(BlockId.apply)))
+ } catch {
+ case e: Exception =>
+ logError("Exception while beginning fetchBlocks", e)
+ blockIds.foreach(listener.onBlockFetchFailure(_, e))
+ }
+ }
+
+ override def hostName: String = Utils.localHostName()
+
+ override def port: Int = server.getPort
+
+ override def uploadBlock(
+ hostname: String,
+ port: Int,
+ blockId: BlockId,
+ blockData: ManagedBuffer,
+ level: StorageLevel): Future[Unit] = {
+ val result = Promise[Unit]()
+ val client = clientFactory.createClient(hostname, port)
+
+ // Convert or copy nio buffer into array in order to serialize it.
+ val nioBuffer = blockData.nioByteBuffer()
+ val array = if (nioBuffer.hasArray) {
+ nioBuffer.array()
+ } else {
+ val data = new Array[Byte](nioBuffer.remaining())
+ nioBuffer.get(data)
+ data
+ }
+
+ val ser = serializer.newInstance()
+ client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(),
+ new RpcResponseCallback {
+ override def onSuccess(response: Array[Byte]): Unit = {
+ logTrace(s"Successfully uploaded block $blockId")
+ result.success()
+ }
+ override def onFailure(e: Throwable): Unit = {
+ logError(s"Error while uploading block $blockId", e)
+ result.failure(e)
+ }
+ })
+
+ result.future
+ }
+
+ override def close(): Unit = {
+ server.close()
+ clientFactory.close()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
deleted file mode 100644
index b5870152c5a64..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyConfig.scala
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import org.apache.spark.SparkConf
-
-/**
- * A central location that tracks all the settings we exposed to users.
- */
-private[spark]
-class NettyConfig(conf: SparkConf) {
-
- /** Port the server listens on. Default to a random port. */
- private[netty] val serverPort = conf.getInt("spark.shuffle.io.port", 0)
-
- /** IO mode: nio, oio, epoll, or auto (try epoll first and then nio). */
- private[netty] val ioMode = conf.get("spark.shuffle.io.mode", "nio").toLowerCase
-
- /** Connect timeout in secs. Default 60 secs. */
- private[netty] val connectTimeoutMs = conf.getInt("spark.shuffle.io.connectionTimeout", 60) * 1000
-
- /**
- * Percentage of the desired amount of time spent for I/O in the child event loops.
- * Only applicable in nio and epoll.
- */
- private[netty] val ioRatio = conf.getInt("spark.shuffle.io.netty.ioRatio", 80)
-
- /** Requested maximum length of the queue of incoming connections. */
- private[netty] val backLog: Option[Int] = conf.getOption("spark.shuffle.io.backLog").map(_.toInt)
-
- /**
- * Receive buffer size (SO_RCVBUF).
- * Note: the optimal size for receive buffer and send buffer should be
- * latency * network_bandwidth.
- * Assuming latency = 1ms, network_bandwidth = 10Gbps
- * buffer size should be ~ 1.25MB
- */
- private[netty] val receiveBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-
- /** Send buffer size (SO_SNDBUF). */
- private[netty] val sendBuf: Option[Int] =
- conf.getOption("spark.shuffle.io.sendBuffer").map(_.toInt)
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala b/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala
deleted file mode 100644
index 0d7695072a7b1..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/PathResolver.scala
+++ /dev/null
@@ -1,25 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import org.apache.spark.storage.{BlockId, FileSegment}
-
-trait PathResolver {
- /** Get the file segment in which the given block resides. */
- def getBlockLocation(blockId: BlockId): FileSegment
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
new file mode 100644
index 0000000000000..9fa4fa77b8817
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.netty
+
+import org.apache.spark.SparkConf
+import org.apache.spark.network.util.{TransportConf, ConfigProvider}
+
+/**
+ * Utility for creating a [[TransportConf]] from a [[SparkConf]].
+ */
+object SparkTransportConf {
+ def fromSparkConf(conf: SparkConf): TransportConf = {
+ new TransportConf(new ConfigProvider {
+ override def get(name: String): String = conf.get(name)
+ })
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
deleted file mode 100644
index e28219dd7745b..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockClientListener.scala
+++ /dev/null
@@ -1,29 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.util.EventListener
-
-
-trait BlockClientListener extends EventListener {
-
- def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit
-
- def onFetchFailure(blockId: String, errorMsg: String): Unit
-
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
deleted file mode 100644
index 5aea7ba2f3673..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClient.scala
+++ /dev/null
@@ -1,132 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.util.concurrent.TimeoutException
-
-import io.netty.bootstrap.Bootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.{ChannelFutureListener, ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.handler.codec.LengthFieldBasedFrameDecoder
-import io.netty.handler.codec.string.StringEncoder
-import io.netty.util.CharsetUtil
-
-import org.apache.spark.Logging
-
-/**
- * Client for fetching data blocks from [[org.apache.spark.network.netty.server.BlockServer]].
- * Use [[BlockFetchingClientFactory]] to instantiate this client.
- *
- * The constructor blocks until a connection is successfully established.
- *
- * See [[org.apache.spark.network.netty.server.BlockServer]] for client/server protocol.
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-@throws[TimeoutException]
-private[spark]
-class BlockFetchingClient(factory: BlockFetchingClientFactory, hostname: String, port: Int)
- extends Logging {
-
- private val handler = new BlockFetchingClientHandler
-
- /** Netty Bootstrap for creating the TCP connection. */
- private val bootstrap: Bootstrap = {
- val b = new Bootstrap
- b.group(factory.workerGroup)
- .channel(factory.socketChannelClass)
- // Use pooled buffers to reduce temporary buffer allocation
- .option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- // Disable Nagle's Algorithm since we don't want packets to wait
- .option(ChannelOption.TCP_NODELAY, java.lang.Boolean.TRUE)
- .option(ChannelOption.SO_KEEPALIVE, java.lang.Boolean.TRUE)
- .option[Integer](ChannelOption.CONNECT_TIMEOUT_MILLIS, factory.conf.connectTimeoutMs)
-
- b.handler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("encoder", new StringEncoder(CharsetUtil.UTF_8))
- // maxFrameLength = 2G, lengthFieldOffset = 0, lengthFieldLength = 4
- .addLast("framedLengthDecoder", new LengthFieldBasedFrameDecoder(Int.MaxValue, 0, 4))
- .addLast("handler", handler)
- }
- })
- b
- }
-
- /** Netty ChannelFuture for the connection. */
- private val cf: ChannelFuture = bootstrap.connect(hostname, port)
- if (!cf.awaitUninterruptibly(factory.conf.connectTimeoutMs)) {
- throw new TimeoutException(
- s"Connecting to $hostname:$port timed out (${factory.conf.connectTimeoutMs} ms)")
- }
-
- /**
- * Ask the remote server for a sequence of blocks, and execute the callback.
- *
- * Note that this is asynchronous and returns immediately. Upstream caller should throttle the
- * rate of fetching; otherwise we could run out of memory.
- *
- * @param blockIds sequence of block ids to fetch.
- * @param listener callback to fire on fetch success / failure.
- */
- def fetchBlocks(blockIds: Seq[String], listener: BlockClientListener): Unit = {
- // It's best to limit the number of "write" calls since it needs to traverse the whole pipeline.
- // It's also best to limit the number of "flush" calls since it requires system calls.
- // Let's concatenate the string and then call writeAndFlush once.
- // This is also why this implementation might be more efficient than multiple, separate
- // fetch block calls.
- var startTime: Long = 0
- logTrace {
- startTime = System.nanoTime
- s"Sending request $blockIds to $hostname:$port"
- }
-
- blockIds.foreach { blockId =>
- handler.addRequest(blockId, listener)
- }
-
- val writeFuture = cf.channel().writeAndFlush(blockIds.mkString("\n") + "\n")
- writeFuture.addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture): Unit = {
- if (future.isSuccess) {
- logTrace {
- val timeTaken = (System.nanoTime - startTime).toDouble / 1000000
- s"Sending request $blockIds to $hostname:$port took $timeTaken ms"
- }
- } else {
- // Fail all blocks.
- val errorMsg =
- s"Failed to send request $blockIds to $hostname:$port: ${future.cause.getMessage}"
- logError(errorMsg, future.cause)
- blockIds.foreach { blockId =>
- listener.onFetchFailure(blockId, errorMsg)
- handler.removeRequest(blockId)
- }
- }
- }
- })
- }
-
- def waitForClose(): Unit = {
- cf.channel().closeFuture().sync()
- }
-
- def close(): Unit = cf.channel().close()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
deleted file mode 100644
index 2b28402c52b49..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientFactory.scala
+++ /dev/null
@@ -1,99 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.nio.NioSocketChannel
-import io.netty.channel.socket.oio.OioSocketChannel
-import io.netty.channel.{EventLoopGroup, Channel}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.util.Utils
-
-/**
- * Factory for creating [[BlockFetchingClient]] by using createClient. This factory reuses
- * the worker thread pool for Netty.
- *
- * Concurrency: createClient is safe to be called from multiple threads concurrently.
- */
-private[spark]
-class BlockFetchingClientFactory(val conf: NettyConfig) {
-
- def this(sparkConf: SparkConf) = this(new NettyConfig(sparkConf))
-
- /** A thread factory so the threads are named (for debugging). */
- val threadFactory = Utils.namedThreadFactory("spark-shuffle-client")
-
- /** The following two are instantiated by the [[init]] method, depending ioMode. */
- var socketChannelClass: Class[_ <: Channel] = _
- var workerGroup: EventLoopGroup = _
-
- init()
-
- /** Initialize [[socketChannelClass]] and [[workerGroup]] based on ioMode. */
- private def init(): Unit = {
- def initOio(): Unit = {
- socketChannelClass = classOf[OioSocketChannel]
- workerGroup = new OioEventLoopGroup(0, threadFactory)
- }
- def initNio(): Unit = {
- socketChannelClass = classOf[NioSocketChannel]
- workerGroup = new NioEventLoopGroup(0, threadFactory)
- }
- def initEpoll(): Unit = {
- socketChannelClass = classOf[EpollSocketChannel]
- workerGroup = new EpollEventLoopGroup(0, threadFactory)
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
- }
-
- /**
- * Create a new BlockFetchingClient connecting to the given remote host / port.
- *
- * This blocks until a connection is successfully established.
- *
- * Concurrency: This method is safe to call from multiple threads.
- */
- def createClient(remoteHost: String, remotePort: Int): BlockFetchingClient = {
- new BlockFetchingClient(this, remoteHost, remotePort)
- }
-
- def stop(): Unit = {
- if (workerGroup != null) {
- workerGroup.shutdownGracefully()
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
deleted file mode 100644
index 83265b164299d..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandler.scala
+++ /dev/null
@@ -1,103 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
-
-import org.apache.spark.Logging
-
-
-/**
- * Handler that processes server responses. It uses the protocol documented in
- * [[org.apache.spark.network.netty.server.BlockServer]].
- *
- * Concurrency: thread safe and can be called from multiple threads.
- */
-private[client]
-class BlockFetchingClientHandler extends SimpleChannelInboundHandler[ByteBuf] with Logging {
-
- /** Tracks the list of outstanding requests and their listeners on success/failure. */
- private val outstandingRequests = java.util.Collections.synchronizedMap {
- new java.util.HashMap[String, BlockClientListener]
- }
-
- def addRequest(blockId: String, listener: BlockClientListener): Unit = {
- outstandingRequests.put(blockId, listener)
- }
-
- def removeRequest(blockId: String): Unit = {
- outstandingRequests.remove(blockId)
- }
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- val errorMsg = s"Exception in connection from ${ctx.channel.remoteAddress}: ${cause.getMessage}"
- logError(errorMsg, cause)
-
- // Fire the failure callback for all outstanding blocks
- outstandingRequests.synchronized {
- val iter = outstandingRequests.entrySet().iterator()
- while (iter.hasNext) {
- val entry = iter.next()
- entry.getValue.onFetchFailure(entry.getKey, errorMsg)
- }
- outstandingRequests.clear()
- }
-
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, in: ByteBuf) {
- val totalLen = in.readInt()
- val blockIdLen = in.readInt()
- val blockIdBytes = new Array[Byte](math.abs(blockIdLen))
- in.readBytes(blockIdBytes)
- val blockId = new String(blockIdBytes)
- val blockSize = totalLen - math.abs(blockIdLen) - 4
-
- def server = ctx.channel.remoteAddress.toString
-
- // blockIdLen is negative when it is an error message.
- if (blockIdLen < 0) {
- val errorMessageBytes = new Array[Byte](blockSize)
- in.readBytes(errorMessageBytes)
- val errorMsg = new String(errorMessageBytes)
- logTrace(s"Received block $blockId ($blockSize B) with error $errorMsg from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchFailure(blockId, errorMsg)
- }
- } else {
- logTrace(s"Received block $blockId ($blockSize B) from $server")
-
- val listener = outstandingRequests.get(blockId)
- if (listener == null) {
- // Ignore callback
- logWarning(s"Got a response for block $blockId but it is not in our outstanding requests")
- } else {
- outstandingRequests.remove(blockId)
- listener.onFetchSuccess(blockId, new ReferenceCountedBuffer(in))
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala b/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
deleted file mode 100644
index 9740ee64d1f2d..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/LazyInitIterator.scala
+++ /dev/null
@@ -1,44 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-/**
- * A simple iterator that lazily initializes the underlying iterator.
- *
- * The use case is that sometimes we might have many iterators open at the same time, and each of
- * the iterator might initialize its own buffer (e.g. decompression buffer, deserialization buffer).
- * This could lead to too many buffers open. If this iterator is used, we lazily initialize those
- * buffers.
- */
-private[spark]
-class LazyInitIterator(createIterator: => Iterator[Any]) extends Iterator[Any] {
-
- lazy val proxy = createIterator
-
- override def hasNext: Boolean = {
- val gotNext = proxy.hasNext
- if (!gotNext) {
- close()
- }
- gotNext
- }
-
- override def next(): Any = proxy.next()
-
- def close(): Unit = Unit
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala b/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
deleted file mode 100644
index ea1abf5eccc26..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/client/ReferenceCountedBuffer.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.io.InputStream
-import java.nio.ByteBuffer
-
-import io.netty.buffer.{ByteBuf, ByteBufInputStream}
-
-
-/**
- * A buffer abstraction based on Netty's ByteBuf so we don't expose Netty.
- * This is a Scala value class.
- *
- * The buffer's life cycle is NOT managed by the JVM, and thus requiring explicit declaration of
- * reference by the retain method and release method.
- */
-private[spark]
-class ReferenceCountedBuffer(val underlying: ByteBuf) extends AnyVal {
-
- /** Return the nio ByteBuffer view of the underlying buffer. */
- def byteBuffer(): ByteBuffer = underlying.nioBuffer
-
- /** Creates a new input stream that starts from the current position of the buffer. */
- def inputStream(): InputStream = new ByteBufInputStream(underlying)
-
- /** Increment the reference counter by one. */
- def retain(): Unit = underlying.retain()
-
- /** Decrement the reference counter by one and release the buffer if the ref count is 0. */
- def release(): Unit = underlying.release()
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
deleted file mode 100644
index 162e9cc6828d4..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeader.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-/**
- * Header describing a block. This is used only in the server pipeline.
- *
- * [[BlockServerHandler]] creates this, and [[BlockHeaderEncoder]] encodes it.
- *
- * @param blockSize length of the block content, excluding the length itself.
- * If positive, this is the header for a block (not part of the header).
- * If negative, this is the header and content for an error message.
- * @param blockId block id
- * @param error some error message from reading the block
- */
-private[server]
-class BlockHeader(val blockSize: Int, val blockId: String, val error: Option[String] = None)
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
deleted file mode 100644
index 8e4dda4ef8595..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockHeaderEncoder.scala
+++ /dev/null
@@ -1,47 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.ChannelHandlerContext
-import io.netty.handler.codec.MessageToByteEncoder
-
-/**
- * A simple encoder for BlockHeader. See [[BlockServer]] for the server to client protocol.
- */
-private[server]
-class BlockHeaderEncoder extends MessageToByteEncoder[BlockHeader] {
- override def encode(ctx: ChannelHandlerContext, msg: BlockHeader, out: ByteBuf): Unit = {
- // message = message length (4 bytes) + block id length (4 bytes) + block id + block data
- // message length = block id length (4 bytes) + size of block id + size of block data
- val blockIdBytes = msg.blockId.getBytes
- msg.error match {
- case Some(errorMsg) =>
- val errorBytes = errorMsg.getBytes
- out.writeInt(4 + blockIdBytes.length + errorBytes.size)
- out.writeInt(-blockIdBytes.length) // use negative block id length to represent errors
- out.writeBytes(blockIdBytes) // next is blockId itself
- out.writeBytes(errorBytes) // error message
- case None =>
- out.writeInt(4 + blockIdBytes.length + msg.blockSize)
- out.writeInt(blockIdBytes.length) // First 4 bytes is blockId length
- out.writeBytes(blockIdBytes) // next is blockId itself
- // msg of size blockSize will be written by ServerHandler
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
deleted file mode 100644
index 7b2f9a8d4dfd0..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServer.scala
+++ /dev/null
@@ -1,162 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.net.InetSocketAddress
-
-import io.netty.bootstrap.ServerBootstrap
-import io.netty.buffer.PooledByteBufAllocator
-import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption}
-import io.netty.channel.epoll.{EpollEventLoopGroup, EpollServerSocketChannel}
-import io.netty.channel.nio.NioEventLoopGroup
-import io.netty.channel.oio.OioEventLoopGroup
-import io.netty.channel.socket.SocketChannel
-import io.netty.channel.socket.nio.NioServerSocketChannel
-import io.netty.channel.socket.oio.OioServerSocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-import io.netty.util.CharsetUtil
-
-import org.apache.spark.{Logging, SparkConf}
-import org.apache.spark.network.netty.NettyConfig
-import org.apache.spark.storage.BlockDataProvider
-import org.apache.spark.util.Utils
-
-
-/**
- * Server for serving Spark data blocks.
- * This should be used together with [[org.apache.spark.network.netty.client.BlockFetchingClient]].
- *
- * Protocol for requesting blocks (client to server):
- * One block id per line, e.g. to request 3 blocks: "block1\nblock2\nblock3\n"
- *
- * Protocol for sending blocks (server to client):
- * frame-length (4 bytes), block-id-length (4 bytes), block-id, block-data.
- *
- * frame-length should not include the length of itself.
- * If block-id-length is negative, then this is an error message rather than block-data. The real
- * length is the absolute value of the frame-length.
- *
- */
-private[spark]
-class BlockServer(conf: NettyConfig, dataProvider: BlockDataProvider) extends Logging {
-
- def this(sparkConf: SparkConf, dataProvider: BlockDataProvider) = {
- this(new NettyConfig(sparkConf), dataProvider)
- }
-
- def port: Int = _port
-
- def hostName: String = _hostName
-
- private var _port: Int = conf.serverPort
- private var _hostName: String = ""
- private var bootstrap: ServerBootstrap = _
- private var channelFuture: ChannelFuture = _
-
- init()
-
- /** Initialize the server. */
- private def init(): Unit = {
- bootstrap = new ServerBootstrap
- val bossThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-boss")
- val workerThreadFactory = Utils.namedThreadFactory("spark-shuffle-server-worker")
-
- // Use only one thread to accept connections, and 2 * num_cores for worker.
- def initNio(): Unit = {
- val bossGroup = new NioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new NioEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[NioServerSocketChannel])
- }
- def initOio(): Unit = {
- val bossGroup = new OioEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new OioEventLoopGroup(0, workerThreadFactory)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[OioServerSocketChannel])
- }
- def initEpoll(): Unit = {
- val bossGroup = new EpollEventLoopGroup(1, bossThreadFactory)
- val workerGroup = new EpollEventLoopGroup(0, workerThreadFactory)
- workerGroup.setIoRatio(conf.ioRatio)
- bootstrap.group(bossGroup, workerGroup).channel(classOf[EpollServerSocketChannel])
- }
-
- conf.ioMode match {
- case "nio" => initNio()
- case "oio" => initOio()
- case "epoll" => initEpoll()
- case "auto" =>
- // For auto mode, first try epoll (only available on Linux), then nio.
- try {
- initEpoll()
- } catch {
- // TODO: Should we log the throwable? But that always happen on non-Linux systems.
- // Perhaps the right thing to do is to check whether the system is Linux, and then only
- // call initEpoll on Linux.
- case e: Throwable => initNio()
- }
- }
-
- // Use pooled buffers to reduce temporary buffer allocation
- bootstrap.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
- bootstrap.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
-
- // Various (advanced) user-configured settings.
- conf.backLog.foreach { backLog =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_BACKLOG, backLog)
- }
- conf.receiveBuf.foreach { receiveBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_RCVBUF, receiveBuf)
- }
- conf.sendBuf.foreach { sendBuf =>
- bootstrap.option[java.lang.Integer](ChannelOption.SO_SNDBUF, sendBuf)
- }
-
- bootstrap.childHandler(new ChannelInitializer[SocketChannel] {
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
- })
-
- channelFuture = bootstrap.bind(new InetSocketAddress(_port))
- channelFuture.sync()
-
- val addr = channelFuture.channel.localAddress.asInstanceOf[InetSocketAddress]
- _port = addr.getPort
- _hostName = addr.getHostName
- }
-
- /** Shutdown the server. */
- def stop(): Unit = {
- if (channelFuture != null) {
- channelFuture.channel().close().awaitUninterruptibly()
- channelFuture = null
- }
- if (bootstrap != null && bootstrap.group() != null) {
- bootstrap.group().shutdownGracefully()
- }
- if (bootstrap != null && bootstrap.childGroup() != null) {
- bootstrap.childGroup().shutdownGracefully()
- }
- bootstrap = null
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
deleted file mode 100644
index cc70bd0c5c477..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerChannelInitializer.scala
+++ /dev/null
@@ -1,40 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import io.netty.channel.ChannelInitializer
-import io.netty.channel.socket.SocketChannel
-import io.netty.handler.codec.LineBasedFrameDecoder
-import io.netty.handler.codec.string.StringDecoder
-import io.netty.util.CharsetUtil
-import org.apache.spark.storage.BlockDataProvider
-
-
-/** Channel initializer that sets up the pipeline for the BlockServer. */
-private[netty]
-class BlockServerChannelInitializer(dataProvider: BlockDataProvider)
- extends ChannelInitializer[SocketChannel] {
-
- override def initChannel(ch: SocketChannel): Unit = {
- ch.pipeline
- .addLast("frameDecoder", new LineBasedFrameDecoder(1024)) // max block id length 1024
- .addLast("stringDecoder", new StringDecoder(CharsetUtil.UTF_8))
- .addLast("blockHeaderEncoder", new BlockHeaderEncoder)
- .addLast("handler", new BlockServerHandler(dataProvider))
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala b/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
deleted file mode 100644
index 40dd5e5d1a2ac..0000000000000
--- a/core/src/main/scala/org/apache/spark/network/netty/server/BlockServerHandler.scala
+++ /dev/null
@@ -1,140 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.io.FileInputStream
-import java.nio.ByteBuffer
-import java.nio.channels.FileChannel
-
-import io.netty.buffer.Unpooled
-import io.netty.channel._
-
-import org.apache.spark.Logging
-import org.apache.spark.storage.{FileSegment, BlockDataProvider}
-
-
-/**
- * A handler that processes requests from clients and writes block data back.
- *
- * The messages should have been processed by a LineBasedFrameDecoder and a StringDecoder first
- * so channelRead0 is called once per line (i.e. per block id).
- */
-private[server]
-class BlockServerHandler(dataProvider: BlockDataProvider)
- extends SimpleChannelInboundHandler[String] with Logging {
-
- override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
- logError(s"Exception in connection from ${ctx.channel.remoteAddress}", cause)
- ctx.close()
- }
-
- override def channelRead0(ctx: ChannelHandlerContext, blockId: String): Unit = {
- def client = ctx.channel.remoteAddress.toString
-
- // A helper function to send error message back to the client.
- def respondWithError(error: String): Unit = {
- ctx.writeAndFlush(new BlockHeader(-1, blockId, Some(error))).addListener(
- new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (!future.isSuccess) {
- // TODO: Maybe log the success case as well.
- logError(s"Error sending error back to $client", future.cause)
- ctx.close()
- }
- }
- }
- )
- }
-
- def writeFileSegment(segment: FileSegment): Unit = {
- // Send error message back if the block is too large. Even though we are capable of sending
- // large (2G+) blocks, the receiving end cannot handle it so let's fail fast.
- // Once we fixed the receiving end to be able to process large blocks, this should be removed.
- // Also make sure we update BlockHeaderEncoder to support length > 2G.
-
- // See [[BlockHeaderEncoder]] for the way length is encoded.
- if (segment.length + blockId.length + 4 > Int.MaxValue) {
- respondWithError(s"Block $blockId size ($segment.length) greater than 2G")
- return
- }
-
- var fileChannel: FileChannel = null
- try {
- fileChannel = new FileInputStream(segment.file).getChannel
- } catch {
- case e: Exception =>
- logError(
- s"Error opening channel for $blockId in ${segment.file} for request from $client", e)
- respondWithError(e.getMessage)
- }
-
- // Found the block. Send it back.
- if (fileChannel != null) {
- // Write the header and block data. In the case of failures, the listener on the block data
- // write should close the connection.
- ctx.write(new BlockHeader(segment.length.toInt, blockId))
-
- val region = new DefaultFileRegion(fileChannel, segment.offset, segment.length)
- ctx.writeAndFlush(region).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${segment.length} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
- }
-
- def writeByteBuffer(buf: ByteBuffer): Unit = {
- ctx.write(new BlockHeader(buf.remaining, blockId))
- ctx.writeAndFlush(Unpooled.wrappedBuffer(buf)).addListener(new ChannelFutureListener {
- override def operationComplete(future: ChannelFuture) {
- if (future.isSuccess) {
- logTrace(s"Sent block $blockId (${buf.remaining} B) back to $client")
- } else {
- logError(s"Error sending block $blockId to $client; closing connection", future.cause)
- ctx.close()
- }
- }
- })
- }
-
- logTrace(s"Received request from $client to fetch block $blockId")
-
- var blockData: Either[FileSegment, ByteBuffer] = null
-
- // First make sure we can find the block. If not, send error back to the user.
- try {
- blockData = dataProvider.getBlockData(blockId)
- } catch {
- case e: Exception =>
- logError(s"Error opening block $blockId for request from $client", e)
- respondWithError(e.getMessage)
- return
- }
-
- blockData match {
- case Left(segment) => writeFileSegment(segment)
- case Right(buf) => writeByteBuffer(buf)
- }
-
- } // end of channelRead0
-}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index bda4bf50932c3..8408b75bb4d65 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -31,6 +31,8 @@ import scala.concurrent.duration._
import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps
+import com.google.common.base.Charsets.UTF_8
+
import org.apache.spark._
import org.apache.spark.util.Utils
@@ -923,7 +925,7 @@ private[nio] class ConnectionManager(
val errorMsgByteBuf = ackMessage.asInstanceOf[BufferMessage].buffers.head
val errorMsgBytes = new Array[Byte](errorMsgByteBuf.limit())
errorMsgByteBuf.get(errorMsgBytes)
- val errorMsg = new String(errorMsgBytes, "utf-8")
+ val errorMsg = new String(errorMsgBytes, UTF_8)
val e = new IOException(
s"sendMessageReliably failed with ACK that signalled a remote error: $errorMsg")
if (!promise.tryFailure(e)) {
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Message.scala b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
index 3ad04591da658..fb4a979b824c3 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Message.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Message.scala
@@ -22,6 +22,8 @@ import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
+import com.google.common.base.Charsets.UTF_8
+
import org.apache.spark.util.Utils
private[nio] abstract class Message(val typ: Long, val id: Int) {
@@ -92,7 +94,7 @@ private[nio] object Message {
*/
def createErrorMessage(exception: Exception, ackId: Int): BufferMessage = {
val exceptionString = Utils.exceptionString(exception)
- val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes("utf-8"))
+ val serializedExceptionString = ByteBuffer.wrap(exceptionString.getBytes(UTF_8))
val errorMessage = createBufferMessage(serializedExceptionString, ackId)
errorMessage.hasError = true
errorMessage
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index 5add4fc433fb3..f56d165daba55 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -19,12 +19,14 @@ package org.apache.spark.network.nio
import java.nio.ByteBuffer
-import scala.concurrent.Future
-
-import org.apache.spark.{SparkException, Logging, SecurityManager, SparkConf}
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.shuffle.BlockFetchingListener
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
+
+import scala.concurrent.Future
/**
@@ -71,20 +73,21 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
/**
* Tear down the transfer service.
*/
- override def stop(): Unit = {
+ override def close(): Unit = {
if (cm != null) {
cm.stop()
}
}
override def fetchBlocks(
- hostName: String,
+ host: String,
port: Int,
- blockIds: Seq[String],
+ execId: String,
+ blockIds: Array[String],
listener: BlockFetchingListener): Unit = {
checkInit()
- val cmId = new ConnectionManagerId(hostName, port)
+ val cmId = new ConnectionManagerId(host, port)
val blockMessageArray = new BlockMessageArray(blockIds.map { blockId =>
BlockMessage.fromGetBlock(GetBlock(BlockId(blockId)))
})
@@ -96,21 +99,33 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val bufferMessage = message.asInstanceOf[BufferMessage]
val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
- for (blockMessage <- blockMessageArray) {
- if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
- listener.onBlockFetchFailure(
- new SparkException(s"Unexpected message ${blockMessage.getType} received from $cmId"))
- } else {
- val blockId = blockMessage.getId
- val networkSize = blockMessage.getData.limit()
- listener.onBlockFetchSuccess(
- blockId.toString, new NioByteBufferManagedBuffer(blockMessage.getData))
+ // SPARK-4064: In some cases(eg. Remote block was removed) blockMessageArray may be empty.
+ if (blockMessageArray.isEmpty) {
+ blockIds.foreach { id =>
+ listener.onBlockFetchFailure(id, new SparkException(s"Received empty message from $cmId"))
+ }
+ } else {
+ for (blockMessage: BlockMessage <- blockMessageArray) {
+ val msgType = blockMessage.getType
+ if (msgType != BlockMessage.TYPE_GOT_BLOCK) {
+ if (blockMessage.getId != null) {
+ listener.onBlockFetchFailure(blockMessage.getId.toString,
+ new SparkException(s"Unexpected message $msgType received from $cmId"))
+ }
+ } else {
+ val blockId = blockMessage.getId
+ val networkSize = blockMessage.getData.limit()
+ listener.onBlockFetchSuccess(
+ blockId.toString, new NioManagedBuffer(blockMessage.getData))
+ }
}
}
}(cm.futureExecContext)
future.onFailure { case exception =>
- listener.onBlockFetchFailure(exception)
+ blockIds.foreach { blockId =>
+ listener.onBlockFetchFailure(blockId, exception)
+ }
}(cm.futureExecContext)
}
@@ -122,12 +137,12 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
override def uploadBlock(
hostname: String,
port: Int,
- blockId: String,
+ blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel)
: Future[Unit] = {
checkInit()
- val msg = PutBlock(BlockId(blockId), blockData.nioByteBuffer(), level)
+ val msg = PutBlock(blockId, blockData.nioByteBuffer(), level)
val blockMessageArray = new BlockMessageArray(BlockMessage.fromPutBlock(msg))
val remoteCmId = new ConnectionManagerId(hostName, port)
val reply = cm.sendMessageReliably(remoteCmId, blockMessageArray.toBufferMessage)
@@ -149,10 +164,9 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
val responseMessages = blockMessages.map(processBlockMessage).filter(_ != None).map(_.get)
Some(new BlockMessageArray(responseMessages).toBufferMessage)
} catch {
- case e: Exception => {
+ case e: Exception =>
logError("Exception handling buffer message", e)
Some(Message.createErrorMessage(e, msg.id))
- }
}
case otherMessage: Any =>
@@ -167,13 +181,13 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
case BlockMessage.TYPE_PUT_BLOCK =>
val msg = PutBlock(blockMessage.getId, blockMessage.getData, blockMessage.getLevel)
logDebug("Received [" + msg + "]")
- putBlock(msg.id.toString, msg.data, msg.level)
+ putBlock(msg.id, msg.data, msg.level)
None
case BlockMessage.TYPE_GET_BLOCK =>
val msg = new GetBlock(blockMessage.getId)
logDebug("Received [" + msg + "]")
- val buffer = getBlock(msg.id.toString)
+ val buffer = getBlock(msg.id)
if (buffer == null) {
return None
}
@@ -183,20 +197,20 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
}
}
- private def putBlock(blockId: String, bytes: ByteBuffer, level: StorageLevel) {
+ private def putBlock(blockId: BlockId, bytes: ByteBuffer, level: StorageLevel) {
val startTimeMs = System.currentTimeMillis()
logDebug("PutBlock " + blockId + " started from " + startTimeMs + " with data: " + bytes)
- blockDataManager.putBlockData(blockId, new NioByteBufferManagedBuffer(bytes), level)
+ blockDataManager.putBlockData(blockId, new NioManagedBuffer(bytes), level)
logDebug("PutBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " with data size: " + bytes.limit)
}
- private def getBlock(blockId: String): ByteBuffer = {
+ private def getBlock(blockId: BlockId): ByteBuffer = {
val startTimeMs = System.currentTimeMillis()
logDebug("GetBlock " + blockId + " started from " + startTimeMs)
- val buffer = blockDataManager.getBlockData(blockId).orNull
+ val buffer = blockDataManager.getBlockData(blockId)
logDebug("GetBlock " + blockId + " used " + Utils.getUsedTimeMs(startTimeMs)
+ " and got buffer " + buffer)
- if (buffer == null) null else buffer.nioByteBuffer()
+ buffer.nioByteBuffer()
}
}
diff --git a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
index 3155dfe165664..637492a97551b 100644
--- a/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/CountEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.NormalDistribution
/**
* An ApproximateEvaluator for counts.
@@ -46,7 +46,8 @@ private[spark] class CountEvaluator(totalOutputs: Int, confidence: Double)
val mean = (sum + 1 - p) / p
val variance = (sum + 1) * (1 - p) / (p * p)
val stdev = math.sqrt(variance)
- val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val confFactor = new NormalDistribution().
+ inverseCumulativeProbability(1 - (1 - confidence) / 2)
val low = mean - confFactor * stdev
val high = mean + confFactor * stdev
new BoundedDouble(mean, confidence, low, high)
diff --git a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
index 8bb78123e3c9c..3ef3cc219dec6 100644
--- a/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/GroupedCountEvaluator.scala
@@ -24,7 +24,7 @@ import scala.collection.Map
import scala.collection.mutable.HashMap
import scala.reflect.ClassTag
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.NormalDistribution
import org.apache.spark.util.collection.OpenHashMap
@@ -55,7 +55,8 @@ private[spark] class GroupedCountEvaluator[T : ClassTag](totalOutputs: Int, conf
new HashMap[T, BoundedDouble]
} else {
val p = outputsMerged.toDouble / totalOutputs
- val confFactor = Probability.normalInverse(1 - (1 - confidence) / 2)
+ val confFactor = new NormalDistribution().
+ inverseCumulativeProbability(1 - (1 - confidence) / 2)
val result = new JHashMap[T, BoundedDouble](sums.size)
sums.foreach { case (key, sum) =>
val mean = (sum + 1 - p) / p
diff --git a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
index d24959cba8727..787a21a61fdcf 100644
--- a/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/MeanEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution}
import org.apache.spark.util.StatCounter
@@ -45,9 +45,10 @@ private[spark] class MeanEvaluator(totalOutputs: Int, confidence: Double)
val stdev = math.sqrt(counter.sampleVariance / counter.count)
val confFactor = {
if (counter.count > 100) {
- Probability.normalInverse(1 - (1 - confidence) / 2)
+ new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
- Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ val degreesOfFreedom = (counter.count - 1).toInt
+ new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
}
val low = mean - confFactor * stdev
diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
index 92915ee66d29f..828bf96c2c0bd 100644
--- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
+++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution}
/**
* A utility class for caching Student's T distribution values for a given confidence level
@@ -25,8 +25,10 @@ import cern.jet.stat.Probability
* confidence intervals for many keys.
*/
private[spark] class StudentTCacher(confidence: Double) {
+
val NORMAL_APPROX_SAMPLE_SIZE = 100 // For samples bigger than this, use Gaussian approximation
- val normalApprox = Probability.normalInverse(1 - (1 - confidence) / 2)
+
+ val normalApprox = new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
val cache = Array.fill[Double](NORMAL_APPROX_SAMPLE_SIZE)(-1.0)
def get(sampleSize: Long): Double = {
@@ -35,7 +37,8 @@ private[spark] class StudentTCacher(confidence: Double) {
} else {
val size = sampleSize.toInt
if (cache(size) < 0) {
- cache(size) = Probability.studentTInverse(1 - confidence, size - 1)
+ val tDist = new TDistribution(size - 1)
+ cache(size) = tDist.inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
cache(size)
}
diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
index d5336284571d2..1753c2561b678 100644
--- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
+++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala
@@ -17,7 +17,7 @@
package org.apache.spark.partial
-import cern.jet.stat.Probability
+import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution}
import org.apache.spark.util.StatCounter
@@ -55,9 +55,10 @@ private[spark] class SumEvaluator(totalOutputs: Int, confidence: Double)
val sumStdev = math.sqrt(sumVar)
val confFactor = {
if (counter.count > 100) {
- Probability.normalInverse(1 - (1 - confidence) / 2)
+ new NormalDistribution().inverseCumulativeProbability(1 - (1 - confidence) / 2)
} else {
- Probability.studentTInverse(1 - confidence, (counter.count - 1).toInt)
+ val degreesOfFreedom = (counter.count - 1).toInt
+ new TDistribution(degreesOfFreedom).inverseCumulativeProbability(1 - (1 - confidence) / 2)
}
}
val low = sumEstimate - confFactor * sumStdev
diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
new file mode 100644
index 0000000000000..6e66ddbdef788
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.apache.hadoop.conf.{ Configurable, Configuration }
+import org.apache.hadoop.io.Writable
+import org.apache.hadoop.mapreduce._
+import org.apache.spark.input.StreamFileInputFormat
+import org.apache.spark.{ Partition, SparkContext }
+
+private[spark] class BinaryFileRDD[T](
+ sc: SparkContext,
+ inputFormatClass: Class[_ <: StreamFileInputFormat[T]],
+ keyClass: Class[String],
+ valueClass: Class[T],
+ @transient conf: Configuration,
+ minPartitions: Int)
+ extends NewHadoopRDD[String, T](sc, inputFormatClass, keyClass, valueClass, conf) {
+
+ override def getPartitions: Array[Partition] = {
+ val inputFormat = inputFormatClass.newInstance
+ inputFormat match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val jobContext = newJobContext(conf, jobId)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
+ val rawSplits = inputFormat.getSplits(jobContext).toArray
+ val result = new Array[Partition](rawSplits.size)
+ for (i <- 0 until rawSplits.size) {
+ result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable])
+ }
+ result
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 2673ec22509e9..fffa1911f5bc2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -84,5 +84,9 @@ class BlockRDD[T: ClassTag](@transient sc: SparkContext, @transient val blockIds
"Attempted to use %s after its blocks have been removed!".format(toString))
}
}
+
+ protected def getBlockIdLocations(): Map[BlockId, Seq[String]] = {
+ locations_
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
index 4908711d17db7..1cbd684224b7c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.util.Utils
private[spark]
class CartesianPartition(
@@ -36,7 +37,7 @@ class CartesianPartition(
override val index: Int = idx
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
s1 = rdd1.partitions(s1Index)
s2 = rdd2.partitions(s2Index)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index fabb882cdd4b3..ffc0a8a6d67eb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -27,6 +27,7 @@ import org.apache.spark.{InterruptibleIterator, Partition, Partitioner, SparkEnv
import org.apache.spark.{Dependency, OneToOneDependency, ShuffleDependency}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{ExternalAppendOnlyMap, AppendOnlyMap, CompactBuffer}
+import org.apache.spark.util.Utils
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleHandle
@@ -39,7 +40,7 @@ private[spark] case class NarrowCoGroupSplitDep(
) extends CoGroupSplitDep {
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
split = rdd.partitions(splitIndex)
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 11ebafbf6d457..9fab1d78abb04 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -25,6 +25,7 @@ import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark._
+import org.apache.spark.util.Utils
/**
* Class that captures a coalesced RDD by essentially keeping track of parent partitions
@@ -42,7 +43,7 @@ private[spark] case class CoalescedRDDPartition(
var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent partition at the time of task serialization
parents = parentsIndices.map(rdd.partitions(_))
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 775141775e06c..a157e36e2286e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -46,7 +46,6 @@ import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.{NextIterator, Utils}
import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation}
-
/**
* A Spark split class that wraps around a Hadoop InputSplit.
*/
@@ -212,8 +211,22 @@ class HadoopRDD[K, V](
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
val jobConf = getJobConf()
+
+ val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
+ val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) {
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
+ split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf)
+ } else {
+ None
+ }
+ if (bytesReadCallback.isDefined) {
+ context.taskMetrics.inputMetrics = Some(inputMetrics)
+ }
+
+ var reader: RecordReader[K, V] = null
val inputFormat = getInputFormat(jobConf)
HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
@@ -224,18 +237,7 @@ class HadoopRDD[K, V](
val key: K = reader.createKey()
val value: V = reader.createValue()
- // Set the task input metrics.
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- try {
- /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't
- * always at record boundaries, so tasks may need to read into other splits to complete
- * a record. */
- inputMetrics.bytesRead = split.inputSplit.value.getLength()
- } catch {
- case e: java.io.IOException =>
- logWarning("Unable to get input size to set InputMetrics for task", e)
- }
- context.taskMetrics.inputMetrics = Some(inputMetrics)
+ var recordsSinceMetricsUpdate = 0
override def getNext() = {
try {
@@ -244,12 +246,36 @@ class HadoopRDD[K, V](
case eof: EOFException =>
finished = true
}
+
+ // Update bytes read metric every few records
+ if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES
+ && bytesReadCallback.isDefined) {
+ recordsSinceMetricsUpdate = 0
+ val bytesReadFn = bytesReadCallback.get
+ inputMetrics.bytesRead = bytesReadFn()
+ } else {
+ recordsSinceMetricsUpdate += 1
+ }
(key, value)
}
override def close() {
try {
reader.close()
+ if (bytesReadCallback.isDefined) {
+ val bytesReadFn = bytesReadCallback.get
+ inputMetrics.bytesRead = bytesReadFn()
+ } else if (split.inputSplit.value.isInstanceOf[FileSplit]) {
+ // If we can't get the bytes read from the FS stats, fall back to the split size,
+ // which may be inaccurate.
+ try {
+ inputMetrics.bytesRead = split.inputSplit.value.getLength
+ context.taskMetrics.inputMetrics = Some(inputMetrics)
+ } catch {
+ case e: java.io.IOException =>
+ logWarning("Unable to get input size to set InputMetrics for task", e)
+ }
+ }
} catch {
case e: Exception => {
if (!Utils.inShutdown()) {
@@ -302,6 +328,9 @@ private[spark] object HadoopRDD extends Logging {
*/
val CONFIGURATION_INSTANTIATION_LOCK = new Object()
+ /** Update the input bytes read metric each time this number of records has been read */
+ val RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES = 256
+
/**
* The three methods below are helpers for accessing the local map, a property of the SparkEnv of
* the local process.
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 0cccdefc5ee09..351e145f96f9a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -25,6 +25,7 @@ import scala.reflect.ClassTag
import org.apache.hadoop.conf.{Configurable, Configuration}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce._
+import org.apache.hadoop.mapreduce.lib.input.FileSplit
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.input.WholeTextFileInputFormat
@@ -36,6 +37,7 @@ import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.Utils
+import org.apache.spark.deploy.SparkHadoopUtil
private[spark] class NewHadoopPartition(
rddId: Int,
@@ -105,6 +107,20 @@ class NewHadoopRDD[K, V](
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
+
+ val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
+ val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
+ SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
+ split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf)
+ } else {
+ None
+ }
+ if (bytesReadCallback.isDefined) {
+ context.taskMetrics.inputMetrics = Some(inputMetrics)
+ }
+
val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
val format = inputFormatClass.newInstance
@@ -117,22 +133,11 @@ class NewHadoopRDD[K, V](
split.serializableHadoopSplit.value, hadoopAttemptContext)
reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
- val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- try {
- /* bytesRead may not exactly equal the bytes read by a task: split boundaries aren't
- * always at record boundaries, so tasks may need to read into other splits to complete
- * a record. */
- inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength()
- } catch {
- case e: Exception =>
- logWarning("Unable to get input split size in order to set task input bytes", e)
- }
- context.taskMetrics.inputMetrics = Some(inputMetrics)
-
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
var havePair = false
var finished = false
+ var recordsSinceMetricsUpdate = 0
override def hasNext: Boolean = {
if (!finished && !havePair) {
@@ -147,12 +152,39 @@ class NewHadoopRDD[K, V](
throw new java.util.NoSuchElementException("End of stream")
}
havePair = false
+
+ // Update bytes read metric every few records
+ if (recordsSinceMetricsUpdate == HadoopRDD.RECORDS_BETWEEN_BYTES_READ_METRIC_UPDATES
+ && bytesReadCallback.isDefined) {
+ recordsSinceMetricsUpdate = 0
+ val bytesReadFn = bytesReadCallback.get
+ inputMetrics.bytesRead = bytesReadFn()
+ } else {
+ recordsSinceMetricsUpdate += 1
+ }
+
(reader.getCurrentKey, reader.getCurrentValue)
}
private def close() {
try {
reader.close()
+
+ // Update metrics with final amount
+ if (bytesReadCallback.isDefined) {
+ val bytesReadFn = bytesReadCallback.get
+ inputMetrics.bytesRead = bytesReadFn()
+ } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
+ // If we can't get the bytes read from the FS stats, fall back to the split size,
+ // which may be inaccurate.
+ try {
+ inputMetrics.bytesRead = split.serializableHadoopSplit.value.getLength
+ context.taskMetrics.inputMetrics = Some(inputMetrics)
+ } catch {
+ case e: java.io.IOException =>
+ logWarning("Unable to get input size to set InputMetrics for task", e)
+ }
+ }
} catch {
case e: Exception => {
if (!Utils.inShutdown()) {
@@ -233,7 +265,7 @@ private[spark] class WholeTextFileRDD(
case _ =>
}
val jobContext = newJobContext(conf, jobId)
- inputFormat.setMaxSplitSize(jobContext, minPartitions)
+ inputFormat.setMinPartitions(jobContext, minPartitions)
val rawSplits = inputFormat.getSplits(jobContext).toArray
val result = new Array[Partition](rawSplits.size)
for (i <- 0 until rawSplits.size) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index ac96de86dd6d4..da89f634abaea 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -315,8 +315,15 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
@deprecated("Use reduceByKeyLocally", "1.0.0")
def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = reduceByKeyLocally(func)
- /** Count the number of elements for each key, and return the result to the master as a Map. */
- def countByKey(): Map[K, Long] = self.map(_._1).countByValue()
+ /**
+ * Count the number of elements for each key, collecting the results to a local Map.
+ *
+ * Note that this method should only be used if the resulting map is expected to be small, as
+ * the whole thing is loaded into the driver's memory.
+ * To handle very large results, consider using rdd.mapValues(_ => 1L).reduceByKey(_ + _), which
+ * returns an RDD[T, Long] instead of a map.
+ */
+ def countByKey(): Map[K, Long] = self.mapValues(_ => 1L).reduceByKey(_ + _).collect().toMap
/**
* :: Experimental ::
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 66c71bf7e8bb5..87b22de6ae697 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -48,7 +48,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
override def index: Int = slice
@throws(classOf[IOException])
- private def writeObject(out: ObjectOutputStream): Unit = {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
@@ -67,7 +67,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
}
@throws(classOf[IOException])
- private def readObject(in: ObjectInputStream): Unit = {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val sfactory = SparkEnv.get.serializer
sfactory match {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
index 0c2cd7a24783b..92b0641d0fb6e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
+import org.apache.spark.util.Utils
/**
* Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of
@@ -38,7 +39,7 @@ class PartitionerAwareUnionRDDPartition(
override def hashCode(): Int = idx
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent partition at the time of task serialization
parents = rdds.map(_.partitions(index)).toArray
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 71cabf61d4ee0..c169b2d3fe97f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -43,7 +43,8 @@ import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.{BoundedPriorityQueue, Utils, CallSite}
import org.apache.spark.util.collection.OpenHashMap
-import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
+import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler,
+ SamplingUtils}
/**
* A Resilient Distributed Dataset (RDD), the basic abstraction in Spark. Represents an immutable,
@@ -375,7 +376,8 @@ abstract class RDD[T: ClassTag](
val sum = weights.sum
val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
normalizedCumWeights.sliding(2).map { x =>
- new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
+ new PartitionwiseSampledRDD[T, T](
+ this, new BernoulliCellSampler[T](x(0), x(1)), true, seed)
}.toArray
}
@@ -927,32 +929,15 @@ abstract class RDD[T: ClassTag](
}
/**
- * Return the count of each unique value in this RDD as a map of (value, count) pairs. The final
- * combine step happens locally on the master, equivalent to running a single reduce task.
+ * Return the count of each unique value in this RDD as a local map of (value, count) pairs.
+ *
+ * Note that this method should only be used if the resulting map is expected to be small, as
+ * the whole thing is loaded into the driver's memory.
+ * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which
+ * returns an RDD[T, Long] instead of a map.
*/
def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = {
- if (elementClassTag.runtimeClass.isArray) {
- throw new SparkException("countByValue() does not support arrays")
- }
- // TODO: This should perhaps be distributed by default.
- val countPartition = (iter: Iterator[T]) => {
- val map = new OpenHashMap[T,Long]
- iter.foreach {
- t => map.changeValue(t, 1L, _ + 1L)
- }
- Iterator(map)
- }: Iterator[OpenHashMap[T,Long]]
- val mergeMaps = (m1: OpenHashMap[T,Long], m2: OpenHashMap[T,Long]) => {
- m2.foreach { case (key, value) =>
- m1.changeValue(key, value, _ + value)
- }
- m1
- }: OpenHashMap[T,Long]
- val myResult = mapPartitions(countPartition).reduce(mergeMaps)
- // Convert to a Scala mutable map
- val mutableResult = scala.collection.mutable.Map[T,Long]()
- myResult.foreach { case (k, v) => mutableResult.put(k, v) }
- mutableResult
+ map(value => (value, null)).countByKey()
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
index b097c30f8c231..9e8cee5331cf8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -21,8 +21,7 @@ import java.util.Random
import scala.reflect.ClassTag
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.{Partition, TaskContext}
@@ -53,9 +52,11 @@ private[spark] class SampledRDD[T: ClassTag](
if (withReplacement) {
// For large datasets, the expected number of occurrences of each element in a sample with
// replacement is Poisson(frac). We use that to get a count for each element.
- val poisson = new Poisson(frac, new DRand(split.seed))
+ val poisson = new PoissonDistribution(frac)
+ poisson.reseedRandomGenerator(split.seed)
+
firstParent[T].iterator(split.prev, context).flatMap { element =>
- val count = poisson.nextInt()
+ val count = poisson.sample()
if (count == 0) {
Iterator.empty // Avoid object allocation when we return 0 items, which is quite often
} else {
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 0c97eb0aaa51f..aece683ff3199 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -24,6 +24,7 @@ import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
/**
* Partition for UnionRDD.
@@ -48,7 +49,7 @@ private[spark] class UnionPartition[T: ClassTag](
override val index: Int = idx
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
parentPartition = rdd.partitions(parentRddPartitionIndex)
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index f3d30f6c9b32f..996f2cd3f34a3 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -22,6 +22,7 @@ import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
+import org.apache.spark.util.Utils
private[spark] class ZippedPartitionsPartition(
idx: Int,
@@ -34,7 +35,7 @@ private[spark] class ZippedPartitionsPartition(
def partitions = partitionValues
@throws(classOf[IOException])
- private def writeObject(oos: ObjectOutputStream) {
+ private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException {
// Update the reference to parent split at the time of task serialization
partitionValues = rdds.map(rdd => rdd.partitions(idx))
oos.defaultWriteObject()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index f81fa6d8089fc..96114c0423a9e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -124,6 +124,9 @@ class DAGScheduler(
/** If enabled, we may run certain actions like take() and first() locally. */
private val localExecutionEnabled = sc.getConf.getBoolean("spark.localExecution.enabled", false)
+ /** If enabled, FetchFailed will not cause stage retry, in order to surface the problem. */
+ private val disallowStageRetryForTest = sc.getConf.getBoolean("spark.test.noStageRetry", false)
+
private def initializeEventProcessActor() {
// blocking the thread until supervisor is started, which ensures eventProcessActor is
// not null before any job is submitted
@@ -1050,7 +1053,7 @@ class DAGScheduler(
logInfo("Resubmitted " + task + ", so marking it as still running")
stage.pendingTasks += task
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, failureMessage) =>
val failedStage = stageIdToStage(task.stageId)
val mapStage = shuffleToMapStage(shuffleId)
@@ -1060,11 +1063,13 @@ class DAGScheduler(
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some("Fetch failure"))
+ markStageAsFinished(failedStage, Some("Fetch failure: " + failureMessage))
runningStages -= failedStage
}
- if (failedStages.isEmpty && eventProcessActor != null) {
+ if (disallowStageRetryForTest) {
+ abortStage(failedStage, "Fetch failure will not retry stage due to testing config")
+ } else if (failedStages.isEmpty && eventProcessActor != null) {
// Don't schedule an event to resubmit failed stages if failed isn't empty, because
// in that case the event will already have been scheduled. eventProcessActor may be
// null during unit tests.
@@ -1086,7 +1091,7 @@ class DAGScheduler(
// TODO: mark the executor as failed only if there were lots of fetch failures on it
if (bmAddress != null) {
- handleExecutorLost(bmAddress.executorId, Some(task.epoch))
+ handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
case ExceptionFailure(className, description, stackTrace, metrics) =>
@@ -1106,25 +1111,35 @@ class DAGScheduler(
* Responds to an executor being lost. This is called inside the event loop, so it assumes it can
* modify the scheduler's internal state. Use executorLost() to post a loss event from outside.
*
+ * We will also assume that we've lost all shuffle blocks associated with the executor if the
+ * executor serves its own blocks (i.e., we're not using external shuffle) OR a FetchFailed
+ * occurred, in which case we presume all shuffle data related to this executor to be lost.
+ *
* Optionally the epoch during which the failure was caught can be passed to avoid allowing
* stray fetch failures from possibly retriggering the detection of a node as lost.
*/
- private[scheduler] def handleExecutorLost(execId: String, maybeEpoch: Option[Long] = None) {
+ private[scheduler] def handleExecutorLost(
+ execId: String,
+ fetchFailed: Boolean,
+ maybeEpoch: Option[Long] = None) {
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
if (!failedEpoch.contains(execId) || failedEpoch(execId) < currentEpoch) {
failedEpoch(execId) = currentEpoch
logInfo("Executor lost: %s (epoch %d)".format(execId, currentEpoch))
blockManagerMaster.removeExecutor(execId)
- // TODO: This will be really slow if we keep accumulating shuffle map stages
- for ((shuffleId, stage) <- shuffleToMapStage) {
- stage.removeOutputsOnExecutor(execId)
- val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
- mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
- }
- if (shuffleToMapStage.isEmpty) {
- mapOutputTracker.incrementEpoch()
+
+ if (!env.blockManager.externalShuffleServiceEnabled || fetchFailed) {
+ // TODO: This will be really slow if we keep accumulating shuffle map stages
+ for ((shuffleId, stage) <- shuffleToMapStage) {
+ stage.removeOutputsOnExecutor(execId)
+ val locs = stage.outputLocs.map(list => if (list.isEmpty) null else list.head).toArray
+ mapOutputTracker.registerMapOutputs(shuffleId, locs, changeEpoch = true)
+ }
+ if (shuffleToMapStage.isEmpty) {
+ mapOutputTracker.incrementEpoch()
+ }
+ clearCacheLocs()
}
- clearCacheLocs()
} else {
logDebug("Additional executor lost message for " + execId +
"(epoch " + currentEpoch + ")")
@@ -1382,7 +1397,7 @@ private[scheduler] class DAGSchedulerEventProcessActor(dagScheduler: DAGSchedule
dagScheduler.handleExecutorAdded(execId, host)
case ExecutorLost(execId) =>
- dagScheduler.handleExecutorLost(execId)
+ dagScheduler.handleExecutorLost(execId, fetchFailed = false)
case BeginEvent(task, taskInfo) =>
dagScheduler.handleBeginEvent(task, taskInfo)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
index 100c9ba9b7809..597dbc884913c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala
@@ -142,7 +142,7 @@ private[spark] object EventLoggingListener extends Logging {
val SPARK_VERSION_PREFIX = "SPARK_VERSION_"
val COMPRESSION_CODEC_PREFIX = "COMPRESSION_CODEC_"
val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
- val LOG_FILE_PERMISSIONS = FsPermission.createImmutable(Integer.parseInt("770", 8).toShort)
+ val LOG_FILE_PERMISSIONS = new FsPermission(Integer.parseInt("770", 8).toShort)
// A cache for compression codecs to avoid creating the same codec many times
private val codecMap = new mutable.HashMap[String, CompressionCodec]
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 54904bffdf10b..4e3d9de540783 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -215,7 +215,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId +
" STAGE_ID=" + taskEnd.stageId
stageLogInfo(taskEnd.stageId, taskStatus)
- case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
+ case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) =>
taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" +
taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" +
mapId + " REDUCE_ID=" + reduceId
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index e25096ea92d70..01d5943d777f3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -19,7 +19,10 @@ package org.apache.spark.scheduler
import java.io.{Externalizable, ObjectInput, ObjectOutput}
+import org.roaringbitmap.RoaringBitmap
+
import org.apache.spark.storage.BlockManagerId
+import org.apache.spark.util.Utils
/**
* Result returned by a ShuffleMapTask to a scheduler. Includes the block manager address that the
@@ -29,7 +32,12 @@ private[spark] sealed trait MapStatus {
/** Location where this task was run. */
def location: BlockManagerId
- /** Estimated size for the reduce block, in bytes. */
+ /**
+ * Estimated size for the reduce block, in bytes.
+ *
+ * If a block is non-empty, then this method MUST return a non-zero size. This invariant is
+ * necessary for correctness, since block fetchers are allowed to skip zero-size blocks.
+ */
def getSizeForBlock(reduceId: Int): Long
}
@@ -38,7 +46,7 @@ private[spark] object MapStatus {
def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): MapStatus = {
if (uncompressedSizes.length > 2000) {
- new HighlyCompressedMapStatus(loc, uncompressedSizes)
+ HighlyCompressedMapStatus(loc, uncompressedSizes)
} else {
new CompressedMapStatus(loc, uncompressedSizes)
}
@@ -98,13 +106,13 @@ private[spark] class CompressedMapStatus(
MapStatus.decompressSize(compressedSizes(reduceId))
}
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
out.writeInt(compressedSizes.length)
out.write(compressedSizes)
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
val len = in.readInt()
compressedSizes = new Array[Byte](len)
@@ -112,35 +120,80 @@ private[spark] class CompressedMapStatus(
}
}
-
/**
- * A [[MapStatus]] implementation that only stores the average size of the blocks.
+ * A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
+ * plus a bitmap for tracking which blocks are non-empty. During serialization, this bitmap
+ * is compressed.
*
- * @param loc location where the task is being executed.
- * @param avgSize average size of all the blocks
+ * @param loc location where the task is being executed
+ * @param numNonEmptyBlocks the number of non-empty blocks
+ * @param emptyBlocks a bitmap tracking which blocks are empty
+ * @param avgSize average size of the non-empty blocks
*/
-private[spark] class HighlyCompressedMapStatus(
+private[spark] class HighlyCompressedMapStatus private (
private[this] var loc: BlockManagerId,
+ private[this] var numNonEmptyBlocks: Int,
+ private[this] var emptyBlocks: RoaringBitmap,
private[this] var avgSize: Long)
extends MapStatus with Externalizable {
- def this(loc: BlockManagerId, uncompressedSizes: Array[Long]) {
- this(loc, uncompressedSizes.sum / uncompressedSizes.length)
- }
+ // loc could be null when the default constructor is called during deserialization
+ require(loc == null || avgSize > 0 || numNonEmptyBlocks == 0,
+ "Average size can only be zero for map stages that produced no output")
- protected def this() = this(null, 0L) // For deserialization only
+ protected def this() = this(null, -1, null, -1) // For deserialization only
override def location: BlockManagerId = loc
- override def getSizeForBlock(reduceId: Int): Long = avgSize
+ override def getSizeForBlock(reduceId: Int): Long = {
+ if (emptyBlocks.contains(reduceId)) {
+ 0
+ } else {
+ avgSize
+ }
+ }
- override def writeExternal(out: ObjectOutput): Unit = {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
loc.writeExternal(out)
+ emptyBlocks.writeExternal(out)
out.writeLong(avgSize)
}
- override def readExternal(in: ObjectInput): Unit = {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
loc = BlockManagerId(in)
+ emptyBlocks = new RoaringBitmap()
+ emptyBlocks.readExternal(in)
avgSize = in.readLong()
}
}
+
+private[spark] object HighlyCompressedMapStatus {
+ def apply(loc: BlockManagerId, uncompressedSizes: Array[Long]): HighlyCompressedMapStatus = {
+ // We must keep track of which blocks are empty so that we don't report a zero-sized
+ // block as being non-empty (or vice-versa) when using the average block size.
+ var i = 0
+ var numNonEmptyBlocks: Int = 0
+ var totalSize: Long = 0
+ // From a compression standpoint, it shouldn't matter whether we track empty or non-empty
+ // blocks. From a performance standpoint, we benefit from tracking empty blocks because
+ // we expect that there will be far fewer of them, so we will perform fewer bitmap insertions.
+ val emptyBlocks = new RoaringBitmap()
+ val totalNumBlocks = uncompressedSizes.length
+ while (i < totalNumBlocks) {
+ var size = uncompressedSizes(i)
+ if (size > 0) {
+ numNonEmptyBlocks += 1
+ totalSize += size
+ } else {
+ emptyBlocks.add(i)
+ }
+ i += 1
+ }
+ val avgSize = if (numNonEmptyBlocks > 0) {
+ totalSize / numNonEmptyBlocks
+ } else {
+ 0
+ }
+ new HighlyCompressedMapStatus(loc, numNonEmptyBlocks, emptyBlocks, avgSize)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
index 071568cdfb429..cc13f57a49b89 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala
@@ -102,6 +102,11 @@ private[spark] class Stage(
}
}
+ /**
+ * Removes all shuffle outputs associated with this executor. Note that this will also remove
+ * outputs which are served by an external shuffle server (if one exists), as they are still
+ * registered with this execId.
+ */
def removeOutputsOnExecutor(execId: String) {
var becameUnavailable = false
for (partition <- 0 until numPartitions) {
@@ -131,4 +136,9 @@ private[spark] class Stage(
override def toString = "Stage " + id
override def hashCode(): Int = id
+
+ override def equals(other: Any): Boolean = other match {
+ case stage: Stage => stage != null && stage.id == id
+ case _ => false
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index d49d8fb887007..1f114a0207f7b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -31,8 +31,8 @@ import org.apache.spark.util.Utils
private[spark] sealed trait TaskResult[T]
/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
-private[spark]
-case class IndirectTaskResult[T](blockId: BlockId) extends TaskResult[T] with Serializable
+private[spark] case class IndirectTaskResult[T](blockId: BlockId, size: Int)
+ extends TaskResult[T] with Serializable
/** A TaskResult that contains the task's return value and accumulator updates. */
private[spark]
@@ -42,7 +42,7 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
def this() = this(null.asInstanceOf[ByteBuffer], null, null)
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(valueBytes.remaining);
Utils.writeByteBuffer(valueBytes, out)
@@ -55,7 +55,7 @@ class DirectTaskResult[T](var valueBytes: ByteBuffer, var accumUpdates: Map[Long
out.writeObject(metrics)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val blen = in.readInt()
val byteVal = new Array[Byte](blen)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 3f345ceeaaf7a..819b51e12ad8c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -47,9 +47,18 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
getTaskResultExecutor.execute(new Runnable {
override def run(): Unit = Utils.logUncaughtExceptions {
try {
- val result = serializer.get().deserialize[TaskResult[_]](serializedData) match {
- case directResult: DirectTaskResult[_] => directResult
- case IndirectTaskResult(blockId) =>
+ val (result, size) = serializer.get().deserialize[TaskResult[_]](serializedData) match {
+ case directResult: DirectTaskResult[_] =>
+ if (!taskSetManager.canFetchMoreResults(serializedData.limit())) {
+ return
+ }
+ (directResult, serializedData.limit())
+ case IndirectTaskResult(blockId, size) =>
+ if (!taskSetManager.canFetchMoreResults(size)) {
+ // dropped by executor if size is larger than maxResultSize
+ sparkEnv.blockManager.master.removeBlock(blockId)
+ return
+ }
logDebug("Fetching indirect task result for TID %s".format(tid))
scheduler.handleTaskGettingResult(taskSetManager, tid)
val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
@@ -64,9 +73,10 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
serializedTaskResult.get)
sparkEnv.blockManager.master.removeBlock(blockId)
- deserializedResult
+ (deserializedResult, size)
}
- result.metrics.resultSize = serializedData.limit()
+
+ result.metrics.resultSize = size
scheduler.handleSuccessfulTask(taskSetManager, tid, result)
} catch {
case cnf: ClassNotFoundException =>
@@ -93,7 +103,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
}
} catch {
case cnd: ClassNotFoundException =>
- // Log an error but keep going here -- the task failed, so not catastropic if we can't
+ // Log an error but keep going here -- the task failed, so not catastrophic if we can't
// deserialize the reason.
val loader = Utils.getContextOrSparkClassLoader
logError(
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index a129a434c9a1a..f095915352b17 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.storage.BlockManagerId
/**
* Low-level task scheduler interface, currently implemented exclusively by TaskSchedulerImpl.
- * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
+ * This interface allows plugging in different task schedulers. Each TaskScheduler schedules tasks
* for a single SparkContext. These schedulers get sets of tasks submitted to them from the
* DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
* them, retrying if there are failures, and mitigating stragglers. They return events to the
@@ -41,7 +41,7 @@ private[spark] trait TaskScheduler {
// Invoked after system has successfully initialized (typically in spark context).
// Yarn uses this to bootstrap allocation of resources based on preferred locations,
- // wait for slave registerations, etc.
+ // wait for slave registrations, etc.
def postStartHook() { }
// Disconnect from the cluster.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 6d697e3d003f6..cd3c015321e85 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -34,7 +34,6 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
import org.apache.spark.util.Utils
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
-import akka.actor.Props
/**
* Schedules tasks for multiple types of clusters by acting through a SchedulerBackend.
@@ -221,6 +220,7 @@ private[spark] class TaskSchedulerImpl(
var newExecAvail = false
for (o <- offers) {
executorIdToHost(o.executorId) = o.host
+ activeExecutorIds += o.executorId
if (!executorsByHost.contains(o.host)) {
executorsByHost(o.host) = new HashSet[String]()
executorAdded(o.executorId, o.host)
@@ -261,7 +261,6 @@ private[spark] class TaskSchedulerImpl(
val tid = task.taskId
taskIdToTaskSetId(tid) = taskSet.taskSet.id
taskIdToExecutorId(tid) = execId
- activeExecutorIds += execId
executorsByHost(host) += execId
availableCpus(i) -= CPUS_PER_TASK
assert(availableCpus(i) >= 0)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index a6c23fc85a1b0..d8fb640350343 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -23,13 +23,12 @@ import java.util.Arrays
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
-import scala.math.max
-import scala.math.min
+import scala.math.{min, max}
import org.apache.spark._
-import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.TaskMetrics
-import org.apache.spark.util.{Clock, SystemClock}
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.util.{Clock, SystemClock, Utils}
/**
* Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of
@@ -68,6 +67,9 @@ private[spark] class TaskSetManager(
val SPECULATION_QUANTILE = conf.getDouble("spark.speculation.quantile", 0.75)
val SPECULATION_MULTIPLIER = conf.getDouble("spark.speculation.multiplier", 1.5)
+ // Limit of bytes for total size of results (default is 1GB)
+ val maxResultSize = Utils.getMaxResultSize(conf)
+
// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
@@ -89,6 +91,8 @@ private[spark] class TaskSetManager(
var stageId = taskSet.stageId
var name = "TaskSet_" + taskSet.stageId.toString
var parent: Pool = null
+ var totalResultSize = 0L
+ var calculatedTasks = 0
val runningTasksSet = new HashSet[Long]
override def runningTasks = runningTasksSet.size
@@ -515,12 +519,33 @@ private[spark] class TaskSetManager(
index
}
+ /**
+ * Marks the task as getting result and notifies the DAG Scheduler
+ */
def handleTaskGettingResult(tid: Long) = {
val info = taskInfos(tid)
info.markGettingResult()
sched.dagScheduler.taskGettingResult(info)
}
+ /**
+ * Check whether has enough quota to fetch the result with `size` bytes
+ */
+ def canFetchMoreResults(size: Long): Boolean = synchronized {
+ totalResultSize += size
+ calculatedTasks += 1
+ if (maxResultSize > 0 && totalResultSize > maxResultSize) {
+ val msg = s"Total size of serialized results of ${calculatedTasks} tasks " +
+ s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " +
+ s"(${Utils.bytesToString(maxResultSize)})"
+ logError(msg)
+ abort(msg)
+ false
+ } else {
+ true
+ }
+ }
+
/**
* Marks the task as successful and notifies the DAGScheduler that a task has ended.
*/
@@ -687,10 +712,11 @@ private[spark] class TaskSetManager(
addPendingTask(index, readding=true)
}
- // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage.
+ // Re-enqueue any tasks that ran on the failed executor if this is a shuffle map stage,
+ // and we are not using an external shuffle server which could serve the shuffle outputs.
// The reason is the next stage wouldn't be able to fetch the data from this dead executor
// so we would need to rerun these tasks on other executors.
- if (tasks(0).isInstanceOf[ShuffleMapTask]) {
+ if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) {
for ((tid, info) <- taskInfos if info.executorId == execId) {
val index = taskInfos(tid).index
if (successful(index)) {
@@ -706,7 +732,7 @@ private[spark] class TaskSetManager(
}
// Also re-enqueue any tasks that were running on the node
for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
- handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure)
+ handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure(execId))
}
// recalculate valid locality levels and waits when executor is lost
recomputeLocality()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index fb8160abc59db..1da6fe976da5b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -66,7 +66,19 @@ private[spark] object CoarseGrainedClusterMessages {
case class RemoveExecutor(executorId: String, reason: String) extends CoarseGrainedClusterMessage
- case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase :String)
+ // Exchanged between the driver and the AM in Yarn client mode
+ case class AddWebUIFilter(filterName:String, filterParams: Map[String, String], proxyBase: String)
extends CoarseGrainedClusterMessage
+ // Messages exchanged between the driver and the cluster manager for executor allocation
+ // In Yarn mode, these are exchanged between the driver and the AM
+
+ case object RegisterClusterManager extends CoarseGrainedClusterMessage
+
+ // Request executors by specifying the new total number of executors desired
+ // This includes executors already pending or running
+ case class RequestExecutors(requestedTotal: Int) extends CoarseGrainedClusterMessage
+
+ case class KillExecutors(executorIds: Seq[String]) extends CoarseGrainedClusterMessage
+
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 59aed6b72fe42..7a6ee56f81689 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -31,7 +31,6 @@ import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState}
import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils}
-import org.apache.spark.ui.JettyUtils
/**
* A scheduler backend that waits for coarse grained executors to connect to it through Akka.
@@ -42,7 +41,7 @@ import org.apache.spark.ui.JettyUtils
* (spark.deploy.*).
*/
private[spark]
-class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: ActorSystem)
+class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem)
extends SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
@@ -61,10 +60,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
conf.getInt("spark.scheduler.maxRegisteredResourcesWaitingTime", 30000)
val createTime = System.currentTimeMillis()
+ private val executorDataMap = new HashMap[String, ExecutorData]
+
+ // Number of executors requested from the cluster manager that have not registered yet
+ private var numPendingExecutors = 0
+
+ // Executors we have requested the cluster manager to kill that have not died yet
+ private val executorsPendingToRemove = new HashSet[String]
+
class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive {
override protected def log = CoarseGrainedSchedulerBackend.this.log
private val addressToExecutorId = new HashMap[Address, String]
- private val executorDataMap = new HashMap[String, ExecutorData]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
@@ -84,12 +90,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
} else {
logInfo("Registered executor: " + sender + " with ID " + executorId)
sender ! RegisteredExecutor
- executorDataMap.put(executorId, new ExecutorData(sender, sender.path.address,
- Utils.parseHostPort(hostPort)._1, cores, cores))
addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
totalRegisteredExecutors.addAndGet(1)
+ val (host, _) = Utils.parseHostPort(hostPort)
+ val data = new ExecutorData(sender, sender.path.address, host, cores, cores)
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ CoarseGrainedSchedulerBackend.this.synchronized {
+ executorDataMap.put(executorId, data)
+ if (numPendingExecutors > 0) {
+ numPendingExecutors -= 1
+ logDebug(s"Decremented number of pending executors ($numPendingExecutors left)")
+ }
+ }
makeOffers()
}
@@ -128,10 +143,6 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
removeExecutor(executorId, reason)
sender ! true
- case AddWebUIFilter(filterName, filterParams, proxyBase) =>
- addWebUIFilter(filterName, filterParams, proxyBase)
- sender ! true
-
case DisassociatedEvent(_, address, _) =>
addressToExecutorId.get(address).foreach(removeExecutor(_,
"remote Akka client disassociated"))
@@ -183,13 +194,18 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
}
// Remove a disconnected slave from the cluster
- def removeExecutor(executorId: String, reason: String) {
+ def removeExecutor(executorId: String, reason: String): Unit = {
executorDataMap.get(executorId) match {
case Some(executorInfo) =>
- executorDataMap -= executorId
+ // This must be synchronized because variables mutated
+ // in this block are read when requesting executors
+ CoarseGrainedSchedulerBackend.this.synchronized {
+ executorDataMap -= executorId
+ executorsPendingToRemove -= executorId
+ }
totalCoreCount.addAndGet(-executorInfo.totalCores)
scheduler.executorLost(executorId, SlaveLost(reason))
- case None => logError(s"Asked to remove non existant executor $executorId")
+ case None => logError(s"Asked to remove non-existent executor $executorId")
}
}
}
@@ -274,21 +290,62 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, actorSystem: A
false
}
- // Add filters to the SparkUI
- def addWebUIFilter(filterName: String, filterParams: Map[String, String], proxyBase: String) {
- if (proxyBase != null && proxyBase.nonEmpty) {
- System.setProperty("spark.ui.proxyBase", proxyBase)
- }
+ /**
+ * Return the number of executors currently registered with this backend.
+ */
+ def numExistingExecutors: Int = executorDataMap.size
+
+ /**
+ * Request an additional number of executors from the cluster manager.
+ * Return whether the request is acknowledged.
+ */
+ final def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
+ logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
+ logDebug(s"Number of pending executors is now $numPendingExecutors")
+ numPendingExecutors += numAdditionalExecutors
+ // Account for executors pending to be added or removed
+ val newTotal = numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size
+ doRequestTotalExecutors(newTotal)
+ }
- val hasFilter = (filterName != null && filterName.nonEmpty &&
- filterParams != null && filterParams.nonEmpty)
- if (hasFilter) {
- logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
- conf.set("spark.ui.filters", filterName)
- filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
- scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
+ /**
+ * Request executors from the cluster manager by specifying the total number desired,
+ * including existing pending and running executors.
+ *
+ * The semantics here guarantee that we do not over-allocate executors for this application,
+ * since a later request overrides the value of any prior request. The alternative interface
+ * of requesting a delta of executors risks double counting new executors when there are
+ * insufficient resources to satisfy the first request. We make the assumption here that the
+ * cluster manager will eventually fulfill all requests when resources free up.
+ *
+ * Return whether the request is acknowledged.
+ */
+ protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false
+
+ /**
+ * Request that the cluster manager kill the specified executors.
+ * Return whether the kill request is acknowledged.
+ */
+ final def killExecutors(executorIds: Seq[String]): Boolean = {
+ logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
+ val filteredExecutorIds = new ArrayBuffer[String]
+ executorIds.foreach { id =>
+ if (executorDataMap.contains(id)) {
+ filteredExecutorIds += id
+ } else {
+ logWarning(s"Executor to kill $id does not exist!")
+ }
}
+ executorsPendingToRemove ++= filteredExecutorIds
+ doKillExecutors(filteredExecutorIds)
}
+
+ /**
+ * Kill the given list of executors through the cluster manager.
+ * Return whether the kill request is acknowledged.
+ */
+ protected def doKillExecutors(executorIds: Seq[String]): Boolean = false
+
}
private[spark] object CoarseGrainedSchedulerBackend {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
new file mode 100644
index 0000000000000..50721b9d6cd6c
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import akka.actor.{Actor, ActorRef, Props}
+import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
+
+import org.apache.spark.SparkContext
+import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
+import org.apache.spark.scheduler.TaskSchedulerImpl
+import org.apache.spark.ui.JettyUtils
+import org.apache.spark.util.AkkaUtils
+
+/**
+ * Abstract Yarn scheduler backend that contains common logic
+ * between the client and cluster Yarn scheduler backends.
+ */
+private[spark] abstract class YarnSchedulerBackend(
+ scheduler: TaskSchedulerImpl,
+ sc: SparkContext)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) {
+
+ if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) {
+ minRegisteredRatio = 0.8
+ }
+
+ protected var totalExpectedExecutors = 0
+
+ private val yarnSchedulerActor: ActorRef =
+ actorSystem.actorOf(
+ Props(new YarnSchedulerActor),
+ name = YarnSchedulerBackend.ACTOR_NAME)
+
+ private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf)
+
+ /**
+ * Request executors from the ApplicationMaster by specifying the total number desired.
+ * This includes executors already pending or running.
+ */
+ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = {
+ AkkaUtils.askWithReply[Boolean](
+ RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout)
+ }
+
+ /**
+ * Request that the ApplicationMaster kill the specified executors.
+ */
+ override def doKillExecutors(executorIds: Seq[String]): Boolean = {
+ AkkaUtils.askWithReply[Boolean](
+ KillExecutors(executorIds), yarnSchedulerActor, askTimeout)
+ }
+
+ override def sufficientResourcesRegistered(): Boolean = {
+ totalRegisteredExecutors.get() >= totalExpectedExecutors * minRegisteredRatio
+ }
+
+ /**
+ * Add filters to the SparkUI.
+ */
+ private def addWebUIFilter(
+ filterName: String,
+ filterParams: Map[String, String],
+ proxyBase: String): Unit = {
+ if (proxyBase != null && proxyBase.nonEmpty) {
+ System.setProperty("spark.ui.proxyBase", proxyBase)
+ }
+
+ val hasFilter =
+ filterName != null && filterName.nonEmpty &&
+ filterParams != null && filterParams.nonEmpty
+ if (hasFilter) {
+ logInfo(s"Add WebUI Filter. $filterName, $filterParams, $proxyBase")
+ conf.set("spark.ui.filters", filterName)
+ filterParams.foreach { case (k, v) => conf.set(s"spark.$filterName.param.$k", v) }
+ scheduler.sc.ui.foreach { ui => JettyUtils.addFilters(ui.getHandlers, conf) }
+ }
+ }
+
+ /**
+ * An actor that communicates with the ApplicationMaster.
+ */
+ private class YarnSchedulerActor extends Actor {
+ private var amActor: Option[ActorRef] = None
+
+ override def preStart(): Unit = {
+ // Listen for disassociation events
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ }
+
+ override def receive = {
+ case RegisterClusterManager =>
+ logInfo(s"ApplicationMaster registered as $sender")
+ amActor = Some(sender)
+
+ case r: RequestExecutors =>
+ amActor match {
+ case Some(actor) =>
+ sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ case None =>
+ logWarning("Attempted to request executors before the AM has registered!")
+ sender ! false
+ }
+
+ case k: KillExecutors =>
+ amActor match {
+ case Some(actor) =>
+ sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ case None =>
+ logWarning("Attempted to kill executors before the AM has registered!")
+ sender ! false
+ }
+
+ case AddWebUIFilter(filterName, filterParams, proxyBase) =>
+ addWebUIFilter(filterName, filterParams, proxyBase)
+ sender ! true
+
+ case d: DisassociatedEvent =>
+ if (amActor.isDefined && sender == amActor.get) {
+ logWarning(s"ApplicationMaster has disassociated: $d")
+ }
+ }
+ }
+}
+
+private[spark] object YarnSchedulerBackend {
+ val ACTOR_NAME = "YarnScheduler"
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index d7f88de4b40aa..d8c0e2f66df01 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -31,6 +31,7 @@ import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, TaskState => MesosTas
import org.apache.spark.{Logging, SparkContext, SparkEnv, SparkException}
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.util.Utils
/**
* A SchedulerBackend that runs tasks on Mesos, but uses "coarse-grained" tasks, where it holds
@@ -120,16 +121,18 @@ private[spark] class CoarseMesosSchedulerBackend(
environment.addVariables(
Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
}
- val extraJavaOpts = conf.getOption("spark.executor.extraJavaOptions")
+ val extraJavaOpts = conf.get("spark.executor.extraJavaOptions", "")
- val libraryPathOption = "spark.executor.extraLibraryPath"
- val extraLibraryPath = conf.getOption(libraryPathOption).map(p => s"-Djava.library.path=$p")
- val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ")
+ // Set the environment variable through a command prefix
+ // to append to the existing value of the variable
+ val prefixEnv = conf.getOption("spark.executor.extraLibraryPath").map { p =>
+ Utils.libraryPathEnvPrefix(Seq(p))
+ }.getOrElse("")
environment.addVariables(
Environment.Variable.newBuilder()
.setName("SPARK_EXECUTOR_OPTS")
- .setValue(extraOpts)
+ .setValue(extraJavaOpts)
.build())
sc.executorEnvs.foreach { case (key, value) =>
@@ -150,16 +153,17 @@ private[spark] class CoarseMesosSchedulerBackend(
if (uri == null) {
val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath
command.setValue(
- "\"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format(
- runScript, driverUrl, offer.getSlaveId.getValue, offer.getHostname, numCores, appId))
+ "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format(
+ prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue,
+ offer.getHostname, numCores, appId))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- ("cd %s*; " +
+ ("cd %s*; %s " +
"./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s")
- .format(basename, driverUrl, offer.getSlaveId.getValue,
+ .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue,
offer.getHostname, numCores, appId))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index e0f2fd622f54c..8e2faff90f9b2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -98,15 +98,16 @@ private[spark] class MesosSchedulerBackend(
environment.addVariables(
Environment.Variable.newBuilder().setName("SPARK_CLASSPATH").setValue(cp).build())
}
- val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
- val extraLibraryPath = sc.conf.getOption("spark.executor.extraLibraryPath").map { lp =>
- s"-Djava.library.path=$lp"
- }
- val extraOpts = Seq(extraJavaOpts, extraLibraryPath).flatten.mkString(" ")
+ val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions").getOrElse("")
+
+ val prefixEnv = sc.conf.getOption("spark.executor.extraLibraryPath").map { p =>
+ Utils.libraryPathEnvPrefix(Seq(p))
+ }.getOrElse("")
+
environment.addVariables(
Environment.Variable.newBuilder()
.setName("SPARK_EXECUTOR_OPTS")
- .setValue(extraOpts)
+ .setValue(extraJavaOpts)
.build())
sc.executorEnvs.foreach { case (key, value) =>
environment.addVariables(Environment.Variable.newBuilder()
@@ -118,12 +119,13 @@ private[spark] class MesosSchedulerBackend(
.setEnvironment(environment)
val uri = sc.conf.get("spark.executor.uri", null)
if (uri == null) {
- command.setValue(new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath)
+ val executorPath = new File(executorSparkHome, "/sbin/spark-executor").getCanonicalPath
+ command.setValue("%s %s".format(prefixEnv, executorPath))
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
- command.setValue("cd %s*; ./sbin/spark-executor".format(basename))
+ command.setValue("cd %s*; %s ./sbin/spark-executor".format(basename, prefixEnv))
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
val cpus = Resource.newBuilder()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index 58b78f041cd85..c0264836de738 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -21,7 +21,7 @@ import java.nio.ByteBuffer
import akka.actor.{Actor, ActorRef, Props}
-import org.apache.spark.{Logging, SparkEnv, TaskState}
+import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState}
import org.apache.spark.TaskState.TaskState
import org.apache.spark.executor.{Executor, ExecutorBackend}
import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer}
@@ -47,7 +47,7 @@ private[spark] class LocalActor(
private var freeCores = totalCores
- private val localExecutorId = "localhost"
+ private val localExecutorId = SparkContext.DRIVER_IDENTIFIER
private val localExecutorHostname = "localhost"
val executor = new Executor(
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 554a33ce7f1a6..662a7b91248aa 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -117,11 +117,11 @@ class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable {
new JavaSerializerInstance(counterReset, classLoader)
}
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeInt(counterReset)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
counterReset = in.readInt()
}
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index d6386f8c06fff..621a951c27d07 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -53,7 +53,18 @@ class KryoSerializer(conf: SparkConf)
private val maxBufferSize = conf.getInt("spark.kryoserializer.buffer.max.mb", 64) * 1024 * 1024
private val referenceTracking = conf.getBoolean("spark.kryo.referenceTracking", true)
private val registrationRequired = conf.getBoolean("spark.kryo.registrationRequired", false)
- private val registrator = conf.getOption("spark.kryo.registrator")
+ private val userRegistrator = conf.getOption("spark.kryo.registrator")
+ private val classesToRegister = conf.get("spark.kryo.classesToRegister", "")
+ .split(',')
+ .filter(!_.isEmpty)
+ .map { className =>
+ try {
+ Class.forName(className)
+ } catch {
+ case e: Exception =>
+ throw new SparkException("Failed to load class to register with Kryo", e)
+ }
+ }
def newKryoOutput() = new KryoOutput(bufferSize, math.max(bufferSize, maxBufferSize))
@@ -80,22 +91,20 @@ class KryoSerializer(conf: SparkConf)
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
- // Allow the user to register their own classes by setting spark.kryo.registrator
- for (regCls <- registrator) {
- logDebug("Running user registrator: " + regCls)
- try {
- val reg = Class.forName(regCls, true, classLoader).newInstance()
- .asInstanceOf[KryoRegistrator]
-
- // Use the default classloader when calling the user registrator.
- Thread.currentThread.setContextClassLoader(classLoader)
- reg.registerClasses(kryo)
- } catch {
- case e: Exception =>
- throw new SparkException(s"Failed to invoke $regCls", e)
- } finally {
- Thread.currentThread.setContextClassLoader(oldClassLoader)
- }
+ try {
+ // Use the default classloader when calling the user registrator.
+ Thread.currentThread.setContextClassLoader(classLoader)
+ // Register classes given through spark.kryo.classesToRegister.
+ classesToRegister.foreach { clazz => kryo.register(clazz) }
+ // Allow the user to register their own classes by setting spark.kryo.registrator.
+ userRegistrator
+ .map(Class.forName(_, true, classLoader).newInstance().asInstanceOf[KryoRegistrator])
+ .foreach { reg => reg.registerClasses(kryo) }
+ } catch {
+ case e: Exception =>
+ throw new SparkException(s"Failed to register classes with Kryo", e)
+ } finally {
+ Thread.currentThread.setContextClassLoader(oldClassLoader)
}
// Register Chill's classes; we do this after our ranges and the user's own classes to let
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index a9144cdd97b8c..ca6e971d227fb 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -17,14 +17,14 @@
package org.apache.spark.serializer
-import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
+import java.io._
import java.nio.ByteBuffer
import scala.reflect.ClassTag
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
+import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator}
/**
* :: DeveloperApi ::
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 71c08e9d5a8c3..0c1b6f4defdb3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -19,6 +19,7 @@ package org.apache.spark.shuffle
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.{FetchFailed, TaskEndReason}
+import org.apache.spark.util.Utils
/**
* Failed to fetch a shuffle block. The executor catches this exception and propagates it
@@ -30,13 +31,11 @@ private[spark] class FetchFailedException(
bmAddress: BlockManagerId,
shuffleId: Int,
mapId: Int,
- reduceId: Int)
- extends Exception {
-
- override def getMessage: String =
- "Fetch failed: %s %d %d %d".format(bmAddress, shuffleId, mapId, reduceId)
+ reduceId: Int,
+ message: String)
+ extends Exception(message) {
- def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId)
+ def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, message)
}
/**
@@ -46,7 +45,4 @@ private[spark] class MetadataFetchFailedException(
shuffleId: Int,
reduceId: Int,
message: String)
- extends FetchFailedException(null, shuffleId, -1, reduceId) {
-
- override def getMessage: String = message
-}
+ extends FetchFailedException(null, shuffleId, -1, reduceId, message)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index 439981d232349..f03e8e4bf1b7e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -24,9 +24,9 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConversions._
-import org.apache.spark.{SparkEnv, SparkConf, Logging}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
@@ -62,7 +62,8 @@ private[spark] trait ShuffleWriterGroup {
* each block stored in each file. In order to find the location of a shuffle block, we search the
* files within a ShuffleFileGroups associated with the block's reducer.
*/
-
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getHashBasedShuffleBlockData().
private[spark]
class FileShuffleBlockManager(conf: SparkConf)
extends ShuffleBlockManager with Logging {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index 4ab34336d3f01..a48f0c9eceb5e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -20,8 +20,10 @@ package org.apache.spark.shuffle
import java.io._
import java.nio.ByteBuffer
+import com.google.common.io.ByteStreams
+
import org.apache.spark.SparkEnv
-import org.apache.spark.network.{ManagedBuffer, FileSegmentManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.storage._
/**
@@ -33,6 +35,8 @@ import org.apache.spark.storage._
* as the filename postfix for data file, and ".index" as the filename postfix for index file.
*
*/
+// Note: Changes to the format in this file should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData().
private[spark]
class IndexShuffleBlockManager extends ShuffleBlockManager {
@@ -101,7 +105,7 @@ class IndexShuffleBlockManager extends ShuffleBlockManager {
val in = new DataInputStream(new FileInputStream(indexFile))
try {
- in.skip(blockId.reduceId * 8)
+ ByteStreams.skipFully(in, blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
new FileSegmentManagedBuffer(
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
index 63863cc0250a3..b521f0c7fc77e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockManager.scala
@@ -18,8 +18,7 @@
package org.apache.spark.shuffle
import java.nio.ByteBuffer
-
-import org.apache.spark.network.ManagedBuffer
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.storage.ShuffleBlockId
private[spark]
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
index b30e366d06006..292e48314ee10 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleReader.scala
@@ -24,6 +24,10 @@ private[spark] trait ShuffleReader[K, C] {
/** Read the combined key-values for this reduce task */
def read(): Iterator[Product2[K, C]]
- /** Close this reader */
- def stop(): Unit
+ /**
+ * Close this reader.
+ * TODO: Add this back when we make the ShuffleReader a developer API that others can implement
+ * (at which point this will likely be necessary).
+ */
+ // def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 6cf9305977a3c..0d5247f4176d4 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -19,12 +19,13 @@ package org.apache.spark.shuffle.hash
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.util.{Failure, Success, Try}
import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
-import org.apache.spark.util.CompletionIterator
+import org.apache.spark.util.{CompletionIterator, Utils}
private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
@@ -52,21 +53,22 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}
- def unpackBlock(blockPair: (BlockId, Option[Iterator[Any]])) : Iterator[T] = {
+ def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
- case Some(block) => {
+ case Success(block) => {
block.asInstanceOf[Iterator[T]]
}
- case None => {
+ case Failure(e) => {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId)
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId,
+ Utils.exceptionString(e))
case _ =>
throw new SparkException(
- "Failed to get block " + blockId + ", which is not a shuffle block")
+ "Failed to get block " + blockId + ", which is not a shuffle block", e)
}
}
}
@@ -74,7 +76,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
- SparkEnv.get.blockTransferService,
+ SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 88a5f1e5ddf58..5baf45db45c17 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -66,7 +66,4 @@ private[spark] class HashShuffleReader[K, C](
aggregatedIter
}
}
-
- /** Close this reader */
- override def stop(): Unit = ???
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 746ed33b54c00..183a30373b28c 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -107,7 +107,7 @@ private[spark] class HashShuffleWriter[K, V](
writer.commitAndClose()
writer.fileSegment().length
}
- MapStatus(blockManager.blockManagerId, sizes)
+ MapStatus(blockManager.shuffleServerId, sizes)
}
private def revertWrites(): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 927481b72cf4f..d75f9d7311fad 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -70,7 +70,7 @@ private[spark] class SortShuffleWriter[K, V, C](
val partitionLengths = sorter.writePartitionedFile(blockId, context, outputFile)
shuffleBlockManager.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
- mapStatus = MapStatus(blockManager.blockManagerId, partitionLengths)
+ mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
/** Close this writer, passing along whether the map completed */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala b/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
deleted file mode 100644
index 5b6d086630834..0000000000000
--- a/core/src/main/scala/org/apache/spark/storage/BlockDataProvider.scala
+++ /dev/null
@@ -1,32 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.storage
-
-import java.nio.ByteBuffer
-
-
-/**
- * An interface for providing data for blocks.
- *
- * getBlockData returns either a FileSegment (for zero-copy send), or a ByteBuffer.
- *
- * Aside from unit tests, [[BlockManager]] is the main class that implements this.
- */
-private[spark] trait BlockDataProvider {
- def getBlockData(blockId: String): Either[FileSegment, ByteBuffer]
-}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index a83a3f468ae5f..1f012941c85ab 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -53,6 +53,8 @@ case class RDDBlockId(rddId: Int, splitIndex: Int) extends BlockId {
def name = "rdd_" + rddId + "_" + splitIndex
}
+// Format of the shuffle block ids (including data and index) should be kept in sync with
+// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getBlockData().
@DeveloperApi
case class ShuffleBlockId(shuffleId: Int, mapId: Int, reduceId: Int) extends BlockId {
def name = "shuffle_" + shuffleId + "_" + mapId + "_" + reduceId
@@ -83,9 +85,14 @@ case class StreamBlockId(streamId: Int, uniqueId: Long) extends BlockId {
def name = "input-" + streamId + "-" + uniqueId
}
-/** Id associated with temporary data managed as blocks. Not serializable. */
-private[spark] case class TempBlockId(id: UUID) extends BlockId {
- def name = "temp_" + id
+/** Id associated with temporary local data managed as blocks. Not serializable. */
+private[spark] case class TempLocalBlockId(id: UUID) extends BlockId {
+ def name = "temp_local_" + id
+}
+
+/** Id associated with temporary shuffle data managed as blocks. Not serializable. */
+private[spark] case class TempShuffleBlockId(id: UUID) extends BlockId {
+ def name = "temp_shuffle_" + id
}
// Intended only for testing purposes
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 3f5d06e1aeee7..5f5dd0dc1c63f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -17,14 +17,12 @@
package org.apache.spark.storage
-import java.io.{File, InputStream, OutputStream, BufferedOutputStream, ByteArrayOutputStream}
+import java.io.{BufferedOutputStream, ByteArrayOutputStream, File, InputStream, OutputStream}
import java.nio.{ByteBuffer, MappedByteBuffer}
-import scala.concurrent.ExecutionContext.Implicits.global
-
-import scala.collection.mutable
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.concurrent.{Await, Future}
+import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
import scala.util.Random
@@ -35,11 +33,16 @@ import org.apache.spark._
import org.apache.spark.executor._
import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
+import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
+import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient}
+import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.hash.HashShuffleManager
+import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.util._
-
private[spark] sealed trait BlockValues
private[spark] case class ByteBufferValues(buffer: ByteBuffer) extends BlockValues
private[spark] case class IteratorValues(iterator: Iterator[Any]) extends BlockValues
@@ -87,9 +90,38 @@ private[spark] class BlockManager(
new TachyonStore(this, tachyonBlockManager)
}
+ private[spark]
+ val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
+ private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337)
+ // Check that we're not using external shuffle service with consolidated shuffle files.
+ if (externalShuffleServiceEnabled
+ && conf.getBoolean("spark.shuffle.consolidateFiles", false)
+ && shuffleManager.isInstanceOf[HashShuffleManager]) {
+ throw new UnsupportedOperationException("Cannot use external shuffle service with consolidated"
+ + " shuffle files in hash-based shuffle. Please disable spark.shuffle.consolidateFiles or "
+ + " switch to sort-based shuffle.")
+ }
+
val blockManagerId = BlockManagerId(
executorId, blockTransferService.hostName, blockTransferService.port)
+ // Address of the server that serves this executor's shuffle files. This is either an external
+ // service, or just our own Executor's BlockManager.
+ private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) {
+ BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+ } else {
+ blockManagerId
+ }
+
+ // Client to read other executors' shuffle files. This is either an external service, or just the
+ // standard BlockTranserService to directly connect to other Executors.
+ private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
+ val appId = conf.get("spark.app.id", "unknown-app-id")
+ new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId)
+ } else {
+ blockTransferService
+ }
+
// Whether to compress broadcast variables that are stored
private val compressBroadcast = conf.getBoolean("spark.broadcast.compress", true)
// Whether to compress shuffle output that are stored
@@ -145,10 +177,41 @@ private[spark] class BlockManager(
/**
* Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor.
+ * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured.
*/
private def initialize(): Unit = {
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
+
+ // Register Executors' configuration with the local shuffle service, if one should exist.
+ if (externalShuffleServiceEnabled && !blockManagerId.isDriver) {
+ registerWithExternalShuffleServer()
+ }
+ }
+
+ private def registerWithExternalShuffleServer() {
+ logInfo("Registering executor with local external shuffle service.")
+ val shuffleConfig = new ExecutorShuffleInfo(
+ diskBlockManager.localDirs.map(_.toString),
+ diskBlockManager.subDirsPerLocalDir,
+ shuffleManager.getClass.getName)
+
+ val MAX_ATTEMPTS = 3
+ val SLEEP_TIME_SECS = 5
+
+ for (i <- 1 to MAX_ATTEMPTS) {
+ try {
+ // Synchronous and will throw an exception if we cannot connect.
+ shuffleClient.asInstanceOf[ExternalShuffleClient].registerWithShuffleServer(
+ shuffleServerId.host, shuffleServerId.port, shuffleServerId.executorId, shuffleConfig)
+ return
+ } catch {
+ case e: Exception if i < MAX_ATTEMPTS =>
+ val attemptsRemaining =
+ logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}"
+ + s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
+ Thread.sleep(SLEEP_TIME_SECS * 1000)
+ }
+ }
}
/**
@@ -212,21 +275,20 @@ private[spark] class BlockManager(
}
/**
- * Interface to get local block data.
- *
- * @return Some(buffer) if the block exists locally, and None if it doesn't.
+ * Interface to get local block data. Throws an exception if the block cannot be found or
+ * cannot be read successfully.
*/
- override def getBlockData(blockId: String): Option[ManagedBuffer] = {
- val bid = BlockId(blockId)
- if (bid.isShuffle) {
- Some(shuffleManager.shuffleBlockManager.getBlockData(bid.asInstanceOf[ShuffleBlockId]))
+ override def getBlockData(blockId: BlockId): ManagedBuffer = {
+ if (blockId.isShuffle) {
+ shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
} else {
- val blockBytesOpt = doGetLocal(bid, asBlockResult = false).asInstanceOf[Option[ByteBuffer]]
+ val blockBytesOpt = doGetLocal(blockId, asBlockResult = false)
+ .asInstanceOf[Option[ByteBuffer]]
if (blockBytesOpt.isDefined) {
val buffer = blockBytesOpt.get
- Some(new NioByteBufferManagedBuffer(buffer))
+ new NioManagedBuffer(buffer)
} else {
- None
+ throw new BlockNotFoundException(blockId.toString)
}
}
}
@@ -234,8 +296,8 @@ private[spark] class BlockManager(
/**
* Put the block locally, using the given storage level.
*/
- override def putBlockData(blockId: String, data: ManagedBuffer, level: StorageLevel): Unit = {
- putBytes(BlockId(blockId), data.nioByteBuffer(), level)
+ override def putBlockData(blockId: BlockId, data: ManagedBuffer, level: StorageLevel): Unit = {
+ putBytes(blockId, data.nioByteBuffer(), level)
}
/**
@@ -340,17 +402,6 @@ private[spark] class BlockManager(
locations
}
- /**
- * A short-circuited method to get blocks directly from disk. This is used for getting
- * shuffle blocks. It is safe to do so without a lock on block info since disk store
- * never deletes (recent) items.
- */
- def getLocalShuffleFromDisk(blockId: BlockId, serializer: Serializer): Option[Iterator[Any]] = {
- val buf = shuffleManager.shuffleBlockManager.getBlockData(blockId.asInstanceOf[ShuffleBlockId])
- val is = wrapForCompression(blockId, buf.inputStream())
- Some(serializer.newInstance().deserializeStream(is).asIterator)
- }
-
/**
* Get block from local block manager.
*/
@@ -520,7 +571,7 @@ private[spark] class BlockManager(
for (loc <- locations) {
logDebug(s"Getting remote block $blockId from $loc")
val data = blockTransferService.fetchBlockSync(
- loc.host, loc.port, blockId.toString).nioByteBuffer()
+ loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer()
if (data != null) {
if (asBlockResult) {
@@ -869,9 +920,9 @@ private[spark] class BlockManager(
data.rewind()
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
blockTransferService.uploadBlockSync(
- peer.host, peer.port, blockId.toString, new NioByteBufferManagedBuffer(data), tLevel)
- logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %f ms"
- .format((System.currentTimeMillis - onePeerStartTime)))
+ peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel)
+ logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
+ .format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
peersForReplication -= peer
replicationFailed = false
@@ -1071,7 +1122,8 @@ private[spark] class BlockManager(
case _: ShuffleBlockId => compressShuffle
case _: BroadcastBlockId => compressBroadcast
case _: RDDBlockId => compressRdds
- case _: TempBlockId => compressShuffleSpill
+ case _: TempLocalBlockId => compressShuffleSpill
+ case _: TempShuffleBlockId => compressShuffle
case _ => false
}
}
@@ -1125,7 +1177,11 @@ private[spark] class BlockManager(
}
def stop(): Unit = {
- blockTransferService.stop()
+ blockTransferService.close()
+ if (shuffleClient ne blockTransferService) {
+ // Closing should be idempotent, but maybe not for the NioBlockTransferService.
+ shuffleClient.close()
+ }
diskBlockManager.stop()
actorSystem.stop(slaveActor)
blockInfo.clear()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
index 142285094342c..b177a59c721df 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerId.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import java.util.concurrent.ConcurrentHashMap
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.Utils
@@ -59,15 +60,15 @@ class BlockManagerId private (
def port: Int = port_
- def isDriver: Boolean = (executorId == "")
+ def isDriver: Boolean = { executorId == SparkContext.DRIVER_IDENTIFIER }
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeUTF(executorId_)
out.writeUTF(host_)
out.writeInt(port_)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
executorId_ = in.readUTF()
host_ = in.readUTF()
port_ = in.readInt()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index d08e1419e3e41..b63c7f191155c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -88,6 +88,10 @@ class BlockManagerMaster(
askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
+ def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
+ }
+
/**
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 088f06e389d83..685b2e11440fb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetPeers(blockManagerId) =>
sender ! getPeers(blockManagerId)
+ case GetActorSystemHostPortForExecutor(executorId) =>
+ sender ! getActorSystemHostPortForExecutor(executorId)
+
case GetMemoryStatus =>
sender ! memoryStatus
@@ -203,6 +206,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
}
}
listenerBus.post(SparkListenerBlockManagerRemoved(System.currentTimeMillis(), blockManagerId))
+ logInfo(s"Removing block manager $blockManagerId")
}
private def expireDeadHosts() {
@@ -327,20 +331,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
val time = System.currentTimeMillis()
if (!blockManagerInfo.contains(id)) {
blockManagerIdByExecutor.get(id.executorId) match {
- case Some(manager) =>
- // A block manager of the same executor already exists.
- // This should never happen. Let's just quit.
- logError("Got two different block manager registrations on " + id.executorId)
- System.exit(1)
+ case Some(oldId) =>
+ // A block manager of the same executor already exists, so remove it (assumed dead)
+ logError("Got two different block manager registrations on same executor - "
+ + s" will replace old one $oldId with new one $id")
+ removeExecutor(id.executorId)
case None =>
- blockManagerIdByExecutor(id.executorId) = id
}
-
- logInfo("Registering block manager %s with %s RAM".format(
- id.hostPort, Utils.bytesToString(maxMemSize)))
-
- blockManagerInfo(id) =
- new BlockManagerInfo(id, time, maxMemSize, slaveActor)
+ logInfo("Registering block manager %s with %s RAM, %s".format(
+ id.hostPort, Utils.bytesToString(maxMemSize), id))
+
+ blockManagerIdByExecutor(id.executorId) = id
+
+ blockManagerInfo(id) = new BlockManagerInfo(
+ id, System.currentTimeMillis(), maxMemSize, slaveActor)
}
listenerBus.post(SparkListenerBlockManagerAdded(time, id, maxMemSize))
}
@@ -411,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
Seq.empty
}
}
+
+ /**
+ * Returns the hostname and port of an executor's actor system, based on the Akka address of its
+ * BlockManagerSlaveActor.
+ */
+ private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ for (
+ blockManagerId <- blockManagerIdByExecutor.get(executorId);
+ info <- blockManagerInfo.get(blockManagerId);
+ host <- info.slaveActor.path.address.host;
+ port <- info.slaveActor.path.address.port
+ ) yield {
+ (host, port)
+ }
+ }
}
@DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 3db5dd9774ae8..3f32099d08cc9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -21,6 +21,8 @@ import java.io.{Externalizable, ObjectInput, ObjectOutput}
import akka.actor.ActorRef
+import org.apache.spark.util.Utils
+
private[spark] object BlockManagerMessages {
//////////////////////////////////////////////////////////////////////////////////
// Messages from the master to slaves.
@@ -65,7 +67,7 @@ private[spark] object BlockManagerMessages {
def this() = this(null, null, null, 0, 0, 0) // For deserialization only
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
blockManagerId.writeExternal(out)
out.writeUTF(blockId.name)
storageLevel.writeExternal(out)
@@ -74,7 +76,7 @@ private[spark] object BlockManagerMessages {
out.writeLong(tachyonSize)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
blockManagerId = BlockManagerId(in)
blockId = BlockId(in.readUTF())
storageLevel = StorageLevel(in)
@@ -90,6 +92,8 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+ case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
case object StopBlockManagerMaster extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
index 9ef453605f4f1..81f5f2d31dbd8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockNotFoundException.scala
@@ -17,5 +17,4 @@
package org.apache.spark.storage
-
class BlockNotFoundException(blockId: String) extends Exception(s"Block $blockId not found")
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index a715594f198c2..58fba54710510 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -38,12 +38,13 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
extends Logging {
private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
- private val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
+ private[spark]
+ val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
/* Create one local directory for each path mentioned in spark.local.dir; then, inside this
* directory, create multiple subdirectories that we will hash files into, in order to avoid
* having really large inodes at the top level. */
- val localDirs: Array[File] = createLocalDirs(conf)
+ private[spark] val localDirs: Array[File] = createLocalDirs(conf)
if (localDirs.isEmpty) {
logError("Failed to create any local dir.")
System.exit(ExecutorExitCode.DISK_STORE_FAILED_TO_CREATE_DIR)
@@ -52,6 +53,9 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
addShutdownHook()
+ /** Looks up a file by hashing it into one of our local subdirectories. */
+ // This method should be kept in sync with
+ // org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getFile().
def getFile(filename: String): File = {
// Figure out which local directory it hashes to, and which subdirectory in that
val hash = Utils.nonNegativeHash(filename)
@@ -98,11 +102,20 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
getAllFiles().map(f => BlockId(f.getName))
}
- /** Produces a unique block id and File suitable for intermediate results. */
- def createTempBlock(): (TempBlockId, File) = {
- var blockId = new TempBlockId(UUID.randomUUID())
+ /** Produces a unique block id and File suitable for storing local intermediate results. */
+ def createTempLocalBlock(): (TempLocalBlockId, File) = {
+ var blockId = new TempLocalBlockId(UUID.randomUUID())
while (getFile(blockId).exists()) {
- blockId = new TempBlockId(UUID.randomUUID())
+ blockId = new TempLocalBlockId(UUID.randomUUID())
+ }
+ (blockId, getFile(blockId))
+ }
+
+ /** Produces a unique block id and File suitable for storing shuffled intermediate results. */
+ def createTempShuffleBlock(): (TempShuffleBlockId, File) = {
+ var blockId = new TempShuffleBlockId(UUID.randomUUID())
+ while (getFile(blockId).exists()) {
+ blockId = new TempShuffleBlockId(UUID.randomUUID())
}
(blockId, getFile(blockId))
}
@@ -140,7 +153,6 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
private def addShutdownHook() {
- localDirs.foreach(localDir => Utils.registerShutdownDeleteDir(localDir))
Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
override def run(): Unit = Utils.logUncaughtExceptions {
logDebug("Shutdown hook called")
@@ -151,13 +163,16 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
/** Cleanup local dirs and stop shuffle sender. */
private[spark] def stop() {
- localDirs.foreach { localDir =>
- if (localDir.isDirectory() && localDir.exists()) {
- try {
- if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
- } catch {
- case e: Exception =>
- logError(s"Exception while deleting local spark dir: $localDir", e)
+ // Only perform cleanup if an external service is not serving our shuffle files.
+ if (!blockManager.externalShuffleServiceEnabled) {
+ localDirs.foreach { localDir =>
+ if (localDir.isDirectory() && localDir.exists()) {
+ try {
+ if (!Utils.hasRootAsShutdownDeleteDir(localDir)) Utils.deleteRecursively(localDir)
+ } catch {
+ case e: Exception =>
+ logError(s"Exception while deleting local spark dir: $localDir", e)
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
index bac459e835a3f..8dadf6794039e 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala
@@ -17,7 +17,7 @@
package org.apache.spark.storage
-import java.io.{File, FileOutputStream, RandomAccessFile}
+import java.io.{IOException, File, FileOutputStream, RandomAccessFile}
import java.nio.ByteBuffer
import java.nio.channels.FileChannel.MapMode
@@ -110,7 +110,13 @@ private[spark] class DiskStore(blockManager: BlockManager, diskManager: DiskBloc
// For small files, directly read rather than memory map
if (length < minMemoryMapBytes) {
val buf = ByteBuffer.allocate(length.toInt)
- channel.read(buf, offset)
+ channel.position(offset)
+ while (buf.remaining() != 0) {
+ if (channel.read(buf) == -1) {
+ throw new IOException("Reached EOF before filling buffer\n" +
+ s"offset=$offset\nfile=${file.getAbsolutePath}\nbuf.remaining=${buf.remaining}")
+ }
+ }
buf.flip()
Some(buf)
} else {
diff --git a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
index edbc729c17ade..71305a46bf570 100644
--- a/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/MemoryStore.scala
@@ -56,6 +56,16 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
(maxMemory * unrollFraction).toLong
}
+ // Initial memory to request before unrolling any block
+ private val unrollMemoryThreshold: Long =
+ conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024)
+
+ if (maxMemory < unrollMemoryThreshold) {
+ logWarning(s"Max memory ${Utils.bytesToString(maxMemory)} is less than the initial memory " +
+ s"threshold ${Utils.bytesToString(unrollMemoryThreshold)} needed to store a block in " +
+ s"memory. Please configure Spark with more memory.")
+ }
+
logInfo("MemoryStore started with capacity %s".format(Utils.bytesToString(maxMemory)))
/** Free memory not occupied by existing blocks. Note that this does not include unroll memory. */
@@ -213,7 +223,7 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// Whether there is still enough memory for us to continue unrolling this block
var keepUnrolling = true
// Initial per-thread memory to request for unrolling blocks (bytes). Exposed for testing.
- val initialMemoryThreshold = conf.getLong("spark.storage.unrollMemoryThreshold", 1024 * 1024)
+ val initialMemoryThreshold = unrollMemoryThreshold
// How often to check whether we need to request more memory
val memoryCheckPeriod = 16
// Memory currently reserved by this thread for this particular unrolling operation
@@ -228,6 +238,11 @@ private[spark] class MemoryStore(blockManager: BlockManager, maxMemory: Long)
// Request enough memory to begin unrolling
keepUnrolling = reserveUnrollMemoryForThisThread(initialMemoryThreshold)
+ if (!keepUnrolling) {
+ logWarning(s"Failed to reserve initial memory threshold of " +
+ s"${Utils.bytesToString(initialMemoryThreshold)} for computing block $blockId in memory.")
+ }
+
// Unroll this block safely, checking whether we have exceeded our threshold periodically
try {
while (values.hasNext && keepUnrolling) {
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 71b276b5f18e4..1e579187e4193 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -19,15 +19,15 @@ package org.apache.spark.storage
import java.util.concurrent.LinkedBlockingQueue
-import scala.collection.mutable.ArrayBuffer
-import scala.collection.mutable.HashSet
-import scala.collection.mutable.Queue
+import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
+import scala.util.{Failure, Success, Try}
-import org.apache.spark.{TaskContext, Logging}
-import org.apache.spark.network.{ManagedBuffer, BlockFetchingListener, BlockTransferService}
+import org.apache.spark.{Logging, TaskContext}
+import org.apache.spark.network.BlockTransferService
+import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
+import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.Serializer
-import org.apache.spark.util.Utils
-
+import org.apache.spark.util.{CompletionIterator, Utils}
/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
@@ -40,8 +40,8 @@ import org.apache.spark.util.Utils
* using too much memory.
*
* @param context [[TaskContext]], used for metrics update
- * @param blockTransferService [[BlockTransferService]] for fetching remote blocks
- * @param blockManager [[BlockManager]] for reading local blocks
+ * @param shuffleClient [[ShuffleClient]] for fetching remote blocks
+ * @param blockManager [[BlockManager]] for reading local blocks
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
@@ -51,12 +51,12 @@ import org.apache.spark.util.Utils
private[spark]
final class ShuffleBlockFetcherIterator(
context: TaskContext,
- blockTransferService: BlockTransferService,
+ shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer,
maxBytesInFlight: Long)
- extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
+ extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
import ShuffleBlockFetcherIterator._
@@ -88,17 +88,53 @@ final class ShuffleBlockFetcherIterator(
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]
- // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
- // the number of bytes in flight is limited to maxBytesInFlight
+ /**
+ * Current [[FetchResult]] being processed. We track this so we can release the current buffer
+ * in case of a runtime exception when processing the current buffer.
+ */
+ private[this] var currentResult: FetchResult = null
+
+ /**
+ * Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+ * the number of bytes in flight is limited to maxBytesInFlight.
+ */
private[this] val fetchRequests = new Queue[FetchRequest]
- // Current bytes in flight from our requests
+ /** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L
private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
+ /**
+ * Whether the iterator is still active. If isZombie is true, the callback interface will no
+ * longer place fetched blocks into [[results]].
+ */
+ @volatile private[this] var isZombie = false
+
initialize()
+ /**
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
+ */
+ private[this] def cleanup() {
+ isZombie = true
+ // Release the current buffer if necessary
+ currentResult match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
+ }
+
+ // Release buffers in the results queue
+ val iter = results.iterator()
+ while (iter.hasNext) {
+ val result = iter.next()
+ result match {
+ case SuccessFetchResult(_, _, buf) => buf.release()
+ case _ =>
+ }
+ }
+ }
+
private[this] def sendRequest(req: FetchRequest) {
logDebug("Sending request for %d blocks (%s) from %s".format(
req.blocks.size, Utils.bytesToString(req.size), req.address.hostPort))
@@ -108,26 +144,26 @@ final class ShuffleBlockFetcherIterator(
val sizeMap = req.blocks.map { case (blockId, size) => (blockId.toString, size) }.toMap
val blockIds = req.blocks.map(_._1.toString)
- blockTransferService.fetchBlocks(req.address.host, req.address.port, blockIds,
+ val address = req.address
+ shuffleClient.fetchBlocks(address.host, address.port, address.executorId, blockIds.toArray,
new BlockFetchingListener {
- override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = {
- results.put(new FetchResult(BlockId(blockId), sizeMap(blockId),
- () => serializer.newInstance().deserializeStream(
- blockManager.wrapForCompression(BlockId(blockId), data.inputStream())).asIterator
- ))
- shuffleMetrics.remoteBytesRead += data.size
- shuffleMetrics.remoteBlocksFetched += 1
- logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+ override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): Unit = {
+ // Only add the buffer to results queue if the iterator is not zombie,
+ // i.e. cleanup() has not been called yet.
+ if (!isZombie) {
+ // Increment the ref count because we need to pass this to a different thread.
+ // This needs to be released after use.
+ buf.retain()
+ results.put(new SuccessFetchResult(BlockId(blockId), sizeMap(blockId), buf))
+ shuffleMetrics.remoteBytesRead += buf.size
+ shuffleMetrics.remoteBlocksFetched += 1
+ }
+ logTrace("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
}
- override def onBlockFetchFailure(e: Throwable): Unit = {
+ override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
logError(s"Failed to get block(s) from ${req.address.host}:${req.address.port}", e)
- // Note that there is a chance that some blocks have been fetched successfully, but we
- // still add them to the failed queue. This is fine because when the caller see a
- // FetchFailedException, it is going to fail the entire task anyway.
- for ((blockId, size) <- req.blocks) {
- results.put(new FetchResult(blockId, -1, null))
- }
+ results.put(new FailureFetchResult(BlockId(blockId), e))
}
}
)
@@ -138,7 +174,7 @@ final class ShuffleBlockFetcherIterator(
// smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
// nodes, rather than blocking on reading output from one node.
val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
- logInfo("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
+ logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " + targetRequestSize)
// Split local and remote blocks. Remote blocks are further split into FetchRequests of size
// at most maxBytesInFlight in order to limit the amount of data in flight.
@@ -148,7 +184,7 @@ final class ShuffleBlockFetcherIterator(
var totalBlocks = 0
for ((address, blockInfos) <- blocksByAddress) {
totalBlocks += blockInfos.size
- if (address == blockManager.blockManagerId) {
+ if (address.executorId == blockManager.blockManagerId.executorId) {
// Filter out zero-sized blocks
localBlocks ++= blockInfos.filter(_._2 != 0).map(_._1)
numBlocksToFetch += localBlocks.size
@@ -185,26 +221,34 @@ final class ShuffleBlockFetcherIterator(
remoteRequests
}
+ /**
+ * Fetch the local blocks while we are fetching remote blocks. This is ok because
+ * [[ManagedBuffer]]'s memory is allocated lazily when we create the input stream, so all we
+ * track in-memory are the ManagedBuffer references themselves.
+ */
private[this] def fetchLocalBlocks() {
- // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
- // these all at once because they will just memory-map some files, so they won't consume
- // any memory that might exceed our maxBytesInFlight
- for (id <- localBlocks) {
+ val iter = localBlocks.iterator
+ while (iter.hasNext) {
+ val blockId = iter.next()
try {
+ val buf = blockManager.getBlockData(blockId)
shuffleMetrics.localBlocksFetched += 1
- results.put(new FetchResult(
- id, 0, () => blockManager.getLocalShuffleFromDisk(id, serializer).get))
- logDebug("Got local block " + id)
+ buf.retain()
+ results.put(new SuccessFetchResult(blockId, 0, buf))
} catch {
case e: Exception =>
+ // If we see an exception, stop immediately.
logError(s"Error occurred while fetching local blocks", e)
- results.put(new FetchResult(id, -1, null))
+ results.put(new FailureFetchResult(blockId, e))
return
}
}
}
private[this] def initialize(): Unit = {
+ // Add a task completion callback (called in both success case and failure case) to cleanup.
+ context.addTaskCompletionListener(_ => cleanup())
+
// Split local and remote blocks.
val remoteRequests = splitLocalRemoteBlocks()
// Add the remote requests into our queue in a random order
@@ -226,21 +270,39 @@ final class ShuffleBlockFetcherIterator(
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
- override def next(): (BlockId, Option[Iterator[Any]]) = {
+ override def next(): (BlockId, Try[Iterator[Any]]) = {
numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
- val result = results.take()
+ currentResult = results.take()
+ val result = currentResult
val stopFetchWait = System.currentTimeMillis()
shuffleMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
- if (!result.failed) {
- bytesInFlight -= result.size
+
+ result match {
+ case SuccessFetchResult(_, size, _) => bytesInFlight -= size
+ case _ =>
}
// Send fetch requests up to maxBytesInFlight
while (fetchRequests.nonEmpty &&
(bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
sendRequest(fetchRequests.dequeue())
}
- (result.blockId, if (result.failed) None else Some(result.deserialize()))
+
+ val iteratorTry: Try[Iterator[Any]] = result match {
+ case FailureFetchResult(_, e) => Failure(e)
+ case SuccessFetchResult(blockId, _, buf) => {
+ val is = blockManager.wrapForCompression(blockId, buf.createInputStream())
+ val iter = serializer.newInstance().deserializeStream(is).asIterator
+ Success(CompletionIterator[Any, Iterator[Any]](iter, {
+ // Once the iterator is exhausted, release the buffer and set currentResult to null
+ // so we don't release it again in cleanup.
+ currentResult = null
+ buf.release()
+ }))
+ }
+ }
+
+ (result.blockId, iteratorTry)
}
}
@@ -254,18 +316,35 @@ object ShuffleBlockFetcherIterator {
* @param blocks Sequence of tuple, where the first element is the block id,
* and the second element is the estimated size, used to calculate bytesInFlight.
*/
- class FetchRequest(val address: BlockManagerId, val blocks: Seq[(BlockId, Long)]) {
+ case class FetchRequest(address: BlockManagerId, blocks: Seq[(BlockId, Long)]) {
val size = blocks.map(_._2).sum
}
/**
- * Result of a fetch from a remote block. A failure is represented as size == -1.
+ * Result of a fetch from a remote block.
+ */
+ private[storage] sealed trait FetchResult {
+ val blockId: BlockId
+ }
+
+ /**
+ * Result of a fetch from a remote block successfully.
* @param blockId block id
* @param size estimated size of the block, used to calculate bytesInFlight.
* Note that this is NOT the exact bytes.
- * @param deserialize closure to return the result in the form of an Iterator.
+ * @param buf [[ManagedBuffer]] for the content.
*/
- class FetchResult(val blockId: BlockId, val size: Long, val deserialize: () => Iterator[Any]) {
- def failed: Boolean = size == -1
+ private[storage] case class SuccessFetchResult(blockId: BlockId, size: Long, buf: ManagedBuffer)
+ extends FetchResult {
+ require(buf != null)
+ require(size >= 0)
}
+
+ /**
+ * Result of a fetch from a remote block unsuccessfully.
+ * @param blockId block id
+ * @param e the failure exception
+ */
+ private[storage] case class FailureFetchResult(blockId: BlockId, e: Throwable)
+ extends FetchResult
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 1e35abaab5353..56edc4fe2e4ad 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput}
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.util.Utils
/**
* :: DeveloperApi ::
@@ -97,12 +98,12 @@ class StorageLevel private(
ret
}
- override def writeExternal(out: ObjectOutput) {
+ override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
out.writeByte(toInt)
out.writeByte(_replication)
}
- override def readExternal(in: ObjectInput) {
+ override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val flags = in.readByte()
_useDisk = (flags & 8) != 0
_useMemory = (flags & 4) != 0
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
index d9066f766476e..def49e80a3605 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala
@@ -19,6 +19,7 @@ package org.apache.spark.storage
import scala.collection.mutable
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.scheduler._
@@ -59,10 +60,9 @@ class StorageStatusListener extends SparkListener {
val info = taskEnd.taskInfo
val metrics = taskEnd.taskMetrics
if (info != null && metrics != null) {
- val execId = formatExecutorId(info.executorId)
val updatedBlocks = metrics.updatedBlocks.getOrElse(Seq[(BlockId, BlockStatus)]())
if (updatedBlocks.length > 0) {
- updateStorageStatus(execId, updatedBlocks)
+ updateStorageStatus(info.executorId, updatedBlocks)
}
}
}
@@ -88,13 +88,4 @@ class StorageStatusListener extends SparkListener {
}
}
- /**
- * In the local mode, there is a discrepancy between the executor ID according to the
- * task ("localhost") and that according to SparkEnv (""). In the UI, this
- * results in duplicate rows for the same executor. Thus, in this mode, we aggregate
- * these two rows and use the executor ID of "" to be consistent.
- */
- def formatExecutorId(execId: String): String = {
- if (execId == "localhost") "" else execId
- }
}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
index 932b5616043b4..6dbad5ff0518e 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
@@ -20,6 +20,7 @@ package org.apache.spark.storage
import java.io.IOException
import java.nio.ByteBuffer
+import com.google.common.io.ByteStreams
import tachyon.client.{ReadType, WriteType}
import org.apache.spark.Logging
@@ -105,25 +106,17 @@ private[spark] class TachyonStore(
return None
}
val is = file.getInStream(ReadType.CACHE)
- var buffer: ByteBuffer = null
+ assert (is != null)
try {
- if (is != null) {
- val size = file.length
- val bs = new Array[Byte](size.asInstanceOf[Int])
- val fetchSize = is.read(bs, 0, size.asInstanceOf[Int])
- buffer = ByteBuffer.wrap(bs)
- if (fetchSize != size) {
- logWarning(s"Failed to fetch the block $blockId from Tachyon: Size $size " +
- s"is not equal to fetched size $fetchSize")
- return None
- }
- }
+ val size = file.length
+ val bs = new Array[Byte](size.asInstanceOf[Int])
+ ByteStreams.readFully(is, bs)
+ Some(ByteBuffer.wrap(bs))
} catch {
case ioe: IOException =>
logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe)
- return None
+ None
}
- Some(buffer)
}
override def contains(blockId: BlockId): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index cccd59d122a92..049938f827291 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -21,47 +21,30 @@ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext}
import org.apache.spark.scheduler._
import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.ui.env.EnvironmentTab
-import org.apache.spark.ui.exec.ExecutorsTab
-import org.apache.spark.ui.jobs.JobProgressTab
-import org.apache.spark.ui.storage.StorageTab
+import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab}
+import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab}
+import org.apache.spark.ui.jobs.{JobProgressListener, JobProgressTab}
+import org.apache.spark.ui.storage.{StorageListener, StorageTab}
/**
* Top level user interface for a Spark application.
*/
-private[spark] class SparkUI(
- val sc: SparkContext,
+private[spark] class SparkUI private (
+ val sc: Option[SparkContext],
val conf: SparkConf,
val securityManager: SecurityManager,
- val listenerBus: SparkListenerBus,
+ val environmentListener: EnvironmentListener,
+ val storageStatusListener: StorageStatusListener,
+ val executorsListener: ExecutorsListener,
+ val jobProgressListener: JobProgressListener,
+ val storageListener: StorageListener,
var appName: String,
- val basePath: String = "")
+ val basePath: String)
extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI")
with Logging {
- def this(sc: SparkContext) = this(sc, sc.conf, sc.env.securityManager, sc.listenerBus, sc.appName)
- def this(conf: SparkConf, listenerBus: SparkListenerBus, appName: String, basePath: String) =
- this(null, conf, new SecurityManager(conf), listenerBus, appName, basePath)
-
- def this(
- conf: SparkConf,
- securityManager: SecurityManager,
- listenerBus: SparkListenerBus,
- appName: String,
- basePath: String) =
- this(null, conf, securityManager, listenerBus, appName, basePath)
-
- // If SparkContext is not provided, assume the associated application is not live
- val live = sc != null
-
- // Maintain executor storage status through Spark events
- val storageStatusListener = new StorageStatusListener
-
- initialize()
-
/** Initialize all components of the server. */
def initialize() {
- listenerBus.addListener(storageStatusListener)
val jobProgressTab = new JobProgressTab(this)
attachTab(jobProgressTab)
attachTab(new StorageTab(this))
@@ -71,10 +54,10 @@ private[spark] class SparkUI(
attachHandler(createRedirectHandler("/", "/stages", basePath = basePath))
attachHandler(
createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest))
- if (live) {
- sc.env.metricsSystem.getServletHandlers.foreach(attachHandler)
- }
+ // If the UI is live, then serve
+ sc.foreach { _.env.metricsSystem.getServletHandlers.foreach(attachHandler) }
}
+ initialize()
def getAppName = appName
@@ -83,11 +66,6 @@ private[spark] class SparkUI(
appName = name
}
- /** Register the given listener with the listener bus. */
- def registerListener(listener: SparkListener) {
- listenerBus.addListener(listener)
- }
-
/** Stop the server behind this web interface. Only valid after bind(). */
override def stop() {
super.stop()
@@ -116,4 +94,60 @@ private[spark] object SparkUI {
def getUIPort(conf: SparkConf): Int = {
conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT)
}
+
+ def createLiveUI(
+ sc: SparkContext,
+ conf: SparkConf,
+ listenerBus: SparkListenerBus,
+ jobProgressListener: JobProgressListener,
+ securityManager: SecurityManager,
+ appName: String): SparkUI = {
+ create(Some(sc), conf, listenerBus, securityManager, appName,
+ jobProgressListener = Some(jobProgressListener))
+ }
+
+ def createHistoryUI(
+ conf: SparkConf,
+ listenerBus: SparkListenerBus,
+ securityManager: SecurityManager,
+ appName: String,
+ basePath: String): SparkUI = {
+ create(None, conf, listenerBus, securityManager, appName, basePath)
+ }
+
+ /**
+ * Create a new Spark UI.
+ *
+ * @param sc optional SparkContext; this can be None when reconstituting a UI from event logs.
+ * @param jobProgressListener if supplied, this JobProgressListener will be used; otherwise, the
+ * web UI will create and register its own JobProgressListener.
+ */
+ private def create(
+ sc: Option[SparkContext],
+ conf: SparkConf,
+ listenerBus: SparkListenerBus,
+ securityManager: SecurityManager,
+ appName: String,
+ basePath: String = "",
+ jobProgressListener: Option[JobProgressListener] = None): SparkUI = {
+
+ val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse {
+ val listener = new JobProgressListener(conf)
+ listenerBus.addListener(listener)
+ listener
+ }
+
+ val environmentListener = new EnvironmentListener
+ val storageStatusListener = new StorageStatusListener
+ val executorsListener = new ExecutorsListener(storageStatusListener)
+ val storageListener = new StorageListener(storageStatusListener)
+
+ listenerBus.addListener(environmentListener)
+ listenerBus.addListener(storageStatusListener)
+ listenerBus.addListener(executorsListener)
+ listenerBus.addListener(storageListener)
+
+ new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener,
+ executorsListener, _jobProgressListener, storageListener, appName, basePath)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 9ced9b8107ebf..f02904df31fcf 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -31,4 +31,16 @@ private[spark] object ToolTips {
val SHUFFLE_READ =
"""Bytes read from remote executors. Typically less than shuffle write bytes
because this does not include shuffle data read locally."""
+
+ val GETTING_RESULT_TIME =
+ """Time that the driver spends fetching task results from workers. If this is large, consider
+ decreasing the amount of data returned from each task."""
+
+ val RESULT_SERIALIZATION_TIME =
+ """Time spent serializing the task result on the executor before sending it back to the
+ driver."""
+
+ val GC_TIME =
+ """Time that the executor spent paused for Java garbage collection while the task was
+ running."""
}
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 32e6b15bb0999..3312671b6f885 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -20,13 +20,13 @@ package org.apache.spark.ui
import java.text.SimpleDateFormat
import java.util.{Locale, Date}
-import scala.xml.Node
+import scala.xml.{Node, Text}
import org.apache.spark.Logging
/** Utility functions for generating XML pages with spark content. */
private[spark] object UIUtils extends Logging {
- val TABLE_CLASS = "table table-bordered table-striped table-condensed sortable"
+ val TABLE_CLASS = "table table-bordered table-striped-custom table-condensed sortable"
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
@@ -160,6 +160,8 @@ private[spark] object UIUtils extends Logging {
+
+
}
/** Returns a spark page with correctly formatted headers */
@@ -239,7 +241,9 @@ private[spark] object UIUtils extends Logging {
headers: Seq[String],
generateDataRow: T => Seq[Node],
data: Iterable[T],
- fixedWidth: Boolean = false): Seq[Node] = {
+ fixedWidth: Boolean = false,
+ id: Option[String] = None,
+ headerClasses: Seq[String] = Seq.empty): Seq[Node] = {
var listingTableClass = TABLE_CLASS
if (fixedWidth) {
@@ -247,23 +251,32 @@ private[spark] object UIUtils extends Logging {
}
val colWidth = 100.toDouble / headers.size
val colWidthAttr = if (fixedWidth) colWidth + "%" else ""
- val headerRow: Seq[Node] = {
- // if none of the headers have "\n" in them
- if (headers.forall(!_.contains("\n"))) {
- // represent header as simple text
- headers.map(h =>
{h}
)
+
+ def getClass(index: Int): String = {
+ if (index < headerClasses.size) {
+ headerClasses(index)
} else {
- // represent header text as list while respecting "\n"
- headers.map { case h =>
-
{headerRow}
{data.map(r => generateDataRow(r))}
diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
index 5d88ca403a674..9be65a4a39a09 100644
--- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala
@@ -82,7 +82,7 @@ private[spark] abstract class WebUI(
}
/** Detach a handler from this UI. */
- def detachHandler(handler: ServletContextHandler) {
+ protected def detachHandler(handler: ServletContextHandler) {
handlers -= handler
serverInfo.foreach { info =>
info.rootHandler.removeHandler(handler)
diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
index 0d158fbe638d3..f62260c6f6e1d 100644
--- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala
@@ -22,10 +22,8 @@ import org.apache.spark.scheduler._
import org.apache.spark.ui._
private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "environment") {
- val listener = new EnvironmentListener
-
+ val listener = parent.environmentListener
attachPage(new EnvironmentPage(this))
- parent.registerListener(listener)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
new file mode 100644
index 0000000000000..e9c755e36f716
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.exec
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.util.Try
+import scala.xml.{Text, Node}
+
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") {
+
+ private val sc = parent.sc
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val executorId = Option(request.getParameter("executorId")).getOrElse {
+ return Text(s"Missing executorId parameter")
+ }
+ val time = System.currentTimeMillis()
+ val maybeThreadDump = sc.get.getExecutorThreadDump(executorId)
+
+ val content = maybeThreadDump.map { threadDump =>
+ val dumpRows = threadDump.map { thread =>
+
++
failedStagesTable.toNodeSeq
UIUtils.headerSparkPage("Spark Stages", content, parent)
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala
index c16542c9db30f..03ca918e2e8b3 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressTab.scala
@@ -25,16 +25,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab}
/** Web UI showing progress status of all jobs in the given SparkContext. */
private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "stages") {
- val live = parent.live
val sc = parent.sc
- val conf = if (live) sc.conf else new SparkConf
- val killEnabled = conf.getBoolean("spark.ui.killEnabled", true)
- val listener = new JobProgressListener(conf)
+ val conf = sc.map(_.conf).getOrElse(new SparkConf)
+ val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false)
+ val listener = parent.jobProgressListener
attachPage(new JobProgressPage(this))
attachPage(new StagePage(this))
attachPage(new PoolPage(this))
- parent.registerListener(listener)
def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR)
@@ -43,7 +41,7 @@ private[ui] class JobProgressTab(parent: SparkUI) extends SparkUITab(parent, "st
val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
val stageId = Option(request.getParameter("id")).getOrElse("-1").toInt
if (stageId >= 0 && killFlag && listener.activeStages.contains(stageId)) {
- sc.cancelStage(stageId)
+ sc.get.cancelStage(stageId)
}
// Do a quick pause here to give Spark time to kill the stage so it shows up as
// killed after the refresh. Note that this will block the serving thread so the
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
index 7a6c7d1a497ed..770d99eea1c9d 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -26,7 +26,6 @@ import org.apache.spark.ui.{WebUIPage, UIUtils}
/** Page showing specific pool details */
private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {
- private val live = parent.live
private val sc = parent.sc
private val listener = parent.listener
@@ -42,7 +41,7 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {
new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent)
// For now, pool information is only accessible in live UIs
- val pools = if (live) Seq(sc.getPoolForName(poolName).get) else Seq[Schedulable]()
+ val pools = sc.map(_.getPoolForName(poolName).get).toSeq
val poolTable = new PoolTable(pools, parent)
val content =
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 2414e4c65237e..7cc03b7d333df 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -22,10 +22,11 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.{Node, Unparsed}
+import org.apache.spark.executor.TaskMetrics
import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils}
import org.apache.spark.ui.jobs.UIData._
import org.apache.spark.util.{Utils, Distribution}
-import org.apache.spark.scheduler.AccumulableInfo
+import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
/** Page showing statistics and task list for a given stage */
private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
@@ -52,12 +53,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
val numCompleted = tasks.count(_.taskInfo.finished)
val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables
+ val hasAccumulators = accumulables.size > 0
val hasInput = stageData.inputBytes > 0
val hasShuffleRead = stageData.shuffleReadBytes > 0
val hasShuffleWrite = stageData.shuffleWriteBytes > 0
val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0
- // scalastyle:off
val summary =
@@ -65,55 +66,105 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
Total task time across all tasks:
{UIUtils.formatDuration(stageData.executorRunTime)}
- {if (hasInput)
+ {if (hasInput) {
+:
+ getFormattedTimeQuantiles(gettingResultTimes)
// The scheduler delay includes the network delay to send the task to the worker
// machine and to send back the result (but not the time to fetch the task result,
// if it needed to be fetched from the block manager on the worker).
val schedulerDelays = validTasks.map { case TaskUIData(info, metrics, _) =>
- val totalExecutionTime = {
- if (info.gettingResultTime > 0) {
- (info.gettingResultTime - info.launchTime).toDouble
- } else {
- (info.finishTime - info.launchTime).toDouble
- }
- }
- totalExecutionTime - metrics.get.executorRunTime
+ getSchedulerDelay(info, metrics.get).toDouble
}
val schedulerDelayTitle =
{inputReadable}
@@ -333,4 +415,15 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
}
}
+
+ private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = {
+ val totalExecutionTime = {
+ if (info.gettingResultTime > 0) {
+ (info.gettingResultTime - info.launchTime)
+ } else {
+ (info.finishTime - info.launchTime)
+ }
+ }
+ totalExecutionTime - metrics.executorRunTime
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
new file mode 100644
index 0000000000000..23d672cabda07
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/TaskDetailsClassNames.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+/**
+ * Names of the CSS classes corresponding to each type of task detail. Used to allow users
+ * to optionally show/hide columns.
+ */
+private object TaskDetailsClassNames {
+ val SCHEDULER_DELAY = "scheduler_delay"
+ val GC_TIME = "gc_time"
+ val RESULT_SERIALIZATION_TIME = "serialization_time"
+ val GETTING_RESULT_TIME = "getting_result_time"
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
index a336bf7e1ed02..e2813f8eb5ab9 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala
@@ -17,6 +17,7 @@
package org.apache.spark.ui.jobs
+import org.apache.spark.JobExecutionStatus
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
import org.apache.spark.util.collection.OpenHashSet
@@ -36,6 +37,13 @@ private[jobs] object UIData {
var diskBytesSpilled : Long = 0
}
+ class JobUIData(
+ var jobId: Int = -1,
+ var stageIds: Seq[Int] = Seq.empty,
+ var jobGroup: Option[String] = None,
+ var status: JobExecutionStatus = JobExecutionStatus.UNKNOWN
+ )
+
class StageUIData {
var numActiveTasks: Int = _
var numCompleteTasks: Int = _
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
index 8a0075ae8daf7..12d23a92878cf 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala
@@ -39,7 +39,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
// Worker table
val workers = storageStatusList.map((rddId, _))
- val workerTable = UIUtils.listingTable(workerHeader, workerRow, workers)
+ val workerTable = UIUtils.listingTable(workerHeader, workerRow, workers,
+ id = Some("rdd-storage-by-worker-table"))
// Block table
val blockLocations = StorageUtils.getRddBlockLocations(rddId, storageStatusList)
@@ -49,7 +50,8 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") {
.map { case (blockId, status) =>
(blockId, status, blockLocations.get(blockId).getOrElse(Seq[String]("Unknown")))
}
- val blockTable = UIUtils.listingTable(blockHeader, blockRow, blocks)
+ val blockTable = UIUtils.listingTable(blockHeader, blockRow, blocks,
+ id = Some("rdd-storage-by-block-table"))
val content =
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
index 83489ca0679ee..6ced6052d2b18 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StoragePage.scala
@@ -31,7 +31,7 @@ private[ui] class StoragePage(parent: StorageTab) extends WebUIPage("") {
def render(request: HttpServletRequest): Seq[Node] = {
val rdds = listener.rddInfoList
- val content = UIUtils.listingTable(rddHeader, rddRow, rdds)
+ val content = UIUtils.listingTable(rddHeader, rddRow, rdds, id = Some("storage-by-rdd-table"))
UIUtils.headerSparkPage("Storage", content, parent)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
index 76097f1c51f8e..a81291d505583 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala
@@ -26,11 +26,10 @@ import org.apache.spark.storage._
/** Web UI showing storage status of all RDD's in the given SparkContext. */
private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") {
- val listener = new StorageListener(parent.storageStatusListener)
+ val listener = parent.storageListener
attachPage(new StoragePage(this))
attachPage(new RDDPage(this))
- parent.registerListener(listener)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index f41c8d0315cb3..10010bdfa1a51 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -159,17 +159,28 @@ private[spark] object AkkaUtils extends Logging {
def askWithReply[T](
message: Any,
actor: ActorRef,
- retryAttempts: Int,
+ timeout: FiniteDuration): T = {
+ askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout)
+ }
+
+ /**
+ * Send a message to the given actor and get its result within a default timeout, or
+ * throw a SparkException if this fails even after the specified number of retries.
+ */
+ def askWithReply[T](
+ message: Any,
+ actor: ActorRef,
+ maxAttempts: Int,
retryInterval: Int,
timeout: FiniteDuration): T = {
// TODO: Consider removing multiple attempts
if (actor == null) {
- throw new SparkException("Error sending message as driverActor is null " +
+ throw new SparkException("Error sending message as actor is null " +
"[message = " + message + "]")
}
var attempts = 0
var lastException: Exception = null
- while (attempts < retryAttempts) {
+ while (attempts < maxAttempts) {
attempts += 1
try {
val future = actor.ask(message)(timeout)
@@ -201,4 +212,18 @@ private[spark] object AkkaUtils extends Logging {
logInfo(s"Connecting to $name: $url")
Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
}
+
+ def makeExecutorRef(
+ name: String,
+ conf: SparkConf,
+ host: String,
+ port: Int,
+ actorSystem: ActorSystem): ActorRef = {
+ val executorActorSystemName = SparkEnv.executorActorSystemName
+ Utils.checkHost(host, "Expected hostname")
+ val url = s"akka.tcp://$executorActorSystemName@$host:$port/user/$name"
+ val timeout = AkkaUtils.lookupTimeout(conf)
+ logInfo(s"Connecting to $name: $url")
+ Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index 5b2e7d3a7edb9..f7ae1f7f334de 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -272,14 +272,15 @@ private[spark] object JsonProtocol {
def taskEndReasonToJson(taskEndReason: TaskEndReason): JValue = {
val reason = Utils.getFormattedClassName(taskEndReason)
- val json = taskEndReason match {
+ val json: JObject = taskEndReason match {
case fetchFailed: FetchFailed =>
val blockManagerAddress = Option(fetchFailed.bmAddress).
map(blockManagerIdToJson).getOrElse(JNothing)
("Block Manager Address" -> blockManagerAddress) ~
("Shuffle ID" -> fetchFailed.shuffleId) ~
("Map ID" -> fetchFailed.mapId) ~
- ("Reduce ID" -> fetchFailed.reduceId)
+ ("Reduce ID" -> fetchFailed.reduceId) ~
+ ("Message" -> fetchFailed.message)
case exceptionFailure: ExceptionFailure =>
val stackTrace = stackTraceToJson(exceptionFailure.stackTrace)
val metrics = exceptionFailure.metrics.map(taskMetricsToJson).getOrElse(JNothing)
@@ -287,6 +288,8 @@ private[spark] object JsonProtocol {
("Description" -> exceptionFailure.description) ~
("Stack Trace" -> stackTrace) ~
("Metrics" -> metrics)
+ case ExecutorLostFailure(executorId) =>
+ ("Executor ID" -> executorId)
case _ => Utils.emptyJson
}
("Reason" -> reason) ~ json
@@ -627,7 +630,9 @@ private[spark] object JsonProtocol {
val shuffleId = (json \ "Shuffle ID").extract[Int]
val mapId = (json \ "Map ID").extract[Int]
val reduceId = (json \ "Reduce ID").extract[Int]
- new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId)
+ val message = Utils.jsonOption(json \ "Message").map(_.extract[String])
+ new FetchFailed(blockManagerAddress, shuffleId, mapId, reduceId,
+ message.getOrElse("Unknown reason"))
case `exceptionFailure` =>
val className = (json \ "Class Name").extract[String]
val description = (json \ "Description").extract[String]
@@ -636,7 +641,9 @@ private[spark] object JsonProtocol {
new ExceptionFailure(className, description, stackTrace, metrics)
case `taskResultLost` => TaskResultLost
case `taskKilled` => TaskKilled
- case `executorLostFailure` => ExecutorLostFailure
+ case `executorLostFailure` =>
+ val executorId = Utils.jsonOption(json \ "Executor ID").map(_.extract[String])
+ ExecutorLostFailure(executorId.getOrElse("Unknown"))
case `unknownReason` => UnknownReason
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
index 2b452ad33b021..770ff9d5ad6ae 100644
--- a/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/SerializableBuffer.scala
@@ -29,7 +29,7 @@ private[spark]
class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable {
def value = buffer
- private def readObject(in: ObjectInputStream) {
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
val length = in.readInt()
buffer = ByteBuffer.allocate(length)
var amountRead = 0
@@ -44,7 +44,7 @@ class SerializableBuffer(@transient var buffer: ByteBuffer) extends Serializable
buffer.rewind() // Allow us to read it later
}
- private def writeObject(out: ObjectOutputStream) {
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
out.writeInt(buffer.limit())
if (Channels.newChannel(out).write(buffer) != buffer.limit()) {
throw new IOException("Could not fully write buffer to output stream")
diff --git a/core/src/main/scala/org/apache/spark/util/SparkExitCode.scala b/core/src/main/scala/org/apache/spark/util/SparkExitCode.scala
new file mode 100644
index 0000000000000..c93b1cca9f564
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SparkExitCode.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+private[spark] object SparkExitCode {
+ /** The default uncaught exception handler was reached. */
+ val UNCAUGHT_EXCEPTION = 50
+
+ /** The default uncaught exception handler was called and an exception was encountered while
+ logging the exception. */
+ val UNCAUGHT_EXCEPTION_TWICE = 51
+
+ /** The default uncaught exception handler was reached, and the uncaught exception was an
+ OutOfMemoryError. */
+ val OOM = 52
+
+}
diff --git a/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
new file mode 100644
index 0000000000000..ad3db1fbb57ed
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/SparkUncaughtExceptionHandler.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import org.apache.spark.Logging
+
+/**
+ * The default uncaught exception handler for Executors terminates the whole process, to avoid
+ * getting into a bad state indefinitely. Since Executors are relatively lightweight, it's better
+ * to fail fast when things go wrong.
+ */
+private[spark] object SparkUncaughtExceptionHandler
+ extends Thread.UncaughtExceptionHandler with Logging {
+
+ override def uncaughtException(thread: Thread, exception: Throwable) {
+ try {
+ logError("Uncaught exception in thread " + thread, exception)
+
+ // We may have been called from a shutdown hook. If so, we must not call System.exit().
+ // (If we do, we will deadlock.)
+ if (!Utils.inShutdown()) {
+ if (exception.isInstanceOf[OutOfMemoryError]) {
+ System.exit(SparkExitCode.OOM)
+ } else {
+ System.exit(SparkExitCode.UNCAUGHT_EXCEPTION)
+ }
+ }
+ } catch {
+ case oom: OutOfMemoryError => Runtime.getRuntime.halt(SparkExitCode.OOM)
+ case t: Throwable => Runtime.getRuntime.halt(SparkExitCode.UNCAUGHT_EXCEPTION_TWICE)
+ }
+ }
+
+ def uncaughtException(exception: Throwable) {
+ uncaughtException(Thread.currentThread(), exception)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
new file mode 100644
index 0000000000000..d4e0ad93b966a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/ThreadStackTrace.scala
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+/**
+ * Used for shipping per-thread stacktraces from the executors to driver.
+ */
+private[spark] case class ThreadStackTrace(
+ threadId: Long,
+ threadName: String,
+ threadState: Thread.State,
+ stackTrace: String)
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 53a7512edd852..6ab94af9f3739 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -18,12 +18,13 @@
package org.apache.spark.util
import java.io._
+import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
+import java.util.jar.Attributes.Name
import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent.{ThreadFactory, ConcurrentHashMap, Executors, ThreadPoolExecutor}
-
-import org.eclipse.jetty.util.MultiException
+import java.util.jar.{Manifest => JarManifest}
import scala.collection.JavaConversions._
import scala.collection.Map
@@ -33,17 +34,17 @@ import scala.reflect.ClassTag
import scala.util.Try
import scala.util.control.{ControlThrowable, NonFatal}
-import com.google.common.io.Files
+import com.google.common.io.{ByteStreams, Files}
import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.commons.lang3.SystemUtils
import org.apache.hadoop.conf.Configuration
import org.apache.log4j.PropertyConfigurator
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
+import org.eclipse.jetty.util.MultiException
import org.json4s._
import tachyon.client.{TachyonFile,TachyonFS}
import org.apache.spark._
-import org.apache.spark.executor.ExecutorUncaughtExceptionHandler
import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance}
/** CallSite represents a place in user code. It can have a short and a long form. */
@@ -269,23 +270,44 @@ private[spark] object Utils extends Logging {
dir
}
- /** Copy all data from an InputStream to an OutputStream */
+ /** Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream
+ * copying is disabled by default unless explicitly set transferToEnabled as true,
+ * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false].
+ */
def copyStream(in: InputStream,
out: OutputStream,
- closeStreams: Boolean = false): Long =
+ closeStreams: Boolean = false,
+ transferToEnabled: Boolean = false): Long =
{
var count = 0L
try {
- if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]) {
+ if (in.isInstanceOf[FileInputStream] && out.isInstanceOf[FileOutputStream]
+ && transferToEnabled) {
// When both streams are File stream, use transferTo to improve copy performance.
val inChannel = in.asInstanceOf[FileInputStream].getChannel()
val outChannel = out.asInstanceOf[FileOutputStream].getChannel()
+ val initialPos = outChannel.position()
val size = inChannel.size()
// In case transferTo method transferred less data than we have required.
while (count < size) {
count += inChannel.transferTo(count, size - count, outChannel)
}
+
+ // Check the position after transferTo loop to see if it is in the right position and
+ // give user information if not.
+ // Position will not be increased to the expected length after calling transferTo in
+ // kernel version 2.6.32, this issue can be seen in
+ // https://bugs.openjdk.java.net/browse/JDK-7052359
+ // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948).
+ val finalPos = outChannel.position()
+ assert(finalPos == initialPos + size,
+ s"""
+ |Current position $finalPos do not equal to expected position ${initialPos + size}
+ |after transferTo, please check your kernel version to see if it is 2.6.32,
+ |this is a kernel bug which will lead to unexpected behavior when using transferTo.
+ |You can set spark.file.transferTo = false to disable this NIO feature.
+ """.stripMargin)
} else {
val buf = new Array[Byte](8192)
var n = 0
@@ -326,15 +348,84 @@ private[spark] object Utils extends Logging {
}
/**
- * Download a file requested by the executor. Supports fetching the file in a variety of ways,
+ * Download a file to target directory. Supports fetching the file in a variety of ways,
* including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
*
+ * If `useCache` is true, first attempts to fetch the file to a local cache that's shared
+ * across executors running the same application. `useCache` is used mainly for
+ * the executors, and not in local mode.
+ *
* Throws SparkException if the target file already exists and has different contents than
* the requested file.
*/
- def fetchFile(url: String, targetDir: File, conf: SparkConf, securityMgr: SecurityManager,
- hadoopConf: Configuration) {
- val filename = url.split("/").last
+ def fetchFile(
+ url: String,
+ targetDir: File,
+ conf: SparkConf,
+ securityMgr: SecurityManager,
+ hadoopConf: Configuration,
+ timestamp: Long,
+ useCache: Boolean) {
+ val fileName = url.split("/").last
+ val targetFile = new File(targetDir, fileName)
+ if (useCache) {
+ val cachedFileName = s"${url.hashCode}${timestamp}_cache"
+ val lockFileName = s"${url.hashCode}${timestamp}_lock"
+ val localDir = new File(getLocalDir(conf))
+ val lockFile = new File(localDir, lockFileName)
+ val raf = new RandomAccessFile(lockFile, "rw")
+ // Only one executor entry.
+ // The FileLock is only used to control synchronization for executors download file,
+ // it's always safe regardless of lock type (mandatory or advisory).
+ val lock = raf.getChannel().lock()
+ val cachedFile = new File(localDir, cachedFileName)
+ try {
+ if (!cachedFile.exists()) {
+ doFetchFile(url, localDir, cachedFileName, conf, securityMgr, hadoopConf)
+ }
+ } finally {
+ lock.release()
+ }
+ if (targetFile.exists && !Files.equal(cachedFile, targetFile)) {
+ if (conf.getBoolean("spark.files.overwrite", false)) {
+ targetFile.delete()
+ logInfo((s"File $targetFile exists and does not match contents of $url, " +
+ s"replacing it with $url"))
+ } else {
+ throw new SparkException(s"File $targetFile exists and does not match contents of $url")
+ }
+ }
+ Files.copy(cachedFile, targetFile)
+ } else {
+ doFetchFile(url, targetDir, fileName, conf, securityMgr, hadoopConf)
+ }
+
+ // Decompress the file if it's a .tar or .tar.gz
+ if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) {
+ logInfo("Untarring " + fileName)
+ Utils.execute(Seq("tar", "-xzf", fileName), targetDir)
+ } else if (fileName.endsWith(".tar")) {
+ logInfo("Untarring " + fileName)
+ Utils.execute(Seq("tar", "-xf", fileName), targetDir)
+ }
+ // Make the file executable - That's necessary for scripts
+ FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
+ }
+
+ /**
+ * Download a file to target directory. Supports fetching the file in a variety of ways,
+ * including HTTP, HDFS and files on a standard filesystem, based on the URL parameter.
+ *
+ * Throws SparkException if the target file already exists and has different contents than
+ * the requested file.
+ */
+ private def doFetchFile(
+ url: String,
+ targetDir: File,
+ filename: String,
+ conf: SparkConf,
+ securityMgr: SecurityManager,
+ hadoopConf: Configuration) {
val tempDir = getLocalDir(conf)
val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir))
val targetFile = new File(targetDir, filename)
@@ -422,16 +513,6 @@ private[spark] object Utils extends Logging {
}
Files.move(tempFile, targetFile)
}
- // Decompress the file if it's a .tar or .tar.gz
- if (filename.endsWith(".tar.gz") || filename.endsWith(".tgz")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xzf", filename), targetDir)
- } else if (filename.endsWith(".tar")) {
- logInfo("Untarring " + filename)
- Utils.execute(Seq("tar", "-xf", filename), targetDir)
- }
- // Make the file executable - That's necessary for scripts
- FileUtil.chmod(targetFile.getAbsolutePath, "a+x")
}
/**
@@ -659,11 +740,15 @@ private[spark] object Utils extends Logging {
}
private def listFilesSafely(file: File): Seq[File] = {
- val files = file.listFiles()
- if (files == null) {
- throw new IOException("Failed to list files for dir: " + file)
+ if (file.exists()) {
+ val files = file.listFiles()
+ if (files == null) {
+ throw new IOException("Failed to list files for dir: " + file)
+ }
+ files
+ } else {
+ List()
}
- files
}
/**
@@ -727,7 +812,7 @@ private[spark] object Utils extends Logging {
/**
* Determines if a directory contains any files newer than cutoff seconds.
- *
+ *
* @param dir must be the path to a directory, or IllegalArgumentException is thrown
* @param cutoff measured in seconds. Returns true if there are any files or directories in the
* given directory whose last modified time is later than this many seconds ago
@@ -885,7 +970,37 @@ private[spark] object Utils extends Logging {
block
} catch {
case e: ControlThrowable => throw e
- case t: Throwable => ExecutorUncaughtExceptionHandler.uncaughtException(t)
+ case t: Throwable => SparkUncaughtExceptionHandler.uncaughtException(t)
+ }
+ }
+
+ /**
+ * Execute a block of code that evaluates to Unit, re-throwing any non-fatal uncaught
+ * exceptions as IOException. This is used when implementing Externalizable and Serializable's
+ * read and write methods, since Java's serializer will not report non-IOExceptions properly;
+ * see SPARK-4080 for more context.
+ */
+ def tryOrIOException(block: => Unit) {
+ try {
+ block
+ } catch {
+ case e: IOException => throw e
+ case NonFatal(t) => throw new IOException(t)
+ }
+ }
+
+ /**
+ * Execute a block of code that returns a value, re-throwing any non-fatal uncaught
+ * exceptions as IOException. This is used when implementing Externalizable and Serializable's
+ * read and write methods, since Java's serializer will not report non-IOExceptions properly;
+ * see SPARK-4080 for more context.
+ */
+ def tryOrIOException[T](block: => T): T = {
+ try {
+ block
+ } catch {
+ case e: IOException => throw e
+ case NonFatal(t) => throw new IOException(t)
}
}
@@ -893,7 +1008,8 @@ private[spark] object Utils extends Logging {
private def coreExclusionFunction(className: String): Boolean = {
// A regular expression to match classes of the "core" Spark API that we want to skip when
// finding the call site of a method.
- val SPARK_CORE_CLASS_REGEX = """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?\.[A-Z]""".r
+ val SPARK_CORE_CLASS_REGEX =
+ """^org\.apache\.spark(\.api\.java)?(\.util)?(\.rdd)?(\.broadcast)?\.[A-Z]""".r
val SCALA_CLASS_REGEX = """^scala""".r
val isSparkCoreClass = SPARK_CORE_CLASS_REGEX.findFirstIn(className).isDefined
val isScalaClass = SCALA_CLASS_REGEX.findFirstIn(className).isDefined
@@ -962,8 +1078,8 @@ private[spark] object Utils extends Logging {
val stream = new FileInputStream(file)
try {
- stream.skip(effectiveStart)
- stream.read(buff)
+ ByteStreams.skipFully(stream, effectiveStart)
+ ByteStreams.readFully(stream, buff)
} finally {
stream.close()
}
@@ -1124,6 +1240,8 @@ private[spark] object Utils extends Logging {
}
// Handles idiosyncracies with hash (add more as required)
+ // This method should be kept in sync with
+ // org.apache.spark.network.util.JavaUtils#nonNegativeHash().
def nonNegativeHash(obj: AnyRef): Int = {
// Required ?
@@ -1157,12 +1275,28 @@ private[spark] object Utils extends Logging {
/**
* Timing method based on iterations that permit JVM JIT optimization.
* @param numIters number of iterations
- * @param f function to be executed
+ * @param f function to be executed. If prepare is not None, the running time of each call to f
+ * must be an order of magnitude longer than one millisecond for accurate timing.
+ * @param prepare function to be executed before each call to f. Its running time doesn't count.
+ * @return the total time across all iterations (not couting preparation time)
*/
- def timeIt(numIters: Int)(f: => Unit): Long = {
- val start = System.currentTimeMillis
- times(numIters)(f)
- System.currentTimeMillis - start
+ def timeIt(numIters: Int)(f: => Unit, prepare: Option[() => Unit] = None): Long = {
+ if (prepare.isEmpty) {
+ val start = System.currentTimeMillis
+ times(numIters)(f)
+ System.currentTimeMillis - start
+ } else {
+ var i = 0
+ var sum = 0L
+ while (i < numIters) {
+ prepare.get.apply()
+ val start = System.currentTimeMillis
+ f
+ sum += System.currentTimeMillis - start
+ i += 1
+ }
+ sum
+ }
}
/**
@@ -1251,6 +1385,11 @@ private[spark] object Utils extends Logging {
*/
val isWindows = SystemUtils.IS_OS_WINDOWS
+ /**
+ * Whether the underlying operating system is Mac OS X.
+ */
+ val isMac = SystemUtils.IS_OS_MAC_OSX
+
/**
* Pattern for matching a Windows drive, which contains only a single alphabet character.
*/
@@ -1459,7 +1598,7 @@ private[spark] object Utils extends Logging {
}
/** Return a nice string representation of the exception, including the stack trace. */
- def exceptionString(e: Exception): String = {
+ def exceptionString(e: Throwable): String = {
if (e == null) "" else exceptionString(getFormattedClassName(e), e.getMessage, e.getStackTrace)
}
@@ -1473,6 +1612,18 @@ private[spark] object Utils extends Logging {
s"$className: $desc\n$st"
}
+ /** Return a thread dump of all threads' stacktraces. Used to capture dumps for the web UI */
+ def getThreadDump(): Array[ThreadStackTrace] = {
+ // We need to filter out null values here because dumpAllThreads() may return null array
+ // elements for threads that are dead / don't exist.
+ val threadInfos = ManagementFactory.getThreadMXBean.dumpAllThreads(true, true).filter(_ != null)
+ threadInfos.sortBy(_.getThreadId).map { case threadInfo =>
+ val stackTrace = threadInfo.getStackTrace.map(_.toString).mkString("\n")
+ ThreadStackTrace(threadInfo.getThreadId, threadInfo.getThreadName,
+ threadInfo.getThreadState, stackTrace)
+ }
+ }
+
/**
* Convert all spark properties set in the given SparkConf to a sequence of java options.
*/
@@ -1573,6 +1724,62 @@ private[spark] object Utils extends Logging {
PropertyConfigurator.configure(pro)
}
+ def invoke(
+ clazz: Class[_],
+ obj: AnyRef,
+ methodName: String,
+ args: (Class[_], AnyRef)*): AnyRef = {
+ val (types, values) = args.unzip
+ val method = clazz.getDeclaredMethod(methodName, types: _*)
+ method.setAccessible(true)
+ method.invoke(obj, values.toSeq: _*)
+ }
+
+ // Limit of bytes for total size of results (default is 1GB)
+ def getMaxResultSize(conf: SparkConf): Long = {
+ memoryStringToMb(conf.get("spark.driver.maxResultSize", "1g")).toLong << 20
+ }
+
+ /**
+ * Return the current system LD_LIBRARY_PATH name
+ */
+ def libraryPathEnvName: String = {
+ if (isWindows) {
+ "PATH"
+ } else if (isMac) {
+ "DYLD_LIBRARY_PATH"
+ } else {
+ "LD_LIBRARY_PATH"
+ }
+ }
+
+ /**
+ * Return the prefix of a command that appends the given library paths to the
+ * system-specific library path environment variable. On Unix, for instance,
+ * this returns the string LD_LIBRARY_PATH="path1:path2:$LD_LIBRARY_PATH".
+ */
+ def libraryPathEnvPrefix(libraryPaths: Seq[String]): String = {
+ val libraryPathScriptVar = if (isWindows) {
+ s"%${libraryPathEnvName}%"
+ } else {
+ "$" + libraryPathEnvName
+ }
+ val libraryPath = (libraryPaths :+ libraryPathScriptVar).mkString("\"",
+ File.pathSeparator, "\"")
+ val ampersand = if (Utils.isWindows) {
+ " &"
+ } else {
+ ""
+ }
+ s"$libraryPathEnvName=$libraryPath$ampersand"
+ }
+
+ lazy val sparkVersion =
+ SparkContext.jarOfObject(this).map { path =>
+ val manifestUrl = new URL(s"jar:file:$path!/META-INF/MANIFEST.MF")
+ val manifest = new JarManifest(manifestUrl.openStream())
+ manifest.getMainAttributes.getValue(Name.IMPLEMENTATION_VERSION)
+ }.getOrElse("Unknown")
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 0c088da46aa5e..26fa0cb6d7bde 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -153,7 +153,7 @@ class ExternalAppendOnlyMap[K, V, C](
* Sort the existing contents of the in-memory map and spill them to a temporary file on disk.
*/
override protected[this] def spill(collection: SizeTracker): Unit = {
- val (blockId, file) = diskBlockManager.createTempBlock()
+ val (blockId, file) = diskBlockManager.createTempLocalBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, serializer, fileBufferSize,
curWriteMetrics)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 644fa36818647..c1ce13683b569 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -38,6 +38,11 @@ import org.apache.spark.storage.{BlockObjectWriter, BlockId}
*
* If combining is disabled, the type C must equal V -- we'll cast the objects at the end.
*
+ * Note: Although ExternalSorter is a fairly generic sorter, some of its configuration is tied
+ * to its use in sort-based shuffle (for example, its block compression is controlled by
+ * `spark.shuffle.compress`). We may need to revisit this if ExternalSorter is used in other
+ * non-shuffle contexts where we might want to use different configuration settings.
+ *
* @param aggregator optional Aggregator with combine functions to use for merging data
* @param partitioner optional Partitioner; if given, sort by partition ID and then key
* @param ordering optional Ordering to sort keys within each partition; should be a total ordering
@@ -93,6 +98,7 @@ private[spark] class ExternalSorter[K, V, C](
private val conf = SparkEnv.get.conf
private val spillingEnabled = conf.getBoolean("spark.shuffle.spill", true)
private val fileBufferSize = conf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
+ private val transferToEnabled = conf.getBoolean("spark.file.transferTo", true)
// Size of object batches when reading/writing from serializers.
//
@@ -258,7 +264,10 @@ private[spark] class ExternalSorter[K, V, C](
private def spillToMergeableFile(collection: SizeTrackingPairCollection[(Int, K), C]): Unit = {
assert(!bypassMergeSort)
- val (blockId, file) = diskBlockManager.createTempBlock()
+ // Because these files may be read during shuffle, their compression must be controlled by
+ // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+ // createTempShuffleBlock here; see SPARK-3426 for more context.
+ val (blockId, file) = diskBlockManager.createTempShuffleBlock()
curWriteMetrics = new ShuffleWriteMetrics()
var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics)
var objectsWritten = 0 // Objects written since the last flush
@@ -337,7 +346,10 @@ private[spark] class ExternalSorter[K, V, C](
if (partitionWriters == null) {
curWriteMetrics = new ShuffleWriteMetrics()
partitionWriters = Array.fill(numPartitions) {
- val (blockId, file) = diskBlockManager.createTempBlock()
+ // Because these files may be read during shuffle, their compression must be controlled by
+ // spark.shuffle.compress instead of spark.shuffle.spill.compress, so we need to use
+ // createTempShuffleBlock here; see SPARK-3426 for more context.
+ val (blockId, file) = diskBlockManager.createTempShuffleBlock()
blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics).open()
}
}
@@ -705,10 +717,10 @@ private[spark] class ExternalSorter[K, V, C](
var out: FileOutputStream = null
var in: FileInputStream = null
try {
- out = new FileOutputStream(outputFile)
+ out = new FileOutputStream(outputFile, true)
for (i <- 0 until numPartitions) {
in = new FileInputStream(partitionWriters(i).fileSegment().file)
- val size = org.apache.spark.util.Utils.copyStream(in, out, false)
+ val size = org.apache.spark.util.Utils.copyStream(in, out, false, transferToEnabled)
in.close()
in = null
lengths(i) = size
diff --git a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
index ac1528969f0be..4f0bf8384afc9 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/SortDataFormat.scala
@@ -27,33 +27,51 @@ import scala.reflect.ClassTag
* Example format: an array of numbers, where each element is also the key.
* See [[KVArraySortDataFormat]] for a more exciting format.
*
- * This trait extends Any to ensure it is universal (and thus compiled to a Java interface).
+ * Note: Declaring and instantiating multiple subclasses of this class would prevent JIT inlining
+ * overridden methods and hence decrease the shuffle performance.
*
* @tparam K Type of the sort key of each element
* @tparam Buffer Internal data structure used by a particular format (e.g., Array[Int]).
*/
// TODO: Making Buffer a real trait would be a better abstraction, but adds some complexity.
-private[spark] trait SortDataFormat[K, Buffer] extends Any {
+private[spark]
+abstract class SortDataFormat[K, Buffer] {
+
+ /**
+ * Creates a new mutable key for reuse. This should be implemented if you want to override
+ * [[getKey(Buffer, Int, K)]].
+ */
+ def newKey(): K = null.asInstanceOf[K]
+
/** Return the sort key for the element at the given index. */
protected def getKey(data: Buffer, pos: Int): K
+ /**
+ * Returns the sort key for the element at the given index and reuse the input key if possible.
+ * The default implementation ignores the reuse parameter and invokes [[getKey(Buffer, Int]].
+ * If you want to override this method, you must implement [[newKey()]].
+ */
+ def getKey(data: Buffer, pos: Int, reuse: K): K = {
+ getKey(data, pos)
+ }
+
/** Swap two elements. */
- protected def swap(data: Buffer, pos0: Int, pos1: Int): Unit
+ def swap(data: Buffer, pos0: Int, pos1: Int): Unit
/** Copy a single element from src(srcPos) to dst(dstPos). */
- protected def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit
+ def copyElement(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int): Unit
/**
* Copy a range of elements starting at src(srcPos) to dst, starting at dstPos.
* Overlapping ranges are allowed.
*/
- protected def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit
+ def copyRange(src: Buffer, srcPos: Int, dst: Buffer, dstPos: Int, length: Int): Unit
/**
* Allocates a Buffer that can hold up to 'length' elements.
* All elements of the buffer should be considered invalid until data is explicitly copied in.
*/
- protected def allocate(length: Int): Buffer
+ def allocate(length: Int): Buffer
}
/**
@@ -67,9 +85,9 @@ private[spark] trait SortDataFormat[K, Buffer] extends Any {
private[spark]
class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K, Array[T]] {
- override protected def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K]
+ override def getKey(data: Array[T], pos: Int): K = data(2 * pos).asInstanceOf[K]
- override protected def swap(data: Array[T], pos0: Int, pos1: Int) {
+ override def swap(data: Array[T], pos0: Int, pos1: Int) {
val tmpKey = data(2 * pos0)
val tmpVal = data(2 * pos0 + 1)
data(2 * pos0) = data(2 * pos1)
@@ -78,17 +96,16 @@ class KVArraySortDataFormat[K, T <: AnyRef : ClassTag] extends SortDataFormat[K,
data(2 * pos1 + 1) = tmpVal
}
- override protected def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) {
+ override def copyElement(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int) {
dst(2 * dstPos) = src(2 * srcPos)
dst(2 * dstPos + 1) = src(2 * srcPos + 1)
}
- override protected def copyRange(src: Array[T], srcPos: Int,
- dst: Array[T], dstPos: Int, length: Int) {
+ override def copyRange(src: Array[T], srcPos: Int, dst: Array[T], dstPos: Int, length: Int) {
System.arraycopy(src, 2 * srcPos, dst, 2 * dstPos, 2 * length)
}
- override protected def allocate(length: Int): Array[T] = {
+ override def allocate(length: Int): Array[T] = {
new Array[T](2 * length)
}
}
diff --git a/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala
new file mode 100644
index 0000000000000..39f66b8c428c6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/collection/Sorter.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util.collection
+
+import java.util.Comparator
+
+/**
+ * A simple wrapper over the Java implementation [[TimSort]].
+ *
+ * The Java implementation is package private, and hence it cannot be called outside package
+ * org.apache.spark.util.collection. This is a simple wrapper of it that is available to spark.
+ */
+private[spark]
+class Sorter[K, Buffer](private val s: SortDataFormat[K, Buffer]) {
+
+ private val timSort = new TimSort(s)
+
+ /**
+ * Sorts the input buffer within range [lo, hi).
+ */
+ def sort(a: Buffer, lo: Int, hi: Int, c: Comparator[_ >: K]): Unit = {
+ timSort.sort(a, lo, hi, c)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 32c5fdad75e58..76e7a2760bcd1 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -19,8 +19,10 @@ package org.apache.spark.util.random
import java.util.Random
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import scala.reflect.ClassTag
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.annotation.DeveloperApi
@@ -39,13 +41,47 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
/** take a random sample */
def sample(items: Iterator[T]): Iterator[U]
+ /** return a copy of the RandomSampler object */
override def clone: RandomSampler[T, U] =
throw new NotImplementedError("clone() is not implemented.")
}
+private[spark]
+object RandomSampler {
+ /** Default random number generator used by random samplers. */
+ def newDefaultRNG: Random = new XORShiftRandom
+
+ /**
+ * Default maximum gap-sampling fraction.
+ * For sampling fractions <= this value, the gap sampling optimization will be applied.
+ * Above this value, it is assumed that "tradtional" Bernoulli sampling is faster. The
+ * optimal value for this will depend on the RNG. More expensive RNGs will tend to make
+ * the optimal value higher. The most reliable way to determine this value for a new RNG
+ * is to experiment. When tuning for a new RNG, I would expect a value of 0.5 to be close
+ * in most cases, as an initial guess.
+ */
+ val defaultMaxGapSamplingFraction = 0.4
+
+ /**
+ * Default epsilon for floating point numbers sampled from the RNG.
+ * The gap-sampling compute logic requires taking log(x), where x is sampled from an RNG.
+ * To guard against errors from taking log(0), a positive epsilon lower bound is applied.
+ * A good value for this parameter is at or near the minimum positive floating
+ * point value returned by "nextDouble()" (or equivalent), for the RNG being used.
+ */
+ val rngEpsilon = 5e-11
+
+ /**
+ * Sampling fraction arguments may be results of computation, and subject to floating
+ * point jitter. I check the arguments with this epsilon slop factor to prevent spurious
+ * warnings for cases such as summing some numbers to get a sampling fraction of 1.000000001
+ */
+ val roundingEpsilon = 1e-6
+}
+
/**
* :: DeveloperApi ::
- * A sampler based on Bernoulli trials.
+ * A sampler based on Bernoulli trials for partitioning a data sequence.
*
* @param lb lower bound of the acceptance range
* @param ub upper bound of the acceptance range
@@ -53,56 +89,262 @@ trait RandomSampler[T, U] extends Pseudorandom with Cloneable with Serializable
* @tparam T item type
*/
@DeveloperApi
-class BernoulliSampler[T](lb: Double, ub: Double, complement: Boolean = false)
+class BernoulliCellSampler[T](lb: Double, ub: Double, complement: Boolean = false)
extends RandomSampler[T, T] {
- private[random] var rng: Random = new XORShiftRandom
+ /** epsilon slop to avoid failure from floating point jitter. */
+ require(
+ lb <= (ub + RandomSampler.roundingEpsilon),
+ s"Lower bound ($lb) must be <= upper bound ($ub)")
+ require(
+ lb >= (0.0 - RandomSampler.roundingEpsilon),
+ s"Lower bound ($lb) must be >= 0.0")
+ require(
+ ub <= (1.0 + RandomSampler.roundingEpsilon),
+ s"Upper bound ($ub) must be <= 1.0")
- def this(ratio: Double) = this(0.0d, ratio)
+ private val rng: Random = new XORShiftRandom
override def setSeed(seed: Long) = rng.setSeed(seed)
override def sample(items: Iterator[T]): Iterator[T] = {
- items.filter { item =>
- val x = rng.nextDouble()
- (x >= lb && x < ub) ^ complement
+ if (ub - lb <= 0.0) {
+ if (complement) items else Iterator.empty
+ } else {
+ if (complement) {
+ items.filter { item => {
+ val x = rng.nextDouble()
+ (x < lb) || (x >= ub)
+ }}
+ } else {
+ items.filter { item => {
+ val x = rng.nextDouble()
+ (x >= lb) && (x < ub)
+ }}
+ }
}
}
/**
* Return a sampler that is the complement of the range specified of the current sampler.
*/
- def cloneComplement(): BernoulliSampler[T] = new BernoulliSampler[T](lb, ub, !complement)
+ def cloneComplement(): BernoulliCellSampler[T] =
+ new BernoulliCellSampler[T](lb, ub, !complement)
+
+ override def clone = new BernoulliCellSampler[T](lb, ub, complement)
+}
+
+
+/**
+ * :: DeveloperApi ::
+ * A sampler based on Bernoulli trials.
+ *
+ * @param fraction the sampling fraction, aka Bernoulli sampling probability
+ * @tparam T item type
+ */
+@DeveloperApi
+class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
+
+ /** epsilon slop to avoid failure from floating point jitter */
+ require(
+ fraction >= (0.0 - RandomSampler.roundingEpsilon)
+ && fraction <= (1.0 + RandomSampler.roundingEpsilon),
+ s"Sampling fraction ($fraction) must be on interval [0, 1]")
- override def clone = new BernoulliSampler[T](lb, ub, complement)
+ private val rng: Random = RandomSampler.newDefaultRNG
+
+ override def setSeed(seed: Long) = rng.setSeed(seed)
+
+ override def sample(items: Iterator[T]): Iterator[T] = {
+ if (fraction <= 0.0) {
+ Iterator.empty
+ } else if (fraction >= 1.0) {
+ items
+ } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
+ new GapSamplingIterator(items, fraction, rng, RandomSampler.rngEpsilon)
+ } else {
+ items.filter { _ => rng.nextDouble() <= fraction }
+ }
+ }
+
+ override def clone = new BernoulliSampler[T](fraction)
}
+
/**
* :: DeveloperApi ::
- * A sampler based on values drawn from Poisson distribution.
+ * A sampler for sampling with replacement, based on values drawn from Poisson distribution.
*
- * @param mean Poisson mean
+ * @param fraction the sampling fraction (with replacement)
* @tparam T item type
*/
@DeveloperApi
-class PoissonSampler[T](mean: Double) extends RandomSampler[T, T] {
+class PoissonSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T] {
+
+ /** Epsilon slop to avoid failure from floating point jitter. */
+ require(
+ fraction >= (0.0 - RandomSampler.roundingEpsilon),
+ s"Sampling fraction ($fraction) must be >= 0")
- private[random] var rng = new Poisson(mean, new DRand)
+ // PoissonDistribution throws an exception when fraction <= 0
+ // If fraction is <= 0, Iterator.empty is used below, so we can use any placeholder value.
+ private val rng = new PoissonDistribution(if (fraction > 0.0) fraction else 1.0)
+ private val rngGap = RandomSampler.newDefaultRNG
override def setSeed(seed: Long) {
- rng = new Poisson(mean, new DRand(seed.toInt))
+ rng.reseedRandomGenerator(seed)
+ rngGap.setSeed(seed)
}
override def sample(items: Iterator[T]): Iterator[T] = {
- items.flatMap { item =>
- val count = rng.nextInt()
- if (count == 0) {
- Iterator.empty
- } else {
- Iterator.fill(count)(item)
- }
+ if (fraction <= 0.0) {
+ Iterator.empty
+ } else if (fraction <= RandomSampler.defaultMaxGapSamplingFraction) {
+ new GapSamplingReplacementIterator(items, fraction, rngGap, RandomSampler.rngEpsilon)
+ } else {
+ items.flatMap { item => {
+ val count = rng.sample()
+ if (count == 0) Iterator.empty else Iterator.fill(count)(item)
+ }}
+ }
+ }
+
+ override def clone = new PoissonSampler[T](fraction)
+}
+
+
+private[spark]
+class GapSamplingIterator[T: ClassTag](
+ var data: Iterator[T],
+ f: Double,
+ rng: Random = RandomSampler.newDefaultRNG,
+ epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+
+ require(f > 0.0 && f < 1.0, s"Sampling fraction ($f) must reside on open interval (0, 1)")
+ require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
+
+ /** implement efficient linear-sequence drop until Scala includes fix for jira SI-8835. */
+ private val iterDrop: Int => Unit = {
+ val arrayClass = Array.empty[T].iterator.getClass
+ val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
+ data.getClass match {
+ case `arrayClass` => ((n: Int) => { data = data.drop(n) })
+ case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) })
+ case _ => ((n: Int) => {
+ var j = 0
+ while (j < n && data.hasNext) {
+ data.next()
+ j += 1
+ }
+ })
+ }
+ }
+
+ override def hasNext: Boolean = data.hasNext
+
+ override def next(): T = {
+ val r = data.next()
+ advance
+ r
+ }
+
+ private val lnq = math.log1p(-f)
+
+ /** skip elements that won't be sampled, according to geometric dist P(k) = (f)(1-f)^k. */
+ private def advance: Unit = {
+ val u = math.max(rng.nextDouble(), epsilon)
+ val k = (math.log(u) / lnq).toInt
+ iterDrop(k)
+ }
+
+ /** advance to first sample as part of object construction. */
+ advance
+ // Attempting to invoke this closer to the top with other object initialization
+ // was causing it to break in strange ways, so I'm invoking it last, which seems to
+ // work reliably.
+}
+
+private[spark]
+class GapSamplingReplacementIterator[T: ClassTag](
+ var data: Iterator[T],
+ f: Double,
+ rng: Random = RandomSampler.newDefaultRNG,
+ epsilon: Double = RandomSampler.rngEpsilon) extends Iterator[T] {
+
+ require(f > 0.0, s"Sampling fraction ($f) must be > 0")
+ require(epsilon > 0.0, s"epsilon ($epsilon) must be > 0")
+
+ /** implement efficient linear-sequence drop until scala includes fix for jira SI-8835. */
+ private val iterDrop: Int => Unit = {
+ val arrayClass = Array.empty[T].iterator.getClass
+ val arrayBufferClass = ArrayBuffer.empty[T].iterator.getClass
+ data.getClass match {
+ case `arrayClass` => ((n: Int) => { data = data.drop(n) })
+ case `arrayBufferClass` => ((n: Int) => { data = data.drop(n) })
+ case _ => ((n: Int) => {
+ var j = 0
+ while (j < n && data.hasNext) {
+ data.next()
+ j += 1
+ }
+ })
+ }
+ }
+
+ /** current sampling value, and its replication factor, as we are sampling with replacement. */
+ private var v: T = _
+ private var rep: Int = 0
+
+ override def hasNext: Boolean = data.hasNext || rep > 0
+
+ override def next(): T = {
+ val r = v
+ rep -= 1
+ if (rep <= 0) advance
+ r
+ }
+
+ /**
+ * Skip elements with replication factor zero (i.e. elements that won't be sampled).
+ * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is
+ * q is the probabililty of Poisson(0; f)
+ */
+ private def advance: Unit = {
+ val u = math.max(rng.nextDouble(), epsilon)
+ val k = (math.log(u) / (-f)).toInt
+ iterDrop(k)
+ // set the value and replication factor for the next value
+ if (data.hasNext) {
+ v = data.next()
+ rep = poissonGE1
+ }
+ }
+
+ private val q = math.exp(-f)
+
+ /**
+ * Sample from Poisson distribution, conditioned such that the sampled value is >= 1.
+ * This is an adaptation from the algorithm for Generating Poisson distributed random variables:
+ * http://en.wikipedia.org/wiki/Poisson_distribution
+ */
+ private def poissonGE1: Int = {
+ // simulate that the standard poisson sampling
+ // gave us at least one iteration, for a sample of >= 1
+ var pp = q + ((1.0 - q) * rng.nextDouble())
+ var r = 1
+
+ // now continue with standard poisson sampling algorithm
+ pp *= rng.nextDouble()
+ while (pp > q) {
+ r += 1
+ pp *= rng.nextDouble()
}
+ r
}
- override def clone = new PoissonSampler[T](mean)
+ /** advance to first sample as part of object construction. */
+ advance
+ // Attempting to invoke this closer to the top with other object initialization
+ // was causing it to break in strange ways, so I'm invoking it last, which seems to
+ // work reliably.
}
diff --git a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
index 8f95d7c6b799b..4fa357edd6f07 100644
--- a/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/StratifiedSamplingUtils.scala
@@ -22,8 +22,7 @@ import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.Logging
import org.apache.spark.SparkContext._
@@ -209,7 +208,7 @@ private[spark] object StratifiedSamplingUtils extends Logging {
samplingRateByKey = computeThresholdByKey(finalResult, fractions)
}
(idx: Int, iter: Iterator[(K, V)]) => {
- val rng = new RandomDataGenerator
+ val rng = new RandomDataGenerator()
rng.reSeed(seed + idx)
// Must use the same invoke pattern on the rng as in getSeqOp for without replacement
// in order to generate the same sequence of random numbers when creating the sample
@@ -245,9 +244,9 @@ private[spark] object StratifiedSamplingUtils extends Logging {
// Must use the same invoke pattern on the rng as in getSeqOp for with replacement
// in order to generate the same sequence of random numbers when creating the sample
val copiesAccepted = if (acceptBound == 0) 0L else rng.nextPoisson(acceptBound)
- val copiesWailisted = rng.nextPoisson(finalResult(key).waitListBound)
+ val copiesWaitlisted = rng.nextPoisson(finalResult(key).waitListBound)
val copiesInSample = copiesAccepted +
- (0 until copiesWailisted).count(i => rng.nextUniform() < thresholdByKey(key))
+ (0 until copiesWaitlisted).count(i => rng.nextUniform() < thresholdByKey(key))
if (copiesInSample > 0) {
Iterator.fill(copiesInSample.toInt)(item)
} else {
@@ -261,10 +260,10 @@ private[spark] object StratifiedSamplingUtils extends Logging {
rng.reSeed(seed + idx)
iter.flatMap { item =>
val count = rng.nextPoisson(fractions(item._1))
- if (count > 0) {
- Iterator.fill(count)(item)
- } else {
+ if (count == 0) {
Iterator.empty
+ } else {
+ Iterator.fill(count)(item)
}
}
}
@@ -274,15 +273,24 @@ private[spark] object StratifiedSamplingUtils extends Logging {
/** A random data generator that generates both uniform values and Poisson values. */
private class RandomDataGenerator {
val uniform = new XORShiftRandom()
- var poisson = new Poisson(1.0, new DRand)
+ // commons-math3 doesn't have a method to generate Poisson from an arbitrary mean;
+ // maintain a cache of Poisson(m) distributions for various m
+ val poissonCache = mutable.Map[Double, PoissonDistribution]()
+ var poissonSeed = 0L
- def reSeed(seed: Long) {
+ def reSeed(seed: Long): Unit = {
uniform.setSeed(seed)
- poisson = new Poisson(1.0, new DRand(seed.toInt))
+ poissonSeed = seed
+ poissonCache.clear()
}
def nextPoisson(mean: Double): Int = {
- poisson.nextInt(mean)
+ val poisson = poissonCache.getOrElseUpdate(mean, {
+ val newPoisson = new PoissonDistribution(mean)
+ newPoisson.reseedRandomGenerator(poissonSeed)
+ newPoisson
+ })
+ poisson.sample()
}
def nextUniform(): Double = {
diff --git a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
index 55b5713706178..467b890fb4bb9 100644
--- a/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/XORShiftRandom.scala
@@ -96,13 +96,9 @@ private[spark] object XORShiftRandom {
xorRand.nextInt()
}
- val iters = timeIt(numIters)(_)
-
/* Return results as a map instead of just printing to screen
in case the user wants to do something with them */
- Map("javaTime" -> iters {javaRand.nextInt()},
- "xorTime" -> iters {xorRand.nextInt()})
-
+ Map("javaTime" -> timeIt(numIters) { javaRand.nextInt() },
+ "xorTime" -> timeIt(numIters) { xorRand.nextInt() })
}
-
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 3190148fb5f43..59c86eecac5e8 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -18,10 +18,13 @@
package org.apache.spark;
import java.io.*;
+import java.nio.channels.FileChannel;
+import java.nio.ByteBuffer;
import java.net.URI;
import java.util.*;
import java.util.concurrent.*;
+import org.apache.spark.input.PortableDataStream;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
@@ -140,11 +143,10 @@ public void intersection() {
public void sample() {
List ints = Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
JavaRDD rdd = sc.parallelize(ints);
- JavaRDD sample20 = rdd.sample(true, 0.2, 11);
- // expected 2 but of course result varies randomly a bit
- Assert.assertEquals(3, sample20.count());
- JavaRDD sample20NoReplacement = rdd.sample(false, 0.2, 11);
- Assert.assertEquals(2, sample20NoReplacement.count());
+ JavaRDD sample20 = rdd.sample(true, 0.2, 3);
+ Assert.assertEquals(2, sample20.count());
+ JavaRDD sample20WithoutReplacement = rdd.sample(false, 0.2, 5);
+ Assert.assertEquals(2, sample20WithoutReplacement.count());
}
@Test
@@ -864,6 +866,82 @@ public Tuple2 call(Tuple2 pair) {
Assert.assertEquals(pairs, readRDD.collect());
}
+ @Test
+ public void binaryFiles() throws Exception {
+ // Reusing the wholeText files example
+ byte[] content1 = "spark is easy to use.\n".getBytes("utf-8");
+
+ String tempDirName = tempDir.getAbsolutePath();
+ File file1 = new File(tempDirName + "/part-00000");
+
+ FileOutputStream fos1 = new FileOutputStream(file1);
+
+ FileChannel channel1 = fos1.getChannel();
+ ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
+ channel1.write(bbuf);
+ channel1.close();
+ JavaPairRDD readRDD = sc.binaryFiles(tempDirName, 3);
+ List> result = readRDD.collect();
+ for (Tuple2 res : result) {
+ Assert.assertArrayEquals(content1, res._2().toArray());
+ }
+ }
+
+ @Test
+ public void binaryFilesCaching() throws Exception {
+ // Reusing the wholeText files example
+ byte[] content1 = "spark is easy to use.\n".getBytes("utf-8");
+
+ String tempDirName = tempDir.getAbsolutePath();
+ File file1 = new File(tempDirName + "/part-00000");
+
+ FileOutputStream fos1 = new FileOutputStream(file1);
+
+ FileChannel channel1 = fos1.getChannel();
+ ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
+ channel1.write(bbuf);
+ channel1.close();
+
+ JavaPairRDD readRDD = sc.binaryFiles(tempDirName).cache();
+ readRDD.foreach(new VoidFunction>() {
+ @Override
+ public void call(Tuple2 pair) throws Exception {
+ pair._2().toArray(); // force the file to read
+ }
+ });
+
+ List> result = readRDD.collect();
+ for (Tuple2 res : result) {
+ Assert.assertArrayEquals(content1, res._2().toArray());
+ }
+ }
+
+ @Test
+ public void binaryRecords() throws Exception {
+ // Reusing the wholeText files example
+ byte[] content1 = "spark isn't always easy to use.\n".getBytes("utf-8");
+ int numOfCopies = 10;
+ String tempDirName = tempDir.getAbsolutePath();
+ File file1 = new File(tempDirName + "/part-00000");
+
+ FileOutputStream fos1 = new FileOutputStream(file1);
+
+ FileChannel channel1 = fos1.getChannel();
+
+ for (int i = 0; i < numOfCopies; i++) {
+ ByteBuffer bbuf = java.nio.ByteBuffer.wrap(content1);
+ channel1.write(bbuf);
+ }
+ channel1.close();
+
+ JavaRDD readRDD = sc.binaryRecords(tempDirName, content1.length);
+ Assert.assertEquals(numOfCopies,readRDD.count());
+ List result = readRDD.collect();
+ for (byte[] res : result) {
+ Assert.assertArrayEquals(content1, res);
+ }
+ }
+
@SuppressWarnings("unchecked")
@Test
public void writeWithNewAPIHadoopFile() {
@@ -1418,4 +1496,16 @@ public Optional call(Integer i) {
}
}
+ static class Class1 {}
+ static class Class2 {}
+
+ @Test
+ public void testRegisterKryoClasses() {
+ SparkConf conf = new SparkConf();
+ conf.registerKryoClasses(new Class[]{ Class1.class, Class2.class });
+ Assert.assertEquals(
+ Class1.class.getName() + "," + Class2.class.getName(),
+ conf.get("spark.kryo.classesToRegister"));
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 81b64c36ddca1..429199f2075c6 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -202,7 +202,8 @@ class DistributedSuite extends FunSuite with Matchers with BeforeAndAfter
val blockManager = SparkEnv.get.blockManager
val blockTransfer = SparkEnv.get.blockTransferService
blockManager.master.getLocations(blockId).foreach { cmId =>
- val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, blockId.toString)
+ val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId,
+ blockId.toString)
val deserialized = blockManager.dataDeserialize(blockId, bytes.nioByteBuffer())
.asInstanceOf[Iterator[Int]].toList
assert(deserialized === (1 to 100).toList)
diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
new file mode 100644
index 0000000000000..66cf60d25f6d1
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala
@@ -0,0 +1,662 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.{FunSuite, PrivateMethodTester}
+import org.apache.spark.executor.TaskMetrics
+import org.apache.spark.scheduler._
+import org.apache.spark.storage.BlockManagerId
+
+/**
+ * Test add and remove behavior of ExecutorAllocationManager.
+ */
+class ExecutorAllocationManagerSuite extends FunSuite with LocalSparkContext {
+ import ExecutorAllocationManager._
+ import ExecutorAllocationManagerSuite._
+
+ test("verify min/max executors") {
+ // No min or max
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test-executor-allocation-manager")
+ .set("spark.dynamicAllocation.enabled", "true")
+ intercept[SparkException] { new SparkContext(conf) }
+ SparkEnv.get.stop() // cleanup the created environment
+
+ // Only min
+ val conf1 = conf.clone().set("spark.dynamicAllocation.minExecutors", "1")
+ intercept[SparkException] { new SparkContext(conf1) }
+ SparkEnv.get.stop()
+
+ // Only max
+ val conf2 = conf.clone().set("spark.dynamicAllocation.maxExecutors", "2")
+ intercept[SparkException] { new SparkContext(conf2) }
+ SparkEnv.get.stop()
+
+ // Both min and max, but min > max
+ intercept[SparkException] { createSparkContext(2, 1) }
+ SparkEnv.get.stop()
+
+ // Both min and max, and min == max
+ val sc1 = createSparkContext(1, 1)
+ assert(sc1.executorAllocationManager.isDefined)
+ sc1.stop()
+
+ // Both min and max, and min < max
+ val sc2 = createSparkContext(1, 2)
+ assert(sc2.executorAllocationManager.isDefined)
+ sc2.stop()
+ }
+
+ test("starting state") {
+ sc = createSparkContext()
+ val manager = sc.executorAllocationManager.get
+ assert(numExecutorsPending(manager) === 0)
+ assert(executorsPendingToRemove(manager).isEmpty)
+ assert(executorIds(manager).isEmpty)
+ assert(addTime(manager) === ExecutorAllocationManager.NOT_SET)
+ assert(removeTimes(manager).isEmpty)
+ }
+
+ test("add executors") {
+ sc = createSparkContext(1, 10)
+ val manager = sc.executorAllocationManager.get
+
+ // Keep adding until the limit is reached
+ assert(numExecutorsPending(manager) === 0)
+ assert(numExecutorsToAdd(manager) === 1)
+ assert(addExecutors(manager) === 1)
+ assert(numExecutorsPending(manager) === 1)
+ assert(numExecutorsToAdd(manager) === 2)
+ assert(addExecutors(manager) === 2)
+ assert(numExecutorsPending(manager) === 3)
+ assert(numExecutorsToAdd(manager) === 4)
+ assert(addExecutors(manager) === 4)
+ assert(numExecutorsPending(manager) === 7)
+ assert(numExecutorsToAdd(manager) === 8)
+ assert(addExecutors(manager) === 3) // reached the limit of 10
+ assert(numExecutorsPending(manager) === 10)
+ assert(numExecutorsToAdd(manager) === 1)
+ assert(addExecutors(manager) === 0)
+ assert(numExecutorsPending(manager) === 10)
+ assert(numExecutorsToAdd(manager) === 1)
+
+ // Register previously requested executors
+ onExecutorAdded(manager, "first")
+ assert(numExecutorsPending(manager) === 9)
+ onExecutorAdded(manager, "second")
+ onExecutorAdded(manager, "third")
+ onExecutorAdded(manager, "fourth")
+ assert(numExecutorsPending(manager) === 6)
+ onExecutorAdded(manager, "first") // duplicates should not count
+ onExecutorAdded(manager, "second")
+ assert(numExecutorsPending(manager) === 6)
+
+ // Try adding again
+ // This should still fail because the number pending + running is still at the limit
+ assert(addExecutors(manager) === 0)
+ assert(numExecutorsPending(manager) === 6)
+ assert(numExecutorsToAdd(manager) === 1)
+ assert(addExecutors(manager) === 0)
+ assert(numExecutorsPending(manager) === 6)
+ assert(numExecutorsToAdd(manager) === 1)
+ }
+
+ test("remove executors") {
+ sc = createSparkContext(5, 10)
+ val manager = sc.executorAllocationManager.get
+ (1 to 10).map(_.toString).foreach { id => onExecutorAdded(manager, id) }
+
+ // Keep removing until the limit is reached
+ assert(executorsPendingToRemove(manager).isEmpty)
+ assert(removeExecutor(manager, "1"))
+ assert(executorsPendingToRemove(manager).size === 1)
+ assert(executorsPendingToRemove(manager).contains("1"))
+ assert(removeExecutor(manager, "2"))
+ assert(removeExecutor(manager, "3"))
+ assert(executorsPendingToRemove(manager).size === 3)
+ assert(executorsPendingToRemove(manager).contains("2"))
+ assert(executorsPendingToRemove(manager).contains("3"))
+ assert(!removeExecutor(manager, "100")) // remove non-existent executors
+ assert(!removeExecutor(manager, "101"))
+ assert(executorsPendingToRemove(manager).size === 3)
+ assert(removeExecutor(manager, "4"))
+ assert(removeExecutor(manager, "5"))
+ assert(!removeExecutor(manager, "6")) // reached the limit of 5
+ assert(executorsPendingToRemove(manager).size === 5)
+ assert(executorsPendingToRemove(manager).contains("4"))
+ assert(executorsPendingToRemove(manager).contains("5"))
+ assert(!executorsPendingToRemove(manager).contains("6"))
+
+ // Kill executors previously requested to remove
+ onExecutorRemoved(manager, "1")
+ assert(executorsPendingToRemove(manager).size === 4)
+ assert(!executorsPendingToRemove(manager).contains("1"))
+ onExecutorRemoved(manager, "2")
+ onExecutorRemoved(manager, "3")
+ assert(executorsPendingToRemove(manager).size === 2)
+ assert(!executorsPendingToRemove(manager).contains("2"))
+ assert(!executorsPendingToRemove(manager).contains("3"))
+ onExecutorRemoved(manager, "2") // duplicates should not count
+ onExecutorRemoved(manager, "3")
+ assert(executorsPendingToRemove(manager).size === 2)
+ onExecutorRemoved(manager, "4")
+ onExecutorRemoved(manager, "5")
+ assert(executorsPendingToRemove(manager).isEmpty)
+
+ // Try removing again
+ // This should still fail because the number pending + running is still at the limit
+ assert(!removeExecutor(manager, "7"))
+ assert(executorsPendingToRemove(manager).isEmpty)
+ assert(!removeExecutor(manager, "8"))
+ assert(executorsPendingToRemove(manager).isEmpty)
+ }
+
+ test ("interleaving add and remove") {
+ sc = createSparkContext(5, 10)
+ val manager = sc.executorAllocationManager.get
+
+ // Add a few executors
+ assert(addExecutors(manager) === 1)
+ assert(addExecutors(manager) === 2)
+ assert(addExecutors(manager) === 4)
+ onExecutorAdded(manager, "1")
+ onExecutorAdded(manager, "2")
+ onExecutorAdded(manager, "3")
+ onExecutorAdded(manager, "4")
+ onExecutorAdded(manager, "5")
+ onExecutorAdded(manager, "6")
+ onExecutorAdded(manager, "7")
+ assert(executorIds(manager).size === 7)
+
+ // Remove until limit
+ assert(removeExecutor(manager, "1"))
+ assert(removeExecutor(manager, "2"))
+ assert(!removeExecutor(manager, "3")) // lower limit reached
+ assert(!removeExecutor(manager, "4"))
+ onExecutorRemoved(manager, "1")
+ onExecutorRemoved(manager, "2")
+ assert(executorIds(manager).size === 5)
+
+ // Add until limit
+ assert(addExecutors(manager) === 5) // upper limit reached
+ assert(addExecutors(manager) === 0)
+ assert(!removeExecutor(manager, "3")) // still at lower limit
+ assert(!removeExecutor(manager, "4"))
+ onExecutorAdded(manager, "8")
+ onExecutorAdded(manager, "9")
+ onExecutorAdded(manager, "10")
+ onExecutorAdded(manager, "11")
+ onExecutorAdded(manager, "12")
+ assert(executorIds(manager).size === 10)
+
+ // Remove succeeds again, now that we are no longer at the lower limit
+ assert(removeExecutor(manager, "3"))
+ assert(removeExecutor(manager, "4"))
+ assert(removeExecutor(manager, "5"))
+ assert(removeExecutor(manager, "6"))
+ assert(executorIds(manager).size === 10)
+ assert(addExecutors(manager) === 0) // still at upper limit
+ onExecutorRemoved(manager, "3")
+ onExecutorRemoved(manager, "4")
+ assert(executorIds(manager).size === 8)
+
+ // Add succeeds again, now that we are no longer at the upper limit
+ // Number of executors added restarts at 1
+ assert(addExecutors(manager) === 1)
+ assert(addExecutors(manager) === 1) // upper limit reached again
+ assert(addExecutors(manager) === 0)
+ assert(executorIds(manager).size === 8)
+ onExecutorRemoved(manager, "5")
+ onExecutorRemoved(manager, "6")
+ onExecutorAdded(manager, "13")
+ onExecutorAdded(manager, "14")
+ assert(executorIds(manager).size === 8)
+ assert(addExecutors(manager) === 1)
+ assert(addExecutors(manager) === 1) // upper limit reached again
+ assert(addExecutors(manager) === 0)
+ onExecutorAdded(manager, "15")
+ onExecutorAdded(manager, "16")
+ assert(executorIds(manager).size === 10)
+ }
+
+ test("starting/canceling add timer") {
+ sc = createSparkContext(2, 10)
+ val clock = new TestClock(8888L)
+ val manager = sc.executorAllocationManager.get
+ manager.setClock(clock)
+
+ // Starting add timer is idempotent
+ assert(addTime(manager) === NOT_SET)
+ onSchedulerBacklogged(manager)
+ val firstAddTime = addTime(manager)
+ assert(firstAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000)
+ clock.tick(100L)
+ onSchedulerBacklogged(manager)
+ assert(addTime(manager) === firstAddTime) // timer is already started
+ clock.tick(200L)
+ onSchedulerBacklogged(manager)
+ assert(addTime(manager) === firstAddTime)
+ onSchedulerQueueEmpty(manager)
+
+ // Restart add timer
+ clock.tick(1000L)
+ assert(addTime(manager) === NOT_SET)
+ onSchedulerBacklogged(manager)
+ val secondAddTime = addTime(manager)
+ assert(secondAddTime === clock.getTimeMillis + schedulerBacklogTimeout * 1000)
+ clock.tick(100L)
+ onSchedulerBacklogged(manager)
+ assert(addTime(manager) === secondAddTime) // timer is already started
+ assert(addTime(manager) !== firstAddTime)
+ assert(firstAddTime !== secondAddTime)
+ }
+
+ test("starting/canceling remove timers") {
+ sc = createSparkContext(2, 10)
+ val clock = new TestClock(14444L)
+ val manager = sc.executorAllocationManager.get
+ manager.setClock(clock)
+
+ // Starting remove timer is idempotent for each executor
+ assert(removeTimes(manager).isEmpty)
+ onExecutorIdle(manager, "1")
+ assert(removeTimes(manager).size === 1)
+ assert(removeTimes(manager).contains("1"))
+ val firstRemoveTime = removeTimes(manager)("1")
+ assert(firstRemoveTime === clock.getTimeMillis + executorIdleTimeout * 1000)
+ clock.tick(100L)
+ onExecutorIdle(manager, "1")
+ assert(removeTimes(manager)("1") === firstRemoveTime) // timer is already started
+ clock.tick(200L)
+ onExecutorIdle(manager, "1")
+ assert(removeTimes(manager)("1") === firstRemoveTime)
+ clock.tick(300L)
+ onExecutorIdle(manager, "2")
+ assert(removeTimes(manager)("2") !== firstRemoveTime) // different executor
+ assert(removeTimes(manager)("2") === clock.getTimeMillis + executorIdleTimeout * 1000)
+ clock.tick(400L)
+ onExecutorIdle(manager, "3")
+ assert(removeTimes(manager)("3") !== firstRemoveTime)
+ assert(removeTimes(manager)("3") === clock.getTimeMillis + executorIdleTimeout * 1000)
+ assert(removeTimes(manager).size === 3)
+ assert(removeTimes(manager).contains("2"))
+ assert(removeTimes(manager).contains("3"))
+
+ // Restart remove timer
+ clock.tick(1000L)
+ onExecutorBusy(manager, "1")
+ assert(removeTimes(manager).size === 2)
+ onExecutorIdle(manager, "1")
+ assert(removeTimes(manager).size === 3)
+ assert(removeTimes(manager).contains("1"))
+ val secondRemoveTime = removeTimes(manager)("1")
+ assert(secondRemoveTime === clock.getTimeMillis + executorIdleTimeout * 1000)
+ assert(removeTimes(manager)("1") === secondRemoveTime) // timer is already started
+ assert(removeTimes(manager)("1") !== firstRemoveTime)
+ assert(firstRemoveTime !== secondRemoveTime)
+ }
+
+ test("mock polling loop with no events") {
+ sc = createSparkContext(1, 20)
+ val manager = sc.executorAllocationManager.get
+ val clock = new TestClock(2020L)
+ manager.setClock(clock)
+
+ // No events - we should not be adding or removing
+ assert(numExecutorsPending(manager) === 0)
+ assert(executorsPendingToRemove(manager).isEmpty)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 0)
+ assert(executorsPendingToRemove(manager).isEmpty)
+ clock.tick(100L)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 0)
+ assert(executorsPendingToRemove(manager).isEmpty)
+ clock.tick(1000L)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 0)
+ assert(executorsPendingToRemove(manager).isEmpty)
+ clock.tick(10000L)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 0)
+ assert(executorsPendingToRemove(manager).isEmpty)
+ }
+
+ test("mock polling loop add behavior") {
+ sc = createSparkContext(1, 20)
+ val clock = new TestClock(2020L)
+ val manager = sc.executorAllocationManager.get
+ manager.setClock(clock)
+
+ // Scheduler queue backlogged
+ onSchedulerBacklogged(manager)
+ clock.tick(schedulerBacklogTimeout * 1000 / 2)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 0) // timer not exceeded yet
+ clock.tick(schedulerBacklogTimeout * 1000)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 1) // first timer exceeded
+ clock.tick(sustainedSchedulerBacklogTimeout * 1000 / 2)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 1) // second timer not exceeded yet
+ clock.tick(sustainedSchedulerBacklogTimeout * 1000)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 1 + 2) // second timer exceeded
+ clock.tick(sustainedSchedulerBacklogTimeout * 1000)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 1 + 2 + 4) // third timer exceeded
+
+ // Scheduler queue drained
+ onSchedulerQueueEmpty(manager)
+ clock.tick(sustainedSchedulerBacklogTimeout * 1000)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 7) // timer is canceled
+ clock.tick(sustainedSchedulerBacklogTimeout * 1000)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 7)
+
+ // Scheduler queue backlogged again
+ onSchedulerBacklogged(manager)
+ clock.tick(schedulerBacklogTimeout * 1000)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 7 + 1) // timer restarted
+ clock.tick(sustainedSchedulerBacklogTimeout * 1000)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 7 + 1 + 2)
+ clock.tick(sustainedSchedulerBacklogTimeout * 1000)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 7 + 1 + 2 + 4)
+ clock.tick(sustainedSchedulerBacklogTimeout * 1000)
+ schedule(manager)
+ assert(numExecutorsPending(manager) === 20) // limit reached
+ }
+
+ test("mock polling loop remove behavior") {
+ sc = createSparkContext(1, 20)
+ val clock = new TestClock(2020L)
+ val manager = sc.executorAllocationManager.get
+ manager.setClock(clock)
+
+ // Remove idle executors on timeout
+ onExecutorAdded(manager, "executor-1")
+ onExecutorAdded(manager, "executor-2")
+ onExecutorAdded(manager, "executor-3")
+ assert(removeTimes(manager).size === 3)
+ assert(executorsPendingToRemove(manager).isEmpty)
+ clock.tick(executorIdleTimeout * 1000 / 2)
+ schedule(manager)
+ assert(removeTimes(manager).size === 3) // idle threshold not reached yet
+ assert(executorsPendingToRemove(manager).isEmpty)
+ clock.tick(executorIdleTimeout * 1000)
+ schedule(manager)
+ assert(removeTimes(manager).isEmpty) // idle threshold exceeded
+ assert(executorsPendingToRemove(manager).size === 2) // limit reached (1 executor remaining)
+
+ // Mark a subset as busy - only idle executors should be removed
+ onExecutorAdded(manager, "executor-4")
+ onExecutorAdded(manager, "executor-5")
+ onExecutorAdded(manager, "executor-6")
+ onExecutorAdded(manager, "executor-7")
+ assert(removeTimes(manager).size === 5) // 5 active executors
+ assert(executorsPendingToRemove(manager).size === 2) // 2 pending to be removed
+ onExecutorBusy(manager, "executor-4")
+ onExecutorBusy(manager, "executor-5")
+ onExecutorBusy(manager, "executor-6") // 3 busy and 2 idle (of the 5 active ones)
+ schedule(manager)
+ assert(removeTimes(manager).size === 2) // remove only idle executors
+ assert(!removeTimes(manager).contains("executor-4"))
+ assert(!removeTimes(manager).contains("executor-5"))
+ assert(!removeTimes(manager).contains("executor-6"))
+ assert(executorsPendingToRemove(manager).size === 2)
+ clock.tick(executorIdleTimeout * 1000)
+ schedule(manager)
+ assert(removeTimes(manager).isEmpty) // idle executors are removed
+ assert(executorsPendingToRemove(manager).size === 4)
+ assert(!executorsPendingToRemove(manager).contains("executor-4"))
+ assert(!executorsPendingToRemove(manager).contains("executor-5"))
+ assert(!executorsPendingToRemove(manager).contains("executor-6"))
+
+ // Busy executors are now idle and should be removed
+ onExecutorIdle(manager, "executor-4")
+ onExecutorIdle(manager, "executor-5")
+ onExecutorIdle(manager, "executor-6")
+ schedule(manager)
+ assert(removeTimes(manager).size === 3) // 0 busy and 3 idle
+ assert(removeTimes(manager).contains("executor-4"))
+ assert(removeTimes(manager).contains("executor-5"))
+ assert(removeTimes(manager).contains("executor-6"))
+ assert(executorsPendingToRemove(manager).size === 4)
+ clock.tick(executorIdleTimeout * 1000)
+ schedule(manager)
+ assert(removeTimes(manager).isEmpty)
+ assert(executorsPendingToRemove(manager).size === 6) // limit reached (1 executor remaining)
+ }
+
+ test("listeners trigger add executors correctly") {
+ sc = createSparkContext(2, 10)
+ val manager = sc.executorAllocationManager.get
+ assert(addTime(manager) === NOT_SET)
+
+ // Starting a stage should start the add timer
+ val numTasks = 10
+ sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(0, numTasks)))
+ assert(addTime(manager) !== NOT_SET)
+
+ // Starting a subset of the tasks should not cancel the add timer
+ val taskInfos = (0 to numTasks - 1).map { i => createTaskInfo(i, i, "executor-1") }
+ taskInfos.tail.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, info)) }
+ assert(addTime(manager) !== NOT_SET)
+
+ // Starting all remaining tasks should cancel the add timer
+ sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, taskInfos.head))
+ assert(addTime(manager) === NOT_SET)
+
+ // Start two different stages
+ // The add timer should be canceled only if all tasks in both stages start running
+ sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(1, numTasks)))
+ sc.listenerBus.postToAll(SparkListenerStageSubmitted(createStageInfo(2, numTasks)))
+ assert(addTime(manager) !== NOT_SET)
+ taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(1, 0, info)) }
+ assert(addTime(manager) !== NOT_SET)
+ taskInfos.foreach { info => sc.listenerBus.postToAll(SparkListenerTaskStart(2, 0, info)) }
+ assert(addTime(manager) === NOT_SET)
+ }
+
+ test("listeners trigger remove executors correctly") {
+ sc = createSparkContext(2, 10)
+ val manager = sc.executorAllocationManager.get
+ assert(removeTimes(manager).isEmpty)
+
+ // Added executors should start the remove timers for each executor
+ (1 to 5).map("executor-" + _).foreach { id => onExecutorAdded(manager, id) }
+ assert(removeTimes(manager).size === 5)
+
+ // Starting a task cancel the remove timer for that executor
+ sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(0, 0, "executor-1")))
+ sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(1, 1, "executor-1")))
+ sc.listenerBus.postToAll(SparkListenerTaskStart(0, 0, createTaskInfo(2, 2, "executor-2")))
+ assert(removeTimes(manager).size === 3)
+ assert(!removeTimes(manager).contains("executor-1"))
+ assert(!removeTimes(manager).contains("executor-2"))
+
+ // Finishing all tasks running on an executor should start the remove timer for that executor
+ sc.listenerBus.postToAll(SparkListenerTaskEnd(
+ 0, 0, "task-type", Success, createTaskInfo(0, 0, "executor-1"), new TaskMetrics))
+ sc.listenerBus.postToAll(SparkListenerTaskEnd(
+ 0, 0, "task-type", Success, createTaskInfo(2, 2, "executor-2"), new TaskMetrics))
+ assert(removeTimes(manager).size === 4)
+ assert(!removeTimes(manager).contains("executor-1")) // executor-1 has not finished yet
+ assert(removeTimes(manager).contains("executor-2"))
+ sc.listenerBus.postToAll(SparkListenerTaskEnd(
+ 0, 0, "task-type", Success, createTaskInfo(1, 1, "executor-1"), new TaskMetrics))
+ assert(removeTimes(manager).size === 5)
+ assert(removeTimes(manager).contains("executor-1")) // executor-1 has now finished
+ }
+
+ test("listeners trigger add and remove executor callbacks correctly") {
+ sc = createSparkContext(2, 10)
+ val manager = sc.executorAllocationManager.get
+ assert(executorIds(manager).isEmpty)
+ assert(removeTimes(manager).isEmpty)
+
+ // New executors have registered
+ sc.listenerBus.postToAll(SparkListenerBlockManagerAdded(
+ 0L, BlockManagerId("executor-1", "host1", 1), 100L))
+ assert(executorIds(manager).size === 1)
+ assert(executorIds(manager).contains("executor-1"))
+ assert(removeTimes(manager).size === 1)
+ assert(removeTimes(manager).contains("executor-1"))
+ sc.listenerBus.postToAll(SparkListenerBlockManagerAdded(
+ 0L, BlockManagerId("executor-2", "host2", 1), 100L))
+ assert(executorIds(manager).size === 2)
+ assert(executorIds(manager).contains("executor-2"))
+ assert(removeTimes(manager).size === 2)
+ assert(removeTimes(manager).contains("executor-2"))
+
+ // Existing executors have disconnected
+ sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved(
+ 0L, BlockManagerId("executor-1", "host1", 1)))
+ assert(executorIds(manager).size === 1)
+ assert(!executorIds(manager).contains("executor-1"))
+ assert(removeTimes(manager).size === 1)
+ assert(!removeTimes(manager).contains("executor-1"))
+
+ // Unknown executor has disconnected
+ sc.listenerBus.postToAll(SparkListenerBlockManagerRemoved(
+ 0L, BlockManagerId("executor-3", "host3", 1)))
+ assert(executorIds(manager).size === 1)
+ assert(removeTimes(manager).size === 1)
+ }
+
+}
+
+/**
+ * Helper methods for testing ExecutorAllocationManager.
+ * This includes methods to access private methods and fields in ExecutorAllocationManager.
+ */
+private object ExecutorAllocationManagerSuite extends PrivateMethodTester {
+ private val schedulerBacklogTimeout = 1L
+ private val sustainedSchedulerBacklogTimeout = 2L
+ private val executorIdleTimeout = 3L
+
+ private def createSparkContext(minExecutors: Int = 1, maxExecutors: Int = 5): SparkContext = {
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test-executor-allocation-manager")
+ .set("spark.dynamicAllocation.enabled", "true")
+ .set("spark.dynamicAllocation.minExecutors", minExecutors.toString)
+ .set("spark.dynamicAllocation.maxExecutors", maxExecutors.toString)
+ .set("spark.dynamicAllocation.schedulerBacklogTimeout", schedulerBacklogTimeout.toString)
+ .set("spark.dynamicAllocation.sustainedSchedulerBacklogTimeout",
+ sustainedSchedulerBacklogTimeout.toString)
+ .set("spark.dynamicAllocation.executorIdleTimeout", executorIdleTimeout.toString)
+ .set("spark.dynamicAllocation.testing", "true")
+ new SparkContext(conf)
+ }
+
+ private def createStageInfo(stageId: Int, numTasks: Int): StageInfo = {
+ new StageInfo(stageId, 0, "name", numTasks, Seq.empty, "no details")
+ }
+
+ private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = {
+ new TaskInfo(taskId, taskIndex, 0, 0, executorId, "", TaskLocality.ANY, speculative = false)
+ }
+
+ /* ------------------------------------------------------- *
+ | Helper methods for accessing private methods and fields |
+ * ------------------------------------------------------- */
+
+ private val _numExecutorsToAdd = PrivateMethod[Int]('numExecutorsToAdd)
+ private val _numExecutorsPending = PrivateMethod[Int]('numExecutorsPending)
+ private val _executorsPendingToRemove =
+ PrivateMethod[collection.Set[String]]('executorsPendingToRemove)
+ private val _executorIds = PrivateMethod[collection.Set[String]]('executorIds)
+ private val _addTime = PrivateMethod[Long]('addTime)
+ private val _removeTimes = PrivateMethod[collection.Map[String, Long]]('removeTimes)
+ private val _schedule = PrivateMethod[Unit]('schedule)
+ private val _addExecutors = PrivateMethod[Int]('addExecutors)
+ private val _removeExecutor = PrivateMethod[Boolean]('removeExecutor)
+ private val _onExecutorAdded = PrivateMethod[Unit]('onExecutorAdded)
+ private val _onExecutorRemoved = PrivateMethod[Unit]('onExecutorRemoved)
+ private val _onSchedulerBacklogged = PrivateMethod[Unit]('onSchedulerBacklogged)
+ private val _onSchedulerQueueEmpty = PrivateMethod[Unit]('onSchedulerQueueEmpty)
+ private val _onExecutorIdle = PrivateMethod[Unit]('onExecutorIdle)
+ private val _onExecutorBusy = PrivateMethod[Unit]('onExecutorBusy)
+
+ private def numExecutorsToAdd(manager: ExecutorAllocationManager): Int = {
+ manager invokePrivate _numExecutorsToAdd()
+ }
+
+ private def numExecutorsPending(manager: ExecutorAllocationManager): Int = {
+ manager invokePrivate _numExecutorsPending()
+ }
+
+ private def executorsPendingToRemove(
+ manager: ExecutorAllocationManager): collection.Set[String] = {
+ manager invokePrivate _executorsPendingToRemove()
+ }
+
+ private def executorIds(manager: ExecutorAllocationManager): collection.Set[String] = {
+ manager invokePrivate _executorIds()
+ }
+
+ private def addTime(manager: ExecutorAllocationManager): Long = {
+ manager invokePrivate _addTime()
+ }
+
+ private def removeTimes(manager: ExecutorAllocationManager): collection.Map[String, Long] = {
+ manager invokePrivate _removeTimes()
+ }
+
+ private def schedule(manager: ExecutorAllocationManager): Unit = {
+ manager invokePrivate _schedule()
+ }
+
+ private def addExecutors(manager: ExecutorAllocationManager): Int = {
+ manager invokePrivate _addExecutors()
+ }
+
+ private def removeExecutor(manager: ExecutorAllocationManager, id: String): Boolean = {
+ manager invokePrivate _removeExecutor(id)
+ }
+
+ private def onExecutorAdded(manager: ExecutorAllocationManager, id: String): Unit = {
+ manager invokePrivate _onExecutorAdded(id)
+ }
+
+ private def onExecutorRemoved(manager: ExecutorAllocationManager, id: String): Unit = {
+ manager invokePrivate _onExecutorRemoved(id)
+ }
+
+ private def onSchedulerBacklogged(manager: ExecutorAllocationManager): Unit = {
+ manager invokePrivate _onSchedulerBacklogged()
+ }
+
+ private def onSchedulerQueueEmpty(manager: ExecutorAllocationManager): Unit = {
+ manager invokePrivate _onSchedulerQueueEmpty()
+ }
+
+ private def onExecutorIdle(manager: ExecutorAllocationManager, id: String): Unit = {
+ manager invokePrivate _onExecutorIdle(id)
+ }
+
+ private def onExecutorBusy(manager: ExecutorAllocationManager, id: String): Unit = {
+ manager invokePrivate _onExecutorBusy(id)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
new file mode 100644
index 0000000000000..792b9cd8b6ff2
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.shuffle.{ExternalShuffleBlockHandler, ExternalShuffleClient}
+
+/**
+ * This suite creates an external shuffle server and routes all shuffle fetches through it.
+ * Note that failures in this suite may arise due to changes in Spark that invalidate expectations
+ * set up in [[ExternalShuffleBlockHandler]], such as changing the format of shuffle files or how
+ * we hash files into folders.
+ */
+class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll {
+ var server: TransportServer = _
+ var rpcHandler: ExternalShuffleBlockHandler = _
+
+ override def beforeAll() {
+ val transportConf = SparkTransportConf.fromSparkConf(conf)
+ rpcHandler = new ExternalShuffleBlockHandler()
+ val transportContext = new TransportContext(transportConf, rpcHandler)
+ server = transportContext.createServer()
+
+ conf.set("spark.shuffle.manager", "sort")
+ conf.set("spark.shuffle.service.enabled", "true")
+ conf.set("spark.shuffle.service.port", server.getPort.toString)
+ }
+
+ override def afterAll() {
+ server.close()
+ }
+
+ // This test ensures that the external shuffle service is actually in use for the other tests.
+ test("using external shuffle service") {
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
+ sc.env.blockManager.externalShuffleServiceEnabled should equal(true)
+ sc.env.blockManager.shuffleClient.getClass should equal(classOf[ExternalShuffleClient])
+
+ val rdd = sc.parallelize(0 until 1000, 10).map(i => (i, 1)).reduceByKey(_ + _)
+
+ rdd.count()
+ rdd.count()
+
+ // Invalidate the registered executors, disallowing access to their shuffle blocks.
+ rpcHandler.clearRegisteredExecutors()
+
+ // Now Spark will receive FetchFailed, and not retry the stage due to "spark.test.noStageRetry"
+ // being set.
+ val e = intercept[SparkException] {
+ rdd.count()
+ }
+ e.getMessage should include ("Fetch failure will not retry stage due to testing config")
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
index a8867020e457d..379c2a6ea4b55 100644
--- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import java.io._
import java.util.jar.{JarEntry, JarOutputStream}
+import com.google.common.io.ByteStreams
import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
@@ -58,12 +59,7 @@ class FileServerSuite extends FunSuite with LocalSparkContext {
jar.putNextEntry(jarEntry)
val in = new FileInputStream(textFile)
- val buffer = new Array[Byte](10240)
- var nRead = 0
- while (nRead <= 0) {
- nRead = in.read(buffer, 0, buffer.length)
- jar.write(buffer, 0, nRead)
- }
+ ByteStreams.copy(in, jar)
in.close()
jar.close()
diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala
index a2b74c4419d46..5e24196101fbc 100644
--- a/core/src/test/scala/org/apache/spark/FileSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FileSuite.scala
@@ -19,6 +19,9 @@ package org.apache.spark
import java.io.{File, FileWriter}
+import org.apache.spark.input.PortableDataStream
+import org.apache.spark.storage.StorageLevel
+
import scala.io.Source
import org.apache.hadoop.io._
@@ -224,6 +227,187 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(output.map(_.toString).collect().toList === List("(1,a)", "(2,aa)", "(3,aaa)"))
}
+ test("binary file input as byte array") {
+ sc = new SparkContext("local", "test")
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ channel.write(bbuf)
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryFiles(outFileName)
+ val (infile: String, indata: PortableDataStream) = inRdd.collect.head
+
+ // Make sure the name and array match
+ assert(infile.contains(outFileName)) // a prefix may get added
+ assert(indata.toArray === testOutput)
+ }
+
+ test("portabledatastream caching tests") {
+ sc = new SparkContext("local", "test")
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ channel.write(bbuf)
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryFiles(outFileName).cache()
+ inRdd.foreach{
+ curData: (String, PortableDataStream) =>
+ curData._2.toArray() // force the file to read
+ }
+ val mappedRdd = inRdd.map {
+ curData: (String, PortableDataStream) =>
+ (curData._2.getPath(), curData._2)
+ }
+ val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head
+
+ // Try reading the output back as an object file
+
+ assert(indata.toArray === testOutput)
+ }
+
+ test("portabledatastream persist disk storage") {
+ sc = new SparkContext("local", "test")
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ channel.write(bbuf)
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryFiles(outFileName).persist(StorageLevel.DISK_ONLY)
+ inRdd.foreach{
+ curData: (String, PortableDataStream) =>
+ curData._2.toArray() // force the file to read
+ }
+ val mappedRdd = inRdd.map {
+ curData: (String, PortableDataStream) =>
+ (curData._2.getPath(), curData._2)
+ }
+ val (infile: String, indata: PortableDataStream) = mappedRdd.collect.head
+
+ // Try reading the output back as an object file
+
+ assert(indata.toArray === testOutput)
+ }
+
+ test("portabledatastream flatmap tests") {
+ sc = new SparkContext("local", "test")
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val numOfCopies = 3
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ channel.write(bbuf)
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryFiles(outFileName)
+ val mappedRdd = inRdd.map {
+ curData: (String, PortableDataStream) =>
+ (curData._2.getPath(), curData._2)
+ }
+ val copyRdd = mappedRdd.flatMap {
+ curData: (String, PortableDataStream) =>
+ for(i <- 1 to numOfCopies) yield (i, curData._2)
+ }
+
+ val copyArr: Array[(Int, PortableDataStream)] = copyRdd.collect()
+
+ // Try reading the output back as an object file
+ assert(copyArr.length == numOfCopies)
+ copyArr.foreach{
+ cEntry: (Int, PortableDataStream) =>
+ assert(cEntry._2.toArray === testOutput)
+ }
+
+ }
+
+ test("fixed record length binary file as byte array") {
+ // a fixed length of 6 bytes
+
+ sc = new SparkContext("local", "test")
+
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val testOutputCopies = 10
+
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ for(i <- 1 to testOutputCopies) {
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ channel.write(bbuf)
+ }
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryRecords(outFileName, testOutput.length)
+ // make sure there are enough elements
+ assert(inRdd.count == testOutputCopies)
+
+ // now just compare the first one
+ val indata: Array[Byte] = inRdd.collect.head
+ assert(indata === testOutput)
+ }
+
+ test ("negative binary record length should raise an exception") {
+ // a fixed length of 6 bytes
+ sc = new SparkContext("local", "test")
+
+ val outFile = new File(tempDir, "record-bytestream-00000.bin")
+ val outFileName = outFile.getAbsolutePath()
+
+ // create file
+ val testOutput = Array[Byte](1, 2, 3, 4, 5, 6)
+ val testOutputCopies = 10
+
+ // write data to file
+ val file = new java.io.FileOutputStream(outFile)
+ val channel = file.getChannel
+ for(i <- 1 to testOutputCopies) {
+ val bbuf = java.nio.ByteBuffer.wrap(testOutput)
+ channel.write(bbuf)
+ }
+ channel.close()
+ file.close()
+
+ val inRdd = sc.binaryRecords(outFileName, -1)
+
+ intercept[SparkException] {
+ inRdd.count
+ }
+ }
+
test("file caching") {
sc = new SparkContext("local", "test")
val out = new FileWriter(tempDir + "/input")
diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala
index 2acc02a54fa3d..19180e88ebe0a 100644
--- a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala
@@ -24,10 +24,6 @@ class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with hash-based shuffle.
override def beforeAll() {
- System.setProperty("spark.shuffle.manager", "hash")
- }
-
- override def afterAll() {
- System.clearProperty("spark.shuffle.manager")
+ conf.set("spark.shuffle.manager", "hash")
}
}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index cbc0bd178d894..d27880f4bc32f 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.BlockManagerId
import org.apache.spark.util.AkkaUtils
-class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
+class MapOutputTrackerSuite extends FunSuite {
private val conf = new SparkConf
test("master start and stop") {
@@ -37,6 +37,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.trackerActor =
actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf)))
tracker.stop()
+ actorSystem.shutdown()
}
test("master register shuffle and fetch") {
@@ -56,6 +57,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000),
(BlockManagerId("b", "hostB", 1000), size10000)))
tracker.stop()
+ actorSystem.shutdown()
}
test("master register and unregister shuffle") {
@@ -74,6 +76,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
tracker.unregisterShuffle(10)
assert(!tracker.containsShuffle(10))
assert(tracker.getServerStatuses(10, 0).isEmpty)
+
+ tracker.stop()
+ actorSystem.shutdown()
}
test("master register shuffle and unregister map output and fetch") {
@@ -97,6 +102,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
// this should cause it to fail, and the scheduler will ignore the failure due to the
// stage already being aborted.
intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) }
+
+ tracker.stop()
+ actorSystem.shutdown()
}
test("remote fetch") {
@@ -136,6 +144,11 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
// failure should be cached
intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+
+ masterTracker.stop()
+ slaveTracker.stop()
+ actorSystem.shutdown()
+ slaveSystem.shutdown()
}
test("remote fetch below akka frame size") {
@@ -154,6 +167,9 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
masterTracker.registerMapOutput(10, 0, MapStatus(
BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0)))
masterActor.receive(GetMapOutputStatuses(10))
+
+// masterTracker.stop() // this throws an exception
+ actorSystem.shutdown()
}
test("remote fetch exceeds akka frame size") {
@@ -176,5 +192,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0)))
}
intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) }
+
+// masterTracker.stop() // this throws an exception
+ actorSystem.shutdown()
}
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
index d7b2d2e1e330f..d78c99c2e1e06 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala
@@ -24,10 +24,6 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with Netty shuffle mode.
override def beforeAll() {
- System.setProperty("spark.shuffle.use.netty", "true")
- }
-
- override def afterAll() {
- System.clearProperty("spark.shuffle.use.netty")
+ conf.set("spark.shuffle.blockTransferService", "netty")
}
}
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 15aa4d83800fa..cda942e15a704 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -30,10 +30,14 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
val conf = new SparkConf(loadDefaults = false)
+ // Ensure that the DAGScheduler doesn't retry stages whose fetches fail, so that we accurately
+ // test that the shuffle works (rather than retrying until all blocks are local to one Executor).
+ conf.set("spark.test.noStageRetry", "true")
+
test("groupByKey without compression") {
try {
System.setProperty("spark.shuffle.compress", "false")
- sc = new SparkContext("local", "test")
+ sc = new SparkContext("local", "test", conf)
val pairs = sc.parallelize(Array((1, 1), (1, 2), (1, 3), (2, 1)), 4)
val groups = pairs.groupByKey(4).collect()
assert(groups.size === 2)
@@ -47,7 +51,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
}
test("shuffle non-zero block size") {
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val NUM_BLOCKS = 3
val a = sc.parallelize(1 to 10, 2)
@@ -73,7 +77,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("shuffle serializer") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(x, new NonJavaSerializableClass(x * 2))
@@ -89,7 +93,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("zero sized blocks") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
// 10 partitions from 4 keys
val NUM_BLOCKS = 10
@@ -116,7 +120,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("zero sized blocks without kryo") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
// 10 partitions from 4 keys
val NUM_BLOCKS = 10
@@ -141,7 +145,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("shuffle on mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data = Array(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
@@ -154,7 +158,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("sorting on mutable pairs") {
// This is not in SortingSuite because of the local cluster setup.
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data = Array(p(1, 11), p(3, 33), p(100, 100), p(2, 22))
val pairs: RDD[MutablePair[Int, Int]] = sc.parallelize(data, 2)
@@ -168,7 +172,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("cogroup using mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1))
val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"), p(3, "3"))
@@ -195,7 +199,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("subtract mutable pairs") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- sc = new SparkContext("local-cluster[2,1,512]", "test")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
def p[T1, T2](_1: T1, _2: T2) = MutablePair(_1, _2)
val data1 = Seq(p(1, 1), p(1, 2), p(1, 3), p(2, 1), p(3, 33))
val data2 = Seq(p(1, "11"), p(1, "12"), p(2, "22"))
@@ -209,11 +213,8 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("sort with Java non serializable class - Kryo") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- val conf = new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .setAppName("test")
- .setMaster("local-cluster[2,1,512]")
- sc = new SparkContext(conf)
+ val myConf = conf.clone().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
+ sc = new SparkContext("local-cluster[2,1,512]", "test", myConf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
@@ -226,10 +227,7 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
test("sort with Java non serializable class - Java") {
// Use a local cluster with 2 processes to make sure there are both local and remote blocks
- val conf = new SparkConf()
- .setAppName("test")
- .setMaster("local-cluster[2,1,512]")
- sc = new SparkContext(conf)
+ sc = new SparkContext("local-cluster[2,1,512]", "test", conf)
val a = sc.parallelize(1 to 10, 2)
val b = a.map { x =>
(new NonJavaSerializableClass(x), x)
@@ -242,6 +240,30 @@ abstract class ShuffleSuite extends FunSuite with Matchers with LocalSparkContex
assert(thrown.getClass === classOf[SparkException])
assert(thrown.getMessage.toLowerCase.contains("serializable"))
}
+
+ test("shuffle with different compression settings (SPARK-3426)") {
+ for (
+ shuffleSpillCompress <- Set(true, false);
+ shuffleCompress <- Set(true, false)
+ ) {
+ val conf = new SparkConf()
+ .setAppName("test")
+ .setMaster("local")
+ .set("spark.shuffle.spill.compress", shuffleSpillCompress.toString)
+ .set("spark.shuffle.compress", shuffleCompress.toString)
+ .set("spark.shuffle.memoryFraction", "0.001")
+ resetSparkContext()
+ sc = new SparkContext(conf)
+ try {
+ sc.parallelize(0 until 100000).map(i => (i / 4, i)).groupByKey().collect()
+ } catch {
+ case e: Exception =>
+ val errMsg = s"Failed with spark.shuffle.spill.compress=$shuffleSpillCompress," +
+ s" spark.shuffle.compress=$shuffleCompress"
+ throw new Exception(errMsg, e)
+ }
+ }
+ }
}
object ShuffleSuite {
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 639e56c488db4..63358172ea1f4 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -24,10 +24,6 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with sort-based shuffle.
override def beforeAll() {
- System.setProperty("spark.shuffle.manager", "sort")
- }
-
- override def afterAll() {
- System.clearProperty("spark.shuffle.manager")
+ conf.set("spark.shuffle.manager", "sort")
}
}
diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
index 87e9012622456..5d018ea9868a7 100644
--- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala
@@ -18,6 +18,8 @@
package org.apache.spark
import org.scalatest.FunSuite
+import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer}
+import com.esotericsoftware.kryo.Kryo
class SparkConfSuite extends FunSuite with LocalSparkContext {
test("loading from system properties") {
@@ -133,4 +135,64 @@ class SparkConfSuite extends FunSuite with LocalSparkContext {
System.clearProperty("spark.test.a.b.c")
}
}
+
+ test("register kryo classes through registerKryoClasses") {
+ val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
+
+ conf.registerKryoClasses(Array(classOf[Class1], classOf[Class2]))
+ assert(conf.get("spark.kryo.classesToRegister") ===
+ classOf[Class1].getName + "," + classOf[Class2].getName)
+
+ conf.registerKryoClasses(Array(classOf[Class3]))
+ assert(conf.get("spark.kryo.classesToRegister") ===
+ classOf[Class1].getName + "," + classOf[Class2].getName + "," + classOf[Class3].getName)
+
+ conf.registerKryoClasses(Array(classOf[Class2]))
+ assert(conf.get("spark.kryo.classesToRegister") ===
+ classOf[Class1].getName + "," + classOf[Class2].getName + "," + classOf[Class3].getName)
+
+ // Kryo doesn't expose a way to discover registered classes, but at least make sure this doesn't
+ // blow up.
+ val serializer = new KryoSerializer(conf)
+ serializer.newInstance().serialize(new Class1())
+ serializer.newInstance().serialize(new Class2())
+ serializer.newInstance().serialize(new Class3())
+ }
+
+ test("register kryo classes through registerKryoClasses and custom registrator") {
+ val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
+
+ conf.registerKryoClasses(Array(classOf[Class1]))
+ assert(conf.get("spark.kryo.classesToRegister") === classOf[Class1].getName)
+
+ conf.set("spark.kryo.registrator", classOf[CustomRegistrator].getName)
+
+ // Kryo doesn't expose a way to discover registered classes, but at least make sure this doesn't
+ // blow up.
+ val serializer = new KryoSerializer(conf)
+ serializer.newInstance().serialize(new Class1())
+ serializer.newInstance().serialize(new Class2())
+ }
+
+ test("register kryo classes through conf") {
+ val conf = new SparkConf().set("spark.kryo.registrationRequired", "true")
+ conf.set("spark.kryo.classesToRegister", "java.lang.StringBuffer")
+ conf.set("spark.serializer", classOf[KryoSerializer].getName)
+
+ // Kryo doesn't expose a way to discover registered classes, but at least make sure this doesn't
+ // blow up.
+ val serializer = new KryoSerializer(conf)
+ serializer.newInstance().serialize(new StringBuffer())
+ }
+
+}
+
+class Class1 {}
+class Class2 {}
+class Class3 {}
+
+class CustomRegistrator extends KryoRegistrator {
+ def registerClasses(kryo: Kryo) {
+ kryo.register(classOf[Class2])
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
index 495a0d48633a4..0390a2e4f1dbb 100644
--- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -17,22 +17,23 @@
package org.apache.spark
-import org.scalatest.{BeforeAndAfterEach, FunSuite, PrivateMethodTester}
+import org.scalatest.{FunSuite, PrivateMethodTester}
-import org.apache.spark.scheduler.{TaskScheduler, TaskSchedulerImpl}
+import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl}
import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend}
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.scheduler.local.LocalBackend
class SparkContextSchedulerCreationSuite
- extends FunSuite with PrivateMethodTester with Logging with BeforeAndAfterEach {
+ extends FunSuite with LocalSparkContext with PrivateMethodTester with Logging {
def createTaskScheduler(master: String): TaskSchedulerImpl = {
// Create local SparkContext to setup a SparkEnv. We don't actually want to start() the
// real schedulers, so we don't want to create a full SparkContext with the desired scheduler.
- val sc = new SparkContext("local", "test")
- val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler)
- val sched = SparkContext invokePrivate createTaskSchedulerMethod(sc, master)
+ sc = new SparkContext("local", "test")
+ val createTaskSchedulerMethod =
+ PrivateMethod[Tuple2[SchedulerBackend, TaskScheduler]]('createTaskScheduler)
+ val (_, sched) = SparkContext invokePrivate createTaskSchedulerMethod(sc, master)
sched.asInstanceOf[TaskSchedulerImpl]
}
diff --git a/core/src/test/scala/org/apache/spark/StatusAPISuite.scala b/core/src/test/scala/org/apache/spark/StatusAPISuite.scala
new file mode 100644
index 0000000000000..4468fba8c1dff
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/StatusAPISuite.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import scala.concurrent.duration._
+import scala.language.implicitConversions
+import scala.language.postfixOps
+
+import org.scalatest.{Matchers, FunSuite}
+import org.scalatest.concurrent.Eventually._
+
+import org.apache.spark.JobExecutionStatus._
+import org.apache.spark.SparkContext._
+
+class StatusAPISuite extends FunSuite with Matchers with SharedSparkContext {
+
+ test("basic status API usage") {
+ val jobFuture = sc.parallelize(1 to 10000, 2).map(identity).groupBy(identity).collectAsync()
+ val jobId: Int = eventually(timeout(10 seconds)) {
+ val jobIds = jobFuture.jobIds
+ jobIds.size should be(1)
+ jobIds.head
+ }
+ val jobInfo = eventually(timeout(10 seconds)) {
+ sc.getJobInfo(jobId).get
+ }
+ jobInfo.status() should not be FAILED
+ val stageIds = jobInfo.stageIds()
+ stageIds.size should be(2)
+
+ val firstStageInfo = eventually(timeout(10 seconds)) {
+ sc.getStageInfo(stageIds(0)).get
+ }
+ firstStageInfo.stageId() should be(stageIds(0))
+ firstStageInfo.currentAttemptId() should be(0)
+ firstStageInfo.numTasks() should be(2)
+ eventually(timeout(10 seconds)) {
+ val updatedFirstStageInfo = sc.getStageInfo(stageIds(0)).get
+ updatedFirstStageInfo.numCompletedTasks() should be(2)
+ updatedFirstStageInfo.numActiveTasks() should be(0)
+ updatedFirstStageInfo.numFailedTasks() should be(0)
+ }
+ }
+
+ test("getJobIdsForGroup()") {
+ sc.setJobGroup("my-job-group", "description")
+ sc.getJobIdsForGroup("my-job-group") should be (Seq.empty)
+ val firstJobFuture = sc.parallelize(1 to 1000).countAsync()
+ val firstJobId = eventually(timeout(10 seconds)) {
+ firstJobFuture.jobIds.head
+ }
+ eventually(timeout(10 seconds)) {
+ sc.getJobIdsForGroup("my-job-group") should be (Seq(firstJobId))
+ }
+ val secondJobFuture = sc.parallelize(1 to 1000).countAsync()
+ val secondJobId = eventually(timeout(10 seconds)) {
+ secondJobFuture.jobIds.head
+ }
+ eventually(timeout(10 seconds)) {
+ sc.getJobIdsForGroup("my-job-group").toSet should be (Set(firstJobId, secondJobId))
+ }
+ }
+}
\ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index acaf321de52fb..b0a70f012f1f3 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -17,11 +17,31 @@
package org.apache.spark.broadcast
-import org.scalatest.FunSuite
+import scala.util.Random
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.scalatest.{Assertions, FunSuite}
+
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkEnv}
+import org.apache.spark.io.SnappyCompressionCodec
+import org.apache.spark.rdd.RDD
+import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage._
+// Dummy class that creates a broadcast variable but doesn't use it
+class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable {
+ @transient val list = List(1, 2, 3, 4)
+ val broadcast = rdd.context.broadcast(list)
+ val bid = broadcast.id
+
+ def doSomething() = {
+ rdd.map { x =>
+ val bm = SparkEnv.get.blockManager
+ // Check if broadcast block was fetched
+ val isFound = bm.getLocal(BroadcastBlockId(bid)).isDefined
+ (x, isFound)
+ }.collect().toSet
+ }
+}
class BroadcastSuite extends FunSuite with LocalSparkContext {
@@ -84,6 +104,35 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet)
}
+ test("TorrentBroadcast's blockifyObject and unblockifyObject are inverses") {
+ import org.apache.spark.broadcast.TorrentBroadcast._
+ val blockSize = 1024
+ val conf = new SparkConf()
+ val compressionCodec = Some(new SnappyCompressionCodec(conf))
+ val serializer = new JavaSerializer(conf)
+ val seed = 42
+ val rand = new Random(seed)
+ for (trial <- 1 to 100) {
+ val size = 1 + rand.nextInt(1024 * 10)
+ val data: Array[Byte] = new Array[Byte](size)
+ rand.nextBytes(data)
+ val blocks = blockifyObject(data, blockSize, serializer, compressionCodec)
+ val unblockified = unBlockifyObject[Array[Byte]](blocks, serializer, compressionCodec)
+ assert(unblockified === data)
+ }
+ }
+
+ test("Test Lazy Broadcast variables with TorrentBroadcast") {
+ val numSlaves = 2
+ val conf = torrentConf.clone
+ sc = new SparkContext("local-cluster[%d, 1, 512]".format(numSlaves), "test", conf)
+ val rdd = sc.parallelize(1 to numSlaves)
+
+ val results = new DummyBroadcastClass(rdd).doSomething()
+
+ assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet)
+ }
+
test("Unpersisting HttpBroadcast on executors only in local mode") {
testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false)
}
@@ -115,6 +164,12 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
test("Unpersisting TorrentBroadcast on executors and driver in distributed mode") {
testUnpersistTorrentBroadcast(distributed = true, removeFromDriver = true)
}
+
+ test("Using broadcast after destroy prints callsite") {
+ sc = new SparkContext("local", "test")
+ testPackage.runCallSiteTest(sc)
+ }
+
/**
* Verify the persistence of state associated with an HttpBroadcast in either local mode or
* local-cluster mode (when distributed = true).
@@ -193,26 +248,17 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
blockId = BroadcastBlockId(broadcastId, "piece0")
statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- assert(statuses.size === (if (distributed) 1 else 0))
+ assert(statuses.size === 1)
}
// Verify that blocks are persisted in both the executors and the driver
def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
var blockId = BroadcastBlockId(broadcastId)
- var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- if (distributed) {
- assert(statuses.size === numSlaves + 1)
- } else {
- assert(statuses.size === 1)
- }
+ val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+ assert(statuses.size === numSlaves + 1)
blockId = BroadcastBlockId(broadcastId, "piece0")
- statuses = bmm.getBlockStatus(blockId, askSlaves = true)
- if (distributed) {
- assert(statuses.size === numSlaves + 1)
- } else {
- assert(statuses.size === 0)
- }
+ assert(statuses.size === numSlaves + 1)
}
// Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
@@ -224,7 +270,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
assert(statuses.size === expectedNumBlocks)
blockId = BroadcastBlockId(broadcastId, "piece0")
- expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1
+ expectedNumBlocks = if (removeFromDriver) 0 else 1
statuses = bmm.getBlockStatus(blockId, askSlaves = true)
assert(statuses.size === expectedNumBlocks)
}
@@ -299,3 +345,15 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
conf
}
}
+
+package object testPackage extends Assertions {
+
+ def runCallSiteTest(sc: SparkContext) {
+ val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val broadcast = sc.broadcast(rdd)
+ broadcast.destroy()
+ val thrown = intercept[SparkException] { broadcast.value }
+ assert(thrown.getMessage.contains("BroadcastSuite.scala"))
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala
index 4161aede1d1d0..94a2bdd74e744 100644
--- a/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/ClientSuite.scala
@@ -29,6 +29,12 @@ class ClientSuite extends FunSuite with Matchers {
ClientArguments.isValidJarUrl("hdfs://someHost:1234/foo") should be (false)
ClientArguments.isValidJarUrl("/missing/a/protocol/jarfile.jar") should be (false)
ClientArguments.isValidJarUrl("not-even-a-path.jar") should be (false)
+
+ // No authority
+ ClientArguments.isValidJarUrl("hdfs:someHost:1234/jarfile.jar") should be (false)
+
+ // Invalid syntax
+ ClientArguments.isValidJarUrl("hdfs:") should be (false)
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/CommandUtilsSuite.scala
new file mode 100644
index 0000000000000..7915ee75d8778
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/CommandUtilsSuite.scala
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy
+
+import org.apache.spark.deploy.worker.CommandUtils
+import org.apache.spark.util.Utils
+
+import org.scalatest.{FunSuite, Matchers}
+
+class CommandUtilsSuite extends FunSuite with Matchers {
+
+ test("set libraryPath correctly") {
+ val appId = "12345-worker321-9876"
+ val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
+ val cmd = new Command("mainClass", Seq(), Map(), Seq(), Seq("libraryPathToB"), Seq())
+ val builder = CommandUtils.buildProcessBuilder(cmd, 512, sparkHome, t => t)
+ val libraryPath = Utils.libraryPathEnvName
+ val env = builder.environment
+ env.keySet should contain(libraryPath)
+ assert(env.get(libraryPath).startsWith("libraryPathToB"))
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
index 1cdf50d5c08c7..d8cd0ff2c9026 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala
@@ -292,7 +292,7 @@ class SparkSubmitSuite extends FunSuite with Matchers {
runSparkSubmit(args)
}
- test("spark submit includes jars passed in through --jar") {
+ test("includes jars passed in through --jars") {
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA"))
val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB"))
@@ -306,6 +306,110 @@ class SparkSubmitSuite extends FunSuite with Matchers {
runSparkSubmit(args)
}
+ test("resolves command line argument paths correctly") {
+ val jars = "/jar1,/jar2" // --jars
+ val files = "hdfs:/file1,file2" // --files
+ val archives = "file:/archive1,archive2" // --archives
+ val pyFiles = "py-file1,py-file2" // --py-files
+
+ // Test jars and files
+ val clArgs = Seq(
+ "--master", "local",
+ "--class", "org.SomeClass",
+ "--jars", jars,
+ "--files", files,
+ "thejar.jar")
+ val appArgs = new SparkSubmitArguments(clArgs)
+ val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3
+ appArgs.jars should be (Utils.resolveURIs(jars))
+ appArgs.files should be (Utils.resolveURIs(files))
+ sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar"))
+ sysProps("spark.files") should be (Utils.resolveURIs(files))
+
+ // Test files and archives (Yarn)
+ val clArgs2 = Seq(
+ "--master", "yarn-client",
+ "--class", "org.SomeClass",
+ "--files", files,
+ "--archives", archives,
+ "thejar.jar"
+ )
+ val appArgs2 = new SparkSubmitArguments(clArgs2)
+ val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3
+ appArgs2.files should be (Utils.resolveURIs(files))
+ appArgs2.archives should be (Utils.resolveURIs(archives))
+ sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files))
+ sysProps2("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives))
+
+ // Test python files
+ val clArgs3 = Seq(
+ "--master", "local",
+ "--py-files", pyFiles,
+ "mister.py"
+ )
+ val appArgs3 = new SparkSubmitArguments(clArgs3)
+ val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3
+ appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles))
+ sysProps3("spark.submit.pyFiles") should be (
+ PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(","))
+ }
+
+ test("resolves config paths correctly") {
+ val jars = "/jar1,/jar2" // spark.jars
+ val files = "hdfs:/file1,file2" // spark.files / spark.yarn.dist.files
+ val archives = "file:/archive1,archive2" // spark.yarn.dist.archives
+ val pyFiles = "py-file1,py-file2" // spark.submit.pyFiles
+
+ // Test jars and files
+ val f1 = File.createTempFile("test-submit-jars-files", "")
+ val writer1 = new PrintWriter(f1)
+ writer1.println("spark.jars " + jars)
+ writer1.println("spark.files " + files)
+ writer1.close()
+ val clArgs = Seq(
+ "--master", "local",
+ "--class", "org.SomeClass",
+ "--properties-file", f1.getPath,
+ "thejar.jar"
+ )
+ val appArgs = new SparkSubmitArguments(clArgs)
+ val sysProps = SparkSubmit.createLaunchEnv(appArgs)._3
+ sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar"))
+ sysProps("spark.files") should be(Utils.resolveURIs(files))
+
+ // Test files and archives (Yarn)
+ val f2 = File.createTempFile("test-submit-files-archives", "")
+ val writer2 = new PrintWriter(f2)
+ writer2.println("spark.yarn.dist.files " + files)
+ writer2.println("spark.yarn.dist.archives " + archives)
+ writer2.close()
+ val clArgs2 = Seq(
+ "--master", "yarn-client",
+ "--class", "org.SomeClass",
+ "--properties-file", f2.getPath,
+ "thejar.jar"
+ )
+ val appArgs2 = new SparkSubmitArguments(clArgs2)
+ val sysProps2 = SparkSubmit.createLaunchEnv(appArgs2)._3
+ sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files))
+ sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives))
+
+ // Test python files
+ val f3 = File.createTempFile("test-submit-python-files", "")
+ val writer3 = new PrintWriter(f3)
+ writer3.println("spark.submit.pyFiles " + pyFiles)
+ writer3.close()
+ val clArgs3 = Seq(
+ "--master", "local",
+ "--properties-file", f3.getPath,
+ "mister.py"
+ )
+ val appArgs3 = new SparkSubmitArguments(clArgs3)
+ val sysProps3 = SparkSubmit.createLaunchEnv(appArgs3)._3
+ sysProps3("spark.submit.pyFiles") should be(
+ PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(","))
+ }
+
test("SPARK_CONF_DIR overrides spark-defaults.conf") {
forConfDir(Map("spark.executor.memory" -> "2.3g")) { path =>
val unusedJar = TestUtils.createJarWithClasses(Seq.empty)
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index 5e2592e8d2e8d..196217062991e 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -19,6 +19,8 @@ package org.apache.spark.deploy.worker
import java.io.File
+import scala.collection.JavaConversions._
+
import org.scalatest.FunSuite
import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState}
@@ -32,6 +34,7 @@ class ExecutorRunnerTest extends FunSuite {
Command("foo", Seq(appId), Map(), Seq(), Seq(), Seq()), "appUiUrl")
val er = new ExecutorRunner(appId, 1, appDesc, 8, 500, null, "blah", "worker321",
new File(sparkHome), new File("ooga"), "blah", new SparkConf, ExecutorState.RUNNING)
- assert(er.getCommandSeq.last === appId)
+ val builder = CommandUtils.buildProcessBuilder(appDesc.command, 512, sparkHome, er.substituteVariables)
+ assert(builder.command().last === appId)
}
}
diff --git a/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala
new file mode 100644
index 0000000000000..48c386ba04311
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/metrics/InputMetricsSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SharedSparkContext
+import org.apache.spark.scheduler.{SparkListenerTaskEnd, SparkListener}
+
+import scala.collection.mutable.ArrayBuffer
+
+import java.io.{FileWriter, PrintWriter, File}
+
+class InputMetricsSuite extends FunSuite with SharedSparkContext {
+ test("input metrics when reading text file with single split") {
+ val file = new File(getClass.getSimpleName + ".txt")
+ val pw = new PrintWriter(new FileWriter(file))
+ pw.println("some stuff")
+ pw.println("some other stuff")
+ pw.println("yet more stuff")
+ pw.println("too much stuff")
+ pw.close()
+ file.deleteOnExit()
+
+ val taskBytesRead = new ArrayBuffer[Long]()
+ sc.addSparkListener(new SparkListener() {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead
+ }
+ })
+ sc.textFile("file://" + file.getAbsolutePath, 2).count()
+
+ // Wait for task end events to come in
+ sc.listenerBus.waitUntilEmpty(500)
+ assert(taskBytesRead.length == 2)
+ assert(taskBytesRead.sum >= file.length())
+ }
+
+ test("input metrics when reading text file with multiple splits") {
+ val file = new File(getClass.getSimpleName + ".txt")
+ val pw = new PrintWriter(new FileWriter(file))
+ for (i <- 0 until 10000) {
+ pw.println("some stuff")
+ }
+ pw.close()
+ file.deleteOnExit()
+
+ val taskBytesRead = new ArrayBuffer[Long]()
+ sc.addSparkListener(new SparkListener() {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
+ taskBytesRead += taskEnd.taskMetrics.inputMetrics.get.bytesRead
+ }
+ })
+ sc.textFile("file://" + file.getAbsolutePath, 2).count()
+
+ // Wait for task end events to come in
+ sc.listenerBus.waitUntilEmpty(500)
+ assert(taskBytesRead.length == 2)
+ assert(taskBytesRead.sum >= file.length())
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
index 3925f0ccbdbf0..bbdc9568a6ddb 100644
--- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
+++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala
@@ -121,7 +121,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod
}
val appId = "testId"
- val executorId = "executor.1"
+ val executorId = "1"
conf.set("spark.app.id", appId)
conf.set("spark.executor.id", executorId)
@@ -138,7 +138,7 @@ class MetricsSystemSuite extends FunSuite with BeforeAndAfter with PrivateMethod
override val metricRegistry = new MetricRegistry()
}
- val executorId = "executor.1"
+ val executorId = "1"
conf.set("spark.executor.id", executorId)
val instanceName = "executor"
diff --git a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
deleted file mode 100644
index 02d0ffc86f58f..0000000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/ServerClientIntegrationSuite.scala
+++ /dev/null
@@ -1,161 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty
-
-import java.io.{RandomAccessFile, File}
-import java.nio.ByteBuffer
-import java.util.{Collections, HashSet}
-import java.util.concurrent.{TimeUnit, Semaphore}
-
-import scala.collection.JavaConversions._
-
-import io.netty.buffer.{ByteBufUtil, Unpooled}
-
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
-
-import org.apache.spark.SparkConf
-import org.apache.spark.network.netty.client.{BlockClientListener, ReferenceCountedBuffer, BlockFetchingClientFactory}
-import org.apache.spark.network.netty.server.BlockServer
-import org.apache.spark.storage.{FileSegment, BlockDataProvider}
-
-
-/**
- * Test suite that makes sure the server and the client implementations share the same protocol.
- */
-class ServerClientIntegrationSuite extends FunSuite with BeforeAndAfterAll {
-
- val bufSize = 100000
- var buf: ByteBuffer = _
- var testFile: File = _
- var server: BlockServer = _
- var clientFactory: BlockFetchingClientFactory = _
-
- val bufferBlockId = "buffer_block"
- val fileBlockId = "file_block"
-
- val fileContent = new Array[Byte](1024)
- scala.util.Random.nextBytes(fileContent)
-
- override def beforeAll() = {
- buf = ByteBuffer.allocate(bufSize)
- for (i <- 1 to bufSize) {
- buf.put(i.toByte)
- }
- buf.flip()
-
- testFile = File.createTempFile("netty-test-file", "txt")
- val fp = new RandomAccessFile(testFile, "rw")
- fp.write(fileContent)
- fp.close()
-
- server = new BlockServer(new SparkConf, new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
- if (blockId == bufferBlockId) {
- Right(buf)
- } else if (blockId == fileBlockId) {
- Left(new FileSegment(testFile, 10, testFile.length - 25))
- } else {
- throw new Exception("Unknown block id " + blockId)
- }
- }
- })
-
- clientFactory = new BlockFetchingClientFactory(new SparkConf)
- }
-
- override def afterAll() = {
- server.stop()
- clientFactory.stop()
- }
-
- /** A ByteBuf for buffer_block */
- lazy val byteBufferBlockReference = Unpooled.wrappedBuffer(buf)
-
- /** A ByteBuf for file_block */
- lazy val fileBlockReference = Unpooled.wrappedBuffer(fileContent, 10, fileContent.length - 25)
-
- def fetchBlocks(blockIds: Seq[String]): (Set[String], Set[ReferenceCountedBuffer], Set[String]) =
- {
- val client = clientFactory.createClient(server.hostName, server.port)
- val sem = new Semaphore(0)
- val receivedBlockIds = Collections.synchronizedSet(new HashSet[String])
- val errorBlockIds = Collections.synchronizedSet(new HashSet[String])
- val receivedBuffers = Collections.synchronizedSet(new HashSet[ReferenceCountedBuffer])
-
- client.fetchBlocks(
- blockIds,
- new BlockClientListener {
- override def onFetchFailure(blockId: String, errorMsg: String): Unit = {
- errorBlockIds.add(blockId)
- sem.release()
- }
-
- override def onFetchSuccess(blockId: String, data: ReferenceCountedBuffer): Unit = {
- receivedBlockIds.add(blockId)
- data.retain()
- receivedBuffers.add(data)
- sem.release()
- }
- }
- )
- if (!sem.tryAcquire(blockIds.size, 30, TimeUnit.SECONDS)) {
- fail("Timeout getting response from the server")
- }
- client.close()
- (receivedBlockIds.toSet, receivedBuffers.toSet, errorBlockIds.toSet)
- }
-
- test("fetch a ByteBuffer block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId))
- assert(blockIds === Set(bufferBlockId))
- assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch a FileSegment block via zero-copy send") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(fileBlockId))
- assert(blockIds === Set(fileBlockId))
- assert(buffers.map(_.underlying) === Set(fileBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch a non-existent block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq("random-block"))
- assert(blockIds.isEmpty)
- assert(buffers.isEmpty)
- assert(failBlockIds === Set("random-block"))
- }
-
- test("fetch both ByteBuffer block and FileSegment block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, fileBlockId))
- assert(blockIds === Set(bufferBlockId, fileBlockId))
- assert(buffers.map(_.underlying) === Set(byteBufferBlockReference, fileBlockReference))
- assert(failBlockIds.isEmpty)
- buffers.foreach(_.release())
- }
-
- test("fetch both ByteBuffer block and a non-existent block") {
- val (blockIds, buffers, failBlockIds) = fetchBlocks(Seq(bufferBlockId, "random-block"))
- assert(blockIds === Set(bufferBlockId))
- assert(buffers.map(_.underlying) === Set(byteBufferBlockReference))
- assert(failBlockIds === Set("random-block"))
- buffers.foreach(_.release())
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
deleted file mode 100644
index 903ab09ae4322..0000000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/client/BlockFetchingClientHandlerSuite.scala
+++ /dev/null
@@ -1,105 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.client
-
-import java.nio.ByteBuffer
-
-import io.netty.buffer.Unpooled
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.{PrivateMethodTester, FunSuite}
-
-
-class BlockFetchingClientHandlerSuite extends FunSuite with PrivateMethodTester {
-
- test("handling block data (successful fetch)") {
- val blockId = "test_block"
- val blockData = "blahblahblahblahblah"
- val totalLength = 4 + blockId.length + blockData.length
-
- var parsedBlockId: String = ""
- var parsedBlockData: String = ""
- val handler = new BlockFetchingClientHandler
- handler.addRequest(blockId,
- new BlockClientListener {
- override def onFetchFailure(blockId: String, errorMsg: String): Unit = ???
- override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer): Unit = {
- parsedBlockId = bid
- val bytes = new Array[Byte](refCntBuf.byteBuffer().remaining)
- refCntBuf.byteBuffer().get(bytes)
- parsedBlockData = new String(bytes)
- }
- }
- )
-
- val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
- assert(handler.invokePrivate(outstandingRequests()).size === 1)
-
- val channel = new EmbeddedChannel(handler)
- val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
- buf.putInt(totalLength)
- buf.putInt(blockId.length)
- buf.put(blockId.getBytes)
- buf.put(blockData.getBytes)
- buf.flip()
-
- channel.writeInbound(Unpooled.wrappedBuffer(buf))
- assert(parsedBlockId === blockId)
- assert(parsedBlockData === blockData)
-
- assert(handler.invokePrivate(outstandingRequests()).size === 0)
-
- channel.close()
- }
-
- test("handling error message (failed fetch)") {
- val blockId = "test_block"
- val errorMsg = "error erro5r error err4or error3 error6 error erro1r"
- val totalLength = 4 + blockId.length + errorMsg.length
-
- var parsedBlockId: String = ""
- var parsedErrorMsg: String = ""
- val handler = new BlockFetchingClientHandler
- handler.addRequest(blockId, new BlockClientListener {
- override def onFetchFailure(bid: String, msg: String) ={
- parsedBlockId = bid
- parsedErrorMsg = msg
- }
- override def onFetchSuccess(bid: String, refCntBuf: ReferenceCountedBuffer) = ???
- })
-
- val outstandingRequests = PrivateMethod[java.util.Map[_, _]]('outstandingRequests)
- assert(handler.invokePrivate(outstandingRequests()).size === 1)
-
- val channel = new EmbeddedChannel(handler)
- val buf = ByteBuffer.allocate(totalLength + 4) // 4 bytes for the length field itself
- buf.putInt(totalLength)
- buf.putInt(-blockId.length)
- buf.put(blockId.getBytes)
- buf.put(errorMsg.getBytes)
- buf.flip()
-
- channel.writeInbound(Unpooled.wrappedBuffer(buf))
- assert(parsedBlockId === blockId)
- assert(parsedErrorMsg === errorMsg)
-
- assert(handler.invokePrivate(outstandingRequests()).size === 0)
-
- channel.close()
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
deleted file mode 100644
index 3ee281cb1350b..0000000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockHeaderEncoderSuite.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import io.netty.buffer.ByteBuf
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.FunSuite
-
-
-class BlockHeaderEncoderSuite extends FunSuite {
-
- test("encode normal block data") {
- val blockId = "test_block"
- val channel = new EmbeddedChannel(new BlockHeaderEncoder)
- channel.writeOutbound(new BlockHeader(17, blockId, None))
- val out = channel.readOutbound().asInstanceOf[ByteBuf]
- assert(out.readInt() === 4 + blockId.length + 17)
- assert(out.readInt() === blockId.length)
-
- val blockIdBytes = new Array[Byte](blockId.length)
- out.readBytes(blockIdBytes)
- assert(new String(blockIdBytes) === blockId)
- assert(out.readableBytes() === 0)
-
- channel.close()
- }
-
- test("encode error message") {
- val blockId = "error_block"
- val errorMsg = "error encountered"
- val channel = new EmbeddedChannel(new BlockHeaderEncoder)
- channel.writeOutbound(new BlockHeader(17, blockId, Some(errorMsg)))
- val out = channel.readOutbound().asInstanceOf[ByteBuf]
- assert(out.readInt() === 4 + blockId.length + errorMsg.length)
- assert(out.readInt() === -blockId.length)
-
- val blockIdBytes = new Array[Byte](blockId.length)
- out.readBytes(blockIdBytes)
- assert(new String(blockIdBytes) === blockId)
-
- val errorMsgBytes = new Array[Byte](errorMsg.length)
- out.readBytes(errorMsgBytes)
- assert(new String(errorMsgBytes) === errorMsg)
- assert(out.readableBytes() === 0)
-
- channel.close()
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
deleted file mode 100644
index 3239c710f1639..0000000000000
--- a/core/src/test/scala/org/apache/spark/network/netty/server/BlockServerHandlerSuite.scala
+++ /dev/null
@@ -1,107 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.network.netty.server
-
-import java.io.{RandomAccessFile, File}
-import java.nio.ByteBuffer
-
-import io.netty.buffer.{Unpooled, ByteBuf}
-import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler, DefaultFileRegion}
-import io.netty.channel.embedded.EmbeddedChannel
-
-import org.scalatest.FunSuite
-
-import org.apache.spark.storage.{BlockDataProvider, FileSegment}
-
-
-class BlockServerHandlerSuite extends FunSuite {
-
- test("ByteBuffer block") {
- val expectedBlockId = "test_bytebuffer_block"
- val buf = ByteBuffer.allocate(10000)
- for (i <- 1 to 10000) {
- buf.put(i.toByte)
- }
- buf.flip()
-
- val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = Right(buf)
- }))
-
- channel.writeInbound(expectedBlockId)
- assert(channel.outboundMessages().size === 2)
-
- val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
- val out2 = channel.readOutbound().asInstanceOf[ByteBuf]
-
- assert(out1.blockId === expectedBlockId)
- assert(out1.blockSize === buf.remaining)
- assert(out1.error === None)
-
- assert(out2.equals(Unpooled.wrappedBuffer(buf)))
-
- channel.close()
- }
-
- test("FileSegment block via zero-copy") {
- val expectedBlockId = "test_file_block"
-
- // Create random file data
- val fileContent = new Array[Byte](1024)
- scala.util.Random.nextBytes(fileContent)
- val testFile = File.createTempFile("netty-test-file", "txt")
- val fp = new RandomAccessFile(testFile, "rw")
- fp.write(fileContent)
- fp.close()
-
- val channel = new EmbeddedChannel(new BlockServerHandler(new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = {
- Left(new FileSegment(testFile, 15, testFile.length - 25))
- }
- }))
-
- channel.writeInbound(expectedBlockId)
- assert(channel.outboundMessages().size === 2)
-
- val out1 = channel.readOutbound().asInstanceOf[BlockHeader]
- val out2 = channel.readOutbound().asInstanceOf[DefaultFileRegion]
-
- assert(out1.blockId === expectedBlockId)
- assert(out1.blockSize === testFile.length - 25)
- assert(out1.error === None)
-
- assert(out2.count === testFile.length - 25)
- assert(out2.position === 15)
- }
-
- test("pipeline exception propagation") {
- val blockServerHandler = new BlockServerHandler(new BlockDataProvider {
- override def getBlockData(blockId: String): Either[FileSegment, ByteBuffer] = ???
- })
- val exceptionHandler = new SimpleChannelInboundHandler[String]() {
- override def channelRead0(ctx: ChannelHandlerContext, msg: String): Unit = {
- throw new Exception("this is an error")
- }
- }
-
- val channel = new EmbeddedChannel(exceptionHandler, blockServerHandler)
- assert(channel.isOpen)
- channel.writeInbound("a message to trigger the error")
- assert(!channel.isOpen)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 465c1a8a43a79..6d2e696dc2fc4 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -459,6 +459,11 @@ class RDDSuite extends FunSuite with SharedSparkContext {
for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
}
+ test("collect large number of empty partitions") {
+ // Regression test for SPARK-4019
+ assert(sc.makeRDD(0 until 10, 1000).repartition(2001).collect().toSet === (0 until 10).toSet)
+ }
+
test("take") {
var nums = sc.makeRDD(Range(1, 1000), 1)
assert(nums.take(0).size === 0)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index a2e4f712db55b..819f95634bcdc 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -431,7 +431,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
// the 2nd ResultTask failed
complete(taskSets(1), Seq(
(Success, 42),
- (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"), null)))
// this will get called
// blockManagerMaster.removeExecutor("exec-hostA")
// ask the scheduler to try it again
@@ -461,7 +461,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
// The first result task fails, with a fetch failure for the output from the first mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0, "ignored"),
null,
Map[Long, Any](),
null,
@@ -472,7 +472,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
// The second ResultTask fails, with a fetch failure for the output from the second mapper.
runEvent(CompletionEvent(
taskSets(1).tasks(0),
- FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1),
+ FetchFailed(makeBlockManagerId("hostA"), shuffleId, 1, 1, "ignored"),
null,
Map[Long, Any](),
null,
@@ -624,7 +624,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
(Success, makeMapStatus("hostC", 1))))
// fail the third stage because hostA went down
complete(taskSets(2), Seq(
- (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
// TODO assert this:
// blockManagerMaster.removeExecutor("exec-hostA")
// have DAGScheduler try again
@@ -655,7 +655,7 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
(Success, makeMapStatus("hostB", 1))))
// pretend stage 0 failed because hostA went down
complete(taskSets(2), Seq(
- (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+ (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0, "ignored"), null)))
// TODO assert this:
// blockManagerMaster.removeExecutor("exec-hostA")
// DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
index 79e04f046e4c4..950c6dc58e332 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala
@@ -23,6 +23,7 @@ import org.scalatest.FunSuite
import org.apache.spark.SparkConf
import org.apache.spark.serializer.JavaSerializer
+import scala.util.Random
class MapStatusSuite extends FunSuite {
@@ -46,6 +47,26 @@ class MapStatusSuite extends FunSuite {
}
}
+ test("MapStatus should never report non-empty blocks' sizes as 0") {
+ import Math._
+ for (
+ numSizes <- Seq(1, 10, 100, 1000, 10000);
+ mean <- Seq(0L, 100L, 10000L, Int.MaxValue.toLong);
+ stddev <- Seq(0.0, 0.01, 0.5, 1.0)
+ ) {
+ val sizes = Array.fill[Long](numSizes)(abs(round(Random.nextGaussian() * stddev)) + mean)
+ val status = MapStatus(BlockManagerId("a", "b", 10), sizes)
+ val status1 = compressAndDecompressMapStatus(status)
+ for (i <- 0 until numSizes) {
+ if (sizes(i) != 0) {
+ val failureMessage = s"Failed with $numSizes sizes with mean=$mean, stddev=$stddev"
+ assert(status.getSizeForBlock(i) !== 0, failureMessage)
+ assert(status1.getSizeForBlock(i) !== 0, failureMessage)
+ }
+ }
+ }
+ }
+
test("large tasks should use " + classOf[HighlyCompressedMapStatus].getName) {
val sizes = Array.fill[Long](2001)(150L)
val status = MapStatus(null, sizes)
@@ -56,37 +77,25 @@ class MapStatusSuite extends FunSuite {
assert(status.getSizeForBlock(2000) === 150L)
}
- test(classOf[HighlyCompressedMapStatus].getName + ": estimated size is within 10%") {
- val sizes = Array.tabulate[Long](50) { i => i.toLong }
+ test("HighlyCompressedMapStatus: estimated size should be the average non-empty block size") {
+ val sizes = Array.tabulate[Long](3000) { i => i.toLong }
+ val avg = sizes.sum / sizes.filter(_ != 0).length
val loc = BlockManagerId("a", "b", 10)
val status = MapStatus(loc, sizes)
- val ser = new JavaSerializer(new SparkConf)
- val buf = ser.newInstance().serialize(status)
- val status1 = ser.newInstance().deserialize[MapStatus](buf)
+ val status1 = compressAndDecompressMapStatus(status)
+ assert(status1.isInstanceOf[HighlyCompressedMapStatus])
assert(status1.location == loc)
- for (i <- 0 until sizes.length) {
- // make sure the estimated size is within 10% of the input; note that we skip the very small
- // sizes because the compression is very lossy there.
+ for (i <- 0 until 3000) {
val estimate = status1.getSizeForBlock(i)
- if (estimate > 100) {
- assert(math.abs(estimate - sizes(i)) * 10 <= sizes(i),
- s"incorrect estimated size $estimate, original was ${sizes(i)}")
+ if (sizes(i) > 0) {
+ assert(estimate === avg)
}
}
}
- test(classOf[HighlyCompressedMapStatus].getName + ": estimated size should be the average size") {
- val sizes = Array.tabulate[Long](3000) { i => i.toLong }
- val avg = sizes.sum / sizes.length
- val loc = BlockManagerId("a", "b", 10)
- val status = MapStatus(loc, sizes)
+ def compressAndDecompressMapStatus(status: MapStatus): MapStatus = {
val ser = new JavaSerializer(new SparkConf)
val buf = ser.newInstance().serialize(status)
- val status1 = ser.newInstance().deserialize[MapStatus](buf)
- assert(status1.location == loc)
- for (i <- 0 until 3000) {
- val estimate = status1.getSizeForBlock(i)
- assert(estimate === avg)
- }
+ ser.newInstance().deserialize[MapStatus](buf)
}
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index c4e7a4bb7d385..5768a3a733f00 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -40,7 +40,7 @@ class ResultDeletingTaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedule
// Only remove the result once, since we'd like to test the case where the task eventually
// succeeds.
serializer.get().deserialize[TaskResult[_]](serializedData) match {
- case IndirectTaskResult(blockId) =>
+ case IndirectTaskResult(blockId, size) =>
sparkEnv.blockManager.master.removeBlock(blockId)
case directResult: DirectTaskResult[_] =>
taskSetManager.abort("Internal error: expect only indirect results")
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index c0b07649eb6dd..1809b5396d53e 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -563,6 +563,31 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(manager.emittedTaskSizeWarning)
}
+ test("abort the job if total size of results is too large") {
+ val conf = new SparkConf().set("spark.driver.maxResultSize", "2m")
+ sc = new SparkContext("local", "test", conf)
+
+ def genBytes(size: Int) = { (x: Int) =>
+ val bytes = Array.ofDim[Byte](size)
+ scala.util.Random.nextBytes(bytes)
+ bytes
+ }
+
+ // multiple 1k result
+ val r = sc.makeRDD(0 until 10, 10).map(genBytes(1024)).collect()
+ assert(10 === r.size )
+
+ // single 10M result
+ val thrown = intercept[SparkException] {sc.makeRDD(genBytes(10 << 20)(0), 1).collect()}
+ assert(thrown.getMessage().contains("bigger than maxResultSize"))
+
+ // multiple 1M results
+ val thrown2 = intercept[SparkException] {
+ sc.makeRDD(0 until 10, 10).map(genBytes(1 << 20)).collect()
+ }
+ assert(thrown2.getMessage().contains("bigger than maxResultSize"))
+ }
+
test("speculative and noPref task should be scheduled after node-local") {
sc = new SparkContext("local", "test")
val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2"), ("execC", "host3"))
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index e1e35b688d581..a70f67af2e62e 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -201,22 +201,27 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
assert(control.sum === result)
}
- // TODO: this still doesn't work
- ignore("kryo with fold") {
+ test("kryo with fold") {
val control = 1 :: 2 :: Nil
+ // zeroValue must not be a ClassWithoutNoArgConstructor instance because it will be
+ // serialized by spark.closure.serializer but spark.closure.serializer only supports
+ // the default Java serializer.
val result = sc.parallelize(control, 2).map(new ClassWithoutNoArgConstructor(_))
- .fold(new ClassWithoutNoArgConstructor(10))((t1, t2) => new ClassWithoutNoArgConstructor(t1.x + t2.x)).x
- assert(10 + control.sum === result)
+ .fold(null)((t1, t2) => {
+ val t1x = if (t1 == null) 0 else t1.x
+ new ClassWithoutNoArgConstructor(t1x + t2.x)
+ }).x
+ assert(control.sum === result)
}
test("kryo with nonexistent custom registrator should fail") {
- import org.apache.spark.{SparkConf, SparkException}
+ import org.apache.spark.SparkException
val conf = new SparkConf(false)
conf.set("spark.kryo.registrator", "this.class.does.not.exist")
-
+
val thrown = intercept[SparkException](new KryoSerializer(conf).newInstance())
- assert(thrown.getMessage.contains("Failed to invoke this.class.does.not.exist"))
+ assert(thrown.getMessage.contains("Failed to register classes with Kryo"))
}
test("default class loader can be set by a different thread") {
diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
new file mode 100644
index 0000000000000..0ade1bab18d7e
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import java.io.{EOFException, OutputStream, InputStream}
+import java.nio.ByteBuffer
+
+import scala.reflect.ClassTag
+
+
+/**
+ * A serializer implementation that always return a single element in a deserialization stream.
+ */
+class TestSerializer extends Serializer {
+ override def newInstance() = new TestSerializerInstance
+}
+
+
+class TestSerializerInstance extends SerializerInstance {
+ override def serialize[T: ClassTag](t: T): ByteBuffer = ???
+
+ override def serializeStream(s: OutputStream): SerializationStream = ???
+
+ override def deserializeStream(s: InputStream) = new TestDeserializationStream
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer): T = ???
+
+ override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = ???
+}
+
+
+class TestDeserializationStream extends DeserializationStream {
+
+ private var count = 0
+
+ override def readObject[T: ClassTag](): T = {
+ count += 1
+ if (count == 2) {
+ throw new EOFException
+ }
+ new Object().asInstanceOf[T]
+ }
+
+ override def close(): Unit = {}
+}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
index ba47fe5e25b9b..6790388f96603 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala
@@ -25,7 +25,7 @@ import org.scalatest.FunSuite
import org.apache.spark.{SparkEnv, SparkContext, LocalSparkContext, SparkConf}
import org.apache.spark.executor.ShuffleWriteMetrics
-import org.apache.spark.network.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.shuffle.FileShuffleBlockManager
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
@@ -36,9 +36,9 @@ class HashShuffleManagerSuite extends FunSuite with LocalSparkContext {
private def checkSegments(expected: FileSegment, buffer: ManagedBuffer) {
assert(buffer.isInstanceOf[FileSegmentManagedBuffer])
val segment = buffer.asInstanceOf[FileSegmentManagedBuffer]
- assert(expected.file.getCanonicalPath === segment.file.getCanonicalPath)
- assert(expected.offset === segment.offset)
- assert(expected.length === segment.length)
+ assert(expected.file.getCanonicalPath === segment.getFile.getCanonicalPath)
+ assert(expected.offset === segment.getOffset)
+ assert(expected.length === segment.getLength)
}
test("consolidated shuffle can write to shuffle group without messing existing offsets/lengths") {
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index 1f1d53a1ee3b0..c6d7105592096 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -27,7 +27,7 @@ import org.mockito.Mockito.{mock, when}
import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
-import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
+import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager}
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
@@ -57,7 +57,9 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
// Implicitly convert strings to BlockIds for test clarity.
implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
- private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = {
+ private def makeBlockManager(
+ maxMem: Long,
+ name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer)
@@ -108,7 +110,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd
storeIds.filterNot { _ == stores(2).blockManagerId })
// Add driver store and test whether it is filtered out
- val driverStore = makeBlockManager(1000, "")
+ val driverStore = makeBlockManager(1000, SparkContext.DRIVER_IDENTIFIER)
assert(master.getPeers(stores(0).blockManagerId).forall(!_.isDriver))
assert(master.getPeers(stores(1).blockManagerId).forall(!_.isDriver))
assert(master.getPeers(stores(2).blockManagerId).forall(!_.isDriver))
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 9d96202a3e7ac..715b740b857b2 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -37,7 +37,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester}
import org.scalatest.concurrent.Eventually._
import org.scalatest.concurrent.Timeouts._
-import org.apache.spark.{MapOutputTrackerMaster, SecurityManager, SparkConf}
+import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager}
import org.apache.spark.executor.DataReadMethod
import org.apache.spark.network.nio.NioBlockTransferService
import org.apache.spark.scheduler.LiveListenerBus
@@ -69,7 +69,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value)
def rdd(rddId: Int, splitId: Int) = RDDBlockId(rddId, splitId)
- private def makeBlockManager(maxMem: Long, name: String = ""): BlockManager = {
+ private def makeBlockManager(
+ maxMem: Long,
+ name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = {
val transfer = new NioBlockTransferService(conf, securityMgr)
new BlockManager(name, actorSystem, master, serializer, maxMem, conf,
mapOutputTracker, shuffleManager, transfer)
@@ -790,8 +792,8 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfter
test("block store put failure") {
// Use Java serializer so we can create an unserializable error.
val transfer = new NioBlockTransferService(conf, securityMgr)
- store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf,
- mapOutputTracker, shuffleManager, transfer)
+ store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master,
+ new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer)
// The put should fail since a1 is not serializable.
class UnserializableClass
diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index a8c049d749015..1eaabb93adbed 100644
--- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -17,48 +17,75 @@
package org.apache.spark.storage
-import org.apache.spark.{TaskContextImpl, TaskContext}
-import org.apache.spark.network.{BlockFetchingListener, BlockTransferService}
+import java.util.concurrent.Semaphore
+
+import scala.concurrent.future
+import scala.concurrent.ExecutionContext.Implicits.global
-import org.mockito.Mockito._
import org.mockito.Matchers.{any, eq => meq}
+import org.mockito.Mockito._
import org.mockito.invocation.InvocationOnMock
import org.mockito.stubbing.Answer
-
import org.scalatest.FunSuite
+import org.apache.spark.{SparkConf, TaskContextImpl}
+import org.apache.spark.network._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.serializer.TestSerializer
class ShuffleBlockFetcherIteratorSuite extends FunSuite {
+ // Some of the tests are quite tricky because we are testing the cleanup behavior
+ // in the presence of faults.
- test("handle local read failures in BlockManager") {
+ /** Creates a mock [[BlockTransferService]] that returns data from the given map. */
+ private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = {
val transfer = mock(classOf[BlockTransferService])
- val blockManager = mock(classOf[BlockManager])
- doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
-
- val optItr = mock(classOf[Option[Iterator[Any]]])
- val answer = new Answer[Option[Iterator[Any]]] {
- override def answer(invocation: InvocationOnMock) = Option[Iterator[Any]] {
- throw new Exception
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]]
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+
+ for (blockId <- blocks) {
+ if (data.contains(BlockId(blockId))) {
+ listener.onBlockFetchSuccess(blockId, data(BlockId(blockId)))
+ } else {
+ listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId))
+ }
+ }
}
+ })
+ transfer
+ }
+
+ private val conf = new SparkConf
+
+ test("successful 3 local reads + 2 remote reads") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure blockManager.getBlockData would return the blocks
+ val localBlocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]))
+ localBlocks.foreach { case (blockId, buf) =>
+ doReturn(buf).when(blockManager).getBlockData(meq(blockId))
}
- // 3rd block is going to fail
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
- doAnswer(answer).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val remoteBlocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer])
+ )
+
+ val transfer = createMockTransfer(remoteBlocks)
- val bmId = BlockManagerId("test-client", "test-client", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
+ (localBmId, localBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq),
+ (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq)
)
val iterator = new ShuffleBlockFetcherIterator(
@@ -66,118 +93,145 @@ class ShuffleBlockFetcherIteratorSuite extends FunSuite {
transfer,
blockManager,
blocksByAddress,
- null,
+ new TestSerializer,
48 * 1024 * 1024)
- // Without exhausting the iterator, the iterator should be lazy and not call
- // getLocalShuffleFromDisk.
- verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- // the 2nd element of the tuple returned by iterator.next should be defined when
- // fetching successfully
- assert(iterator.next()._2.isDefined,
- "1st element should be defined but is not actually defined")
- verify(blockManager, times(1)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next()._2.isDefined,
- "2nd element should be defined but is not actually defined")
- verify(blockManager, times(2)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- // 3rd fetch should be failed
- intercept[Exception] {
- iterator.next()
+ // 3 local blocks fetched in initialization
+ verify(blockManager, times(3)).getBlockData(any())
+
+ for (i <- 0 until 5) {
+ assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
+ val (blockId, subIterator) = iterator.next()
+ assert(subIterator.isSuccess,
+ s"iterator should have 5 elements defined but actually has $i elements")
+
+ // Make sure we release the buffer once the iterator is exhausted.
+ val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
+ verify(mockBuf, times(0)).release()
+ subIterator.get.foreach(_ => Unit) // exhaust the iterator
+ verify(mockBuf, times(1)).release()
}
- verify(blockManager, times(3)).getLocalShuffleFromDisk(any(), any())
+
+ // 3 local blocks, and 2 remote blocks
+ // (but from the same block manager so one call to fetchBlocks)
+ verify(blockManager, times(3)).getBlockData(any())
+ verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any())
}
- test("handle local read successes") {
- val transfer = mock(classOf[BlockTransferService])
+ test("release current unexhausted buffer in case the task completes early") {
val blockManager = mock(classOf[BlockManager])
- doReturn(BlockManagerId("test-client", "test-client", 1)).when(blockManager).blockManagerId
-
- val blIds = Array[BlockId](
- ShuffleBlockId(0,0,0),
- ShuffleBlockId(0,1,0),
- ShuffleBlockId(0,2,0),
- ShuffleBlockId(0,3,0),
- ShuffleBlockId(0,4,0))
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
+ )
- val optItr = mock(classOf[Option[Iterator[Any]]])
+ // Semaphore to coordinate event sequence in two different threads.
+ val sem = new Semaphore(0)
- // All blocks should be fetched successfully
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(0)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(1)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(2)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(3)), any())
- doReturn(optItr).when(blockManager).getLocalShuffleFromDisk(meq(blIds(4)), any())
+ val transfer = mock(classOf[BlockTransferService])
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ override def answer(invocation: InvocationOnMock): Unit = {
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ future {
+ // Return the first two blocks, and wait till task completion before returning the 3rd one
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
+ sem.acquire()
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
+ }
+ }
+ })
- val bmId = BlockManagerId("test-client", "test-client", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
- )
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ val taskContext = new TaskContextImpl(0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0),
+ taskContext,
transfer,
blockManager,
blocksByAddress,
- null,
+ new TestSerializer,
48 * 1024 * 1024)
- // Without exhausting the iterator, the iterator should be lazy and not call getLocalShuffleFromDisk.
- verify(blockManager, times(0)).getLocalShuffleFromDisk(any(), any())
-
- assert(iterator.hasNext, "iterator should have 5 elements but actually has no elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 1st element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 1 element")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 2nd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 2 elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 3rd element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 3 elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 4th element is not actually defined")
- assert(iterator.hasNext, "iterator should have 5 elements but actually has 4 elements")
- assert(iterator.next()._2.isDefined,
- "All elements should be defined but 5th element is not actually defined")
-
- verify(blockManager, times(5)).getLocalShuffleFromDisk(any(), any())
+ // Exhaust the first block, and then it should be released.
+ iterator.next()._2.get.foreach(_ => Unit)
+ verify(blocks(ShuffleBlockId(0, 0, 0)), times(1)).release()
+
+ // Get the 2nd block but do not exhaust the iterator
+ val subIter = iterator.next()._2.get
+
+ // Complete the task; then the 2nd block buffer should be exhausted
+ verify(blocks(ShuffleBlockId(0, 1, 0)), times(0)).release()
+ taskContext.markTaskCompleted()
+ verify(blocks(ShuffleBlockId(0, 1, 0)), times(1)).release()
+
+ // The 3rd block should not be retained because the iterator is already in zombie state
+ sem.release()
+ verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).retain()
+ verify(blocks(ShuffleBlockId(0, 2, 0)), times(0)).release()
}
- test("handle remote fetch failures in BlockTransferService") {
+ test("fail all blocks if any of the remote request fails") {
+ val blockManager = mock(classOf[BlockManager])
+ val localBmId = BlockManagerId("test-client", "test-client", 1)
+ doReturn(localBmId).when(blockManager).blockManagerId
+
+ // Make sure remote blocks would return
+ val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+ val blocks = Map[BlockId, ManagedBuffer](
+ ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
+ ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer])
+ )
+
+ // Semaphore to coordinate event sequence in two different threads.
+ val sem = new Semaphore(0)
+
val transfer = mock(classOf[BlockTransferService])
- when(transfer.fetchBlocks(any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
+ when(transfer.fetchBlocks(any(), any(), any(), any(), any())).thenAnswer(new Answer[Unit] {
override def answer(invocation: InvocationOnMock): Unit = {
- val listener = invocation.getArguments()(3).asInstanceOf[BlockFetchingListener]
- listener.onBlockFetchFailure(new Exception("blah"))
+ val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+ future {
+ // Return the first block, and then fail.
+ listener.onBlockFetchSuccess(
+ ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah"))
+ listener.onBlockFetchFailure(
+ ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah"))
+ sem.release()
+ }
}
})
- val blockManager = mock(classOf[BlockManager])
-
- when(blockManager.blockManagerId).thenReturn(BlockManagerId("test-client", "test-client", 1))
-
- val blId1 = ShuffleBlockId(0, 0, 0)
- val blId2 = ShuffleBlockId(0, 1, 0)
- val bmId = BlockManagerId("test-server", "test-server", 1)
val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
- (bmId, Seq((blId1, 1L), (blId2, 1L))))
+ (remoteBmId, blocks.keys.map(blockId => (blockId, 1.asInstanceOf[Long])).toSeq))
+ val taskContext = new TaskContextImpl(0, 0, 0)
val iterator = new ShuffleBlockFetcherIterator(
- new TaskContextImpl(0, 0, 0),
+ taskContext,
transfer,
blockManager,
blocksByAddress,
- null,
+ new TestSerializer,
48 * 1024 * 1024)
- iterator.foreach { case (_, iterOption) =>
- assert(!iterOption.isDefined)
- }
+ // Continue only after the mock calls onBlockFetchFailure
+ sem.acquire()
+
+ // The first block should be defined, and the last two are not defined (due to failure)
+ assert(iterator.next()._2.isSuccess)
+ assert(iterator.next()._2.isFailure)
+ assert(iterator.next()._2.isFailure)
}
}
diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
new file mode 100644
index 0000000000000..bacf6a16fc233
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import org.apache.spark.api.java.StorageLevels
+import org.apache.spark.{SparkException, SparkConf, SparkContext}
+import org.openqa.selenium.WebDriver
+import org.openqa.selenium.htmlunit.HtmlUnitDriver
+import org.scalatest._
+import org.scalatest.concurrent.Eventually._
+import org.scalatest.selenium.WebBrowser
+import org.scalatest.time.SpanSugar._
+
+import org.apache.spark.LocalSparkContext._
+
+/**
+ * Selenium tests for the Spark Web UI. These tests are not run by default
+ * because they're slow.
+ */
+@DoNotDiscover
+class UISeleniumSuite extends FunSuite with WebBrowser with Matchers {
+ implicit val webDriver: WebDriver = new HtmlUnitDriver
+
+ /**
+ * Create a test SparkContext with the SparkUI enabled.
+ * It is safe to `get` the SparkUI directly from the SparkContext returned here.
+ */
+ private def newSparkContext(): SparkContext = {
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName("test")
+ .set("spark.ui.enabled", "true")
+ val sc = new SparkContext(conf)
+ assert(sc.ui.isDefined)
+ sc
+ }
+
+ test("effects of unpersist() / persist() should be reflected") {
+ // Regression test for SPARK-2527
+ withSpark(newSparkContext()) { sc =>
+ val ui = sc.ui.get
+ val rdd = sc.parallelize(Seq(1, 2, 3))
+ rdd.persist(StorageLevels.DISK_ONLY).count()
+ eventually(timeout(5 seconds), interval(50 milliseconds)) {
+ go to (ui.appUIAddress.stripSuffix("/") + "/storage")
+ val tableRowText = findAll(cssSelector("#storage-by-rdd-table td")).map(_.text).toSeq
+ tableRowText should contain (StorageLevels.DISK_ONLY.description)
+ }
+ eventually(timeout(5 seconds), interval(50 milliseconds)) {
+ go to (ui.appUIAddress.stripSuffix("/") + "/storage/rdd/?id=0")
+ val tableRowText = findAll(cssSelector("#rdd-storage-by-block-table td")).map(_.text).toSeq
+ tableRowText should contain (StorageLevels.DISK_ONLY.description)
+ }
+
+ rdd.unpersist()
+ rdd.persist(StorageLevels.MEMORY_ONLY).count()
+ eventually(timeout(5 seconds), interval(50 milliseconds)) {
+ go to (ui.appUIAddress.stripSuffix("/") + "/storage")
+ val tableRowText = findAll(cssSelector("#storage-by-rdd-table td")).map(_.text).toSeq
+ tableRowText should contain (StorageLevels.MEMORY_ONLY.description)
+ }
+ eventually(timeout(5 seconds), interval(50 milliseconds)) {
+ go to (ui.appUIAddress.stripSuffix("/") + "/storage/rdd/?id=0")
+ val tableRowText = findAll(cssSelector("#rdd-storage-by-block-table td")).map(_.text).toSeq
+ tableRowText should contain (StorageLevels.MEMORY_ONLY.description)
+ }
+ }
+ }
+
+ test("failed stages should not appear to be active") {
+ withSpark(newSparkContext()) { sc =>
+ // Regression test for SPARK-3021
+ intercept[SparkException] {
+ sc.parallelize(1 to 10).map { x => throw new Exception()}.collect()
+ }
+ eventually(timeout(5 seconds), interval(50 milliseconds)) {
+ go to sc.ui.get.appUIAddress
+ find(id("active")).get.text should be("Active Stages (0)")
+ find(id("failed")).get.text should be("Failed Stages (1)")
+ }
+
+ // Regression test for SPARK-2105
+ class NotSerializable
+ val unserializableObject = new NotSerializable
+ intercept[SparkException] {
+ sc.parallelize(1 to 10).map { x => unserializableObject}.collect()
+ }
+ eventually(timeout(5 seconds), interval(50 milliseconds)) {
+ go to sc.ui.get.appUIAddress
+ find(id("active")).get.text should be("Active Stages (0)")
+ // The failure occurs before the stage becomes active, hence we should still show only one
+ // failed stage, not two:
+ find(id("failed")).get.text should be("Failed Stages (1)")
+ }
+ }
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index 3370dd4156c3f..2efbae689771a 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -115,11 +115,11 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
// Go through all the failure cases to make sure we are counting them as failures.
val taskFailedReasons = Seq(
Resubmitted,
- new FetchFailed(null, 0, 0, 0),
+ new FetchFailed(null, 0, 0, 0, "ignored"),
new ExceptionFailure("Exception", "description", null, None),
TaskResultLost,
TaskKilled,
- ExecutorLostFailure,
+ ExecutorLostFailure("0"),
UnknownReason)
var failCount = 0
for (reason <- taskFailedReasons) {
diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
index d2bee448d4d3b..4dc5b6103db74 100644
--- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala
@@ -18,13 +18,13 @@
package org.apache.spark.util
import java.io._
-import java.nio.charset.Charset
import scala.collection.mutable.HashSet
import scala.reflect._
import org.scalatest.{BeforeAndAfter, FunSuite}
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.apache.spark.{Logging, SparkConf}
@@ -44,11 +44,11 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging {
test("basic file appender") {
val testString = (1 to 1000).mkString(", ")
- val inputStream = new ByteArrayInputStream(testString.getBytes(Charset.forName("UTF-8")))
+ val inputStream = new ByteArrayInputStream(testString.getBytes(UTF_8))
val appender = new FileAppender(inputStream, testFile)
inputStream.close()
appender.awaitTermination()
- assert(Files.toString(testFile, Charset.forName("UTF-8")) === testString)
+ assert(Files.toString(testFile, UTF_8) === testString)
}
test("rolling file appender - time-based rolling") {
@@ -96,7 +96,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging {
val allGeneratedFiles = new HashSet[String]()
val items = (1 to 10).map { _.toString * 10000 }
for (i <- 0 until items.size) {
- testOutputStream.write(items(i).getBytes(Charset.forName("UTF-8")))
+ testOutputStream.write(items(i).getBytes(UTF_8))
testOutputStream.flush()
allGeneratedFiles ++= RollingFileAppender.getSortedRolledOverFiles(
testFile.getParentFile.toString, testFile.getName).map(_.toString)
@@ -199,7 +199,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging {
// send data to appender through the input stream, and wait for the data to be written
val expectedText = textToAppend.mkString("")
for (i <- 0 until textToAppend.size) {
- outputStream.write(textToAppend(i).getBytes(Charset.forName("UTF-8")))
+ outputStream.write(textToAppend(i).getBytes(UTF_8))
outputStream.flush()
Thread.sleep(sleepTimeBetweenTexts)
}
@@ -214,7 +214,7 @@ class FileAppenderSuite extends FunSuite with BeforeAndAfter with Logging {
logInfo("Filtered files: \n" + generatedFiles.mkString("\n"))
assert(generatedFiles.size > 1)
val allText = generatedFiles.map { file =>
- Files.toString(file, Charset.forName("UTF-8"))
+ Files.toString(file, UTF_8)
}.mkString("")
assert(allText === expectedText)
generatedFiles
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index f1f88c5fd3634..aec1e409db95c 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -107,7 +107,8 @@ class JsonProtocolSuite extends FunSuite {
testJobResult(jobFailed)
// TaskEndReason
- val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19)
+ val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ "Some exception")
val exceptionFailure = ExceptionFailure("To be", "or not to be", stackTrace, None)
testTaskEndReason(Success)
testTaskEndReason(Resubmitted)
@@ -115,7 +116,7 @@ class JsonProtocolSuite extends FunSuite {
testTaskEndReason(exceptionFailure)
testTaskEndReason(TaskResultLost)
testTaskEndReason(TaskKilled)
- testTaskEndReason(ExecutorLostFailure)
+ testTaskEndReason(ExecutorLostFailure("100"))
testTaskEndReason(UnknownReason)
// BlockId
@@ -176,6 +177,17 @@ class JsonProtocolSuite extends FunSuite {
deserializedBmRemoved)
}
+ test("FetchFailed backwards compatibility") {
+ // FetchFailed in Spark 1.1.0 does not have an "Message" property.
+ val fetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ "ignored")
+ val oldEvent = JsonProtocol.taskEndReasonToJson(fetchFailed)
+ .removeField({ _._1 == "Message" })
+ val expectedFetchFailed = FetchFailed(BlockManagerId("With or", "without you", 15), 17, 18, 19,
+ "Unknown reason")
+ assert(expectedFetchFailed === JsonProtocol.taskEndReasonFromJson(oldEvent))
+ }
+
test("SparkListenerApplicationStart backwards compatibility") {
// SparkListenerApplicationStart in Spark 1.0.0 do not have an "appId" property.
val applicationStart = SparkListenerApplicationStart("test", None, 1L, "user")
@@ -184,6 +196,15 @@ class JsonProtocolSuite extends FunSuite {
assert(applicationStart === JsonProtocol.applicationStartFromJson(oldEvent))
}
+ test("ExecutorLostFailure backward compatibility") {
+ // ExecutorLostFailure in Spark 1.1.0 does not have an "Executor ID" property.
+ val executorLostFailure = ExecutorLostFailure("100")
+ val oldEvent = JsonProtocol.taskEndReasonToJson(executorLostFailure)
+ .removeField({ _._1 == "Executor ID" })
+ val expectedExecutorLostFailure = ExecutorLostFailure("Unknown")
+ assert(expectedExecutorLostFailure === JsonProtocol.taskEndReasonFromJson(oldEvent))
+ }
+
/** -------------------------- *
| Helper test running methods |
* --------------------------- */
@@ -396,6 +417,7 @@ class JsonProtocolSuite extends FunSuite {
assert(r1.mapId === r2.mapId)
assert(r1.reduceId === r2.reduceId)
assertEquals(r1.bmAddress, r2.bmAddress)
+ assert(r1.message === r2.message)
case (r1: ExceptionFailure, r2: ExceptionFailure) =>
assert(r1.className === r2.className)
assert(r1.description === r2.description)
@@ -403,7 +425,8 @@ class JsonProtocolSuite extends FunSuite {
assertOptionEquals(r1.metrics, r2.metrics, assertTaskMetricsEquals)
case (TaskResultLost, TaskResultLost) =>
case (TaskKilled, TaskKilled) =>
- case (ExecutorLostFailure, ExecutorLostFailure) =>
+ case (ExecutorLostFailure(execId1), ExecutorLostFailure(execId2)) =>
+ assert(execId1 === execId2)
case (UnknownReason, UnknownReason) =>
case _ => fail("Task end reasons don't match in types!")
}
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index ea7ef0524d1e1..8ffe3e2b139c3 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -23,7 +23,7 @@ import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStr
import java.net.{BindException, ServerSocket, URI}
import java.nio.{ByteBuffer, ByteOrder}
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.Files
import org.scalatest.FunSuite
@@ -118,7 +118,7 @@ class UtilsSuite extends FunSuite {
tmpDir2.deleteOnExit()
val f1Path = tmpDir2 + "/f1"
val f1 = new FileOutputStream(f1Path)
- f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(Charsets.UTF_8))
+ f1.write("1\n2\n3\n4\n5\n6\n7\n8\n9\n".getBytes(UTF_8))
f1.close()
// Read first few bytes
@@ -146,9 +146,9 @@ class UtilsSuite extends FunSuite {
val tmpDir = Utils.createTempDir()
tmpDir.deleteOnExit()
val files = (1 to 3).map(i => new File(tmpDir, i.toString))
- Files.write("0123456789", files(0), Charsets.UTF_8)
- Files.write("abcdefghij", files(1), Charsets.UTF_8)
- Files.write("ABCDEFGHIJ", files(2), Charsets.UTF_8)
+ Files.write("0123456789", files(0), UTF_8)
+ Files.write("abcdefghij", files(1), UTF_8)
+ Files.write("ABCDEFGHIJ", files(2), UTF_8)
// Read first few bytes in the 1st file
assert(Utils.offsetBytes(files, 0, 5) === "01234")
@@ -217,9 +217,14 @@ class UtilsSuite extends FunSuite {
test("resolveURI") {
def assertResolves(before: String, after: String, testWindows: Boolean = false): Unit = {
- assume(before.split(",").length == 1)
- assert(Utils.resolveURI(before, testWindows) === new URI(after))
- assert(Utils.resolveURI(after, testWindows) === new URI(after))
+ // This should test only single paths
+ assume(before.split(",").length === 1)
+ // Repeated invocations of resolveURI should yield the same result
+ def resolve(uri: String): String = Utils.resolveURI(uri, testWindows).toString
+ assert(resolve(after) === after)
+ assert(resolve(resolve(after)) === after)
+ assert(resolve(resolve(resolve(after))) === after)
+ // Also test resolveURIs with single paths
assert(new URI(Utils.resolveURIs(before, testWindows)) === new URI(after))
assert(new URI(Utils.resolveURIs(after, testWindows)) === new URI(after))
}
@@ -235,16 +240,27 @@ class UtilsSuite extends FunSuite {
assertResolves("file:/C:/file.txt#alias.txt", "file:/C:/file.txt#alias.txt", testWindows = true)
intercept[IllegalArgumentException] { Utils.resolveURI("file:foo") }
intercept[IllegalArgumentException] { Utils.resolveURI("file:foo:baby") }
+ }
- // Test resolving comma-delimited paths
- assert(Utils.resolveURIs("jar1,jar2") === s"file:$cwd/jar1,file:$cwd/jar2")
- assert(Utils.resolveURIs("file:/jar1,file:/jar2") === "file:/jar1,file:/jar2")
- assert(Utils.resolveURIs("hdfs:/jar1,file:/jar2,jar3") ===
- s"hdfs:/jar1,file:/jar2,file:$cwd/jar3")
- assert(Utils.resolveURIs("hdfs:/jar1,file:/jar2,jar3,jar4#jar5") ===
+ test("resolveURIs with multiple paths") {
+ def assertResolves(before: String, after: String, testWindows: Boolean = false): Unit = {
+ assume(before.split(",").length > 1)
+ assert(Utils.resolveURIs(before, testWindows) === after)
+ assert(Utils.resolveURIs(after, testWindows) === after)
+ // Repeated invocations of resolveURIs should yield the same result
+ def resolve(uri: String): String = Utils.resolveURIs(uri, testWindows)
+ assert(resolve(after) === after)
+ assert(resolve(resolve(after)) === after)
+ assert(resolve(resolve(resolve(after))) === after)
+ }
+ val cwd = System.getProperty("user.dir")
+ assertResolves("jar1,jar2", s"file:$cwd/jar1,file:$cwd/jar2")
+ assertResolves("file:/jar1,file:/jar2", "file:/jar1,file:/jar2")
+ assertResolves("hdfs:/jar1,file:/jar2,jar3", s"hdfs:/jar1,file:/jar2,file:$cwd/jar3")
+ assertResolves("hdfs:/jar1,file:/jar2,jar3,jar4#jar5",
s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:$cwd/jar4#jar5")
- assert(Utils.resolveURIs("hdfs:/jar1,file:/jar2,jar3,C:\\pi.py#py.pi", testWindows = true) ===
- s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi")
+ assertResolves("hdfs:/jar1,file:/jar2,jar3,C:\\pi.py#py.pi",
+ s"hdfs:/jar1,file:/jar2,file:$cwd/jar3,file:/C:/pi.py#py.pi", testWindows = true)
}
test("nonLocalPaths") {
@@ -339,7 +355,7 @@ class UtilsSuite extends FunSuite {
try {
System.setProperty("spark.test.fileNameLoadB", "2")
Files.write("spark.test.fileNameLoadA true\n" +
- "spark.test.fileNameLoadB 1\n", outFile, Charsets.UTF_8)
+ "spark.test.fileNameLoadB 1\n", outFile, UTF_8)
val properties = Utils.getPropertiesFromFile(outFile.getAbsolutePath)
properties
.filter { case (k, v) => k.startsWith("spark.")}
@@ -351,4 +367,15 @@ class UtilsSuite extends FunSuite {
outFile.delete()
}
}
+
+ test("timeIt with prepare") {
+ var cnt = 0
+ val prepare = () => {
+ cnt += 1
+ Thread.sleep(1000)
+ }
+ val time = Utils.timeIt(2)({}, Some(prepare))
+ require(cnt === 2, "prepare should be called twice")
+ require(time < 500, "preparation time should not count")
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
index 6fe1079c2719a..0cb1ed7397655 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.util.collection
-import java.lang.{Float => JFloat}
+import java.lang.{Float => JFloat, Integer => JInteger}
import java.util.{Arrays, Comparator}
import org.scalatest.FunSuite
@@ -30,11 +30,15 @@ class SorterSuite extends FunSuite {
val rand = new XORShiftRandom(123)
val data0 = Array.tabulate[Int](10000) { i => rand.nextInt() }
val data1 = data0.clone()
+ val data2 = data0.clone()
Arrays.sort(data0)
new Sorter(new IntArraySortDataFormat).sort(data1, 0, data1.length, Ordering.Int)
+ new Sorter(new KeyReuseIntArraySortDataFormat)
+ .sort(data2, 0, data2.length, Ordering[IntWrapper])
- data0.zip(data1).foreach { case (x, y) => assert(x === y) }
+ assert(data0.view === data1.view)
+ assert(data0.view === data2.view)
}
test("KVArraySorter") {
@@ -61,10 +65,33 @@ class SorterSuite extends FunSuite {
}
}
+ /** Runs an experiment several times. */
+ def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = {
+ if (skip) {
+ println(s"Skipped experiment $name.")
+ return
+ }
+
+ val firstTry = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare))
+ System.gc()
+
+ var i = 0
+ var next10: Long = 0
+ while (i < 10) {
+ val time = org.apache.spark.util.Utils.timeIt(1)(f, Some(prepare))
+ next10 += time
+ println(s"$name: Took $time ms")
+ i += 1
+ }
+
+ println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)")
+ }
+
/**
* This provides a simple benchmark for comparing the Sorter with Java internal sorting.
* Ideally these would be executed one at a time, each in their own JVM, so their listing
- * here is mainly to have the code.
+ * here is mainly to have the code. Running multiple tests within the same JVM session would
+ * prevent JIT inlining overridden methods and hence hurt the performance.
*
* The goal of this code is to sort an array of key-value pairs, where the array physically
* has the keys and values alternating. The basic Java sorts work only on the keys, so the
@@ -72,96 +99,167 @@ class SorterSuite extends FunSuite {
* those, while the Sorter approach can work directly on the input data format.
*
* Note that the Java implementation varies tremendously between Java 6 and Java 7, when
- * the Java sort changed from merge sort to Timsort.
+ * the Java sort changed from merge sort to TimSort.
*/
- ignore("Sorter benchmark") {
-
- /** Runs an experiment several times. */
- def runExperiment(name: String)(f: => Unit): Unit = {
- val firstTry = org.apache.spark.util.Utils.timeIt(1)(f)
- System.gc()
-
- var i = 0
- var next10: Long = 0
- while (i < 10) {
- val time = org.apache.spark.util.Utils.timeIt(1)(f)
- next10 += time
- println(s"$name: Took $time ms")
- i += 1
- }
-
- println(s"$name: ($firstTry ms first try, ${next10 / 10} ms average)")
- }
-
+ ignore("Sorter benchmark for key-value pairs") {
val numElements = 25000000 // 25 mil
val rand = new XORShiftRandom(123)
- val keys = Array.tabulate[JFloat](numElements) { i =>
- new JFloat(rand.nextFloat())
+ // Test our key-value pairs where each element is a Tuple2[Float, Integer].
+
+ val kvTuples = Array.tabulate(numElements) { i =>
+ (new JFloat(rand.nextFloat()), new JInteger(i))
}
- // Test our key-value pairs where each element is a Tuple2[Float, Integer)
- val kvTupleArray = Array.tabulate[AnyRef](numElements) { i =>
- (keys(i / 2): Float, i / 2: Int)
+ val kvTupleArray = new Array[AnyRef](numElements)
+ val prepareKvTupleArray = () => {
+ System.arraycopy(kvTuples, 0, kvTupleArray, 0, numElements)
}
- runExperiment("Tuple-sort using Arrays.sort()") {
+ runExperiment("Tuple-sort using Arrays.sort()")({
Arrays.sort(kvTupleArray, new Comparator[AnyRef] {
override def compare(x: AnyRef, y: AnyRef): Int =
- Ordering.Float.compare(x.asInstanceOf[(Float, _)]._1, y.asInstanceOf[(Float, _)]._1)
+ x.asInstanceOf[(JFloat, _)]._1.compareTo(y.asInstanceOf[(JFloat, _)]._1)
})
- }
+ }, prepareKvTupleArray)
// Test our Sorter where each element alternates between Float and Integer, non-primitive
- val keyValueArray = Array.tabulate[AnyRef](numElements * 2) { i =>
- if (i % 2 == 0) keys(i / 2) else new Integer(i / 2)
+
+ val keyValues = {
+ val data = new Array[AnyRef](numElements * 2)
+ var i = 0
+ while (i < numElements) {
+ data(2 * i) = kvTuples(i)._1
+ data(2 * i + 1) = kvTuples(i)._2
+ i += 1
+ }
+ data
}
+
+ val keyValueArray = new Array[AnyRef](numElements * 2)
+ val prepareKeyValueArray = () => {
+ System.arraycopy(keyValues, 0, keyValueArray, 0, numElements * 2)
+ }
+
val sorter = new Sorter(new KVArraySortDataFormat[JFloat, AnyRef])
- runExperiment("KV-sort using Sorter") {
- sorter.sort(keyValueArray, 0, keys.length, new Comparator[JFloat] {
- override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y)
+ runExperiment("KV-sort using Sorter")({
+ sorter.sort(keyValueArray, 0, numElements, new Comparator[JFloat] {
+ override def compare(x: JFloat, y: JFloat): Int = x.compareTo(y)
})
+ }, prepareKeyValueArray)
+ }
+
+ /**
+ * Tests for sorting with primitive keys with/without key reuse. Java's Arrays.sort is used as
+ * reference, which is expected to be faster but it can only sort a single array. Sorter can be
+ * used to sort parallel arrays.
+ *
+ * Ideally these would be executed one at a time, each in their own JVM, so their listing
+ * here is mainly to have the code. Running multiple tests within the same JVM session would
+ * prevent JIT inlining overridden methods and hence hurt the performance.
+ */
+ ignore("Sorter benchmark for primitive int array") {
+ val numElements = 25000000 // 25 mil
+ val rand = new XORShiftRandom(123)
+
+ val ints = Array.fill(numElements)(rand.nextInt())
+ val intObjects = {
+ val data = new Array[JInteger](numElements)
+ var i = 0
+ while (i < numElements) {
+ data(i) = new JInteger(ints(i))
+ i += 1
+ }
+ data
}
- // Test non-primitive sort on float array
- runExperiment("Java Arrays.sort()") {
- Arrays.sort(keys, new Comparator[JFloat] {
- override def compare(x: JFloat, y: JFloat): Int = Ordering.Float.compare(x, y)
- })
+ val intObjectArray = new Array[JInteger](numElements)
+ val prepareIntObjectArray = () => {
+ System.arraycopy(intObjects, 0, intObjectArray, 0, numElements)
}
- // Test primitive sort on float array
- val primitiveKeys = Array.tabulate[Float](numElements) { i => rand.nextFloat() }
- runExperiment("Java Arrays.sort() on primitive keys") {
- Arrays.sort(primitiveKeys)
+ runExperiment("Java Arrays.sort() on non-primitive int array")({
+ Arrays.sort(intObjectArray, new Comparator[JInteger] {
+ override def compare(x: JInteger, y: JInteger): Int = x.compareTo(y)
+ })
+ }, prepareIntObjectArray)
+
+ val intPrimitiveArray = new Array[Int](numElements)
+ val prepareIntPrimitiveArray = () => {
+ System.arraycopy(ints, 0, intPrimitiveArray, 0, numElements)
}
- }
-}
+ runExperiment("Java Arrays.sort() on primitive int array")({
+ Arrays.sort(intPrimitiveArray)
+ }, prepareIntPrimitiveArray)
-/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */
-class IntArraySortDataFormat extends SortDataFormat[Int, Array[Int]] {
- override protected def getKey(data: Array[Int], pos: Int): Int = {
- data(pos)
+ val sorterWithoutKeyReuse = new Sorter(new IntArraySortDataFormat)
+ runExperiment("Sorter without key reuse on primitive int array")({
+ sorterWithoutKeyReuse.sort(intPrimitiveArray, 0, numElements, Ordering[Int])
+ }, prepareIntPrimitiveArray)
+
+ val sorterWithKeyReuse = new Sorter(new KeyReuseIntArraySortDataFormat)
+ runExperiment("Sorter with key reuse on primitive int array")({
+ sorterWithKeyReuse.sort(intPrimitiveArray, 0, numElements, Ordering[IntWrapper])
+ }, prepareIntPrimitiveArray)
}
+}
- override protected def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = {
+abstract class AbstractIntArraySortDataFormat[K] extends SortDataFormat[K, Array[Int]] {
+
+ override def swap(data: Array[Int], pos0: Int, pos1: Int): Unit = {
val tmp = data(pos0)
data(pos0) = data(pos1)
data(pos1) = tmp
}
- override protected def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) {
+ override def copyElement(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int) {
dst(dstPos) = src(srcPos)
}
/** Copy a range of elements starting at src(srcPos) to dest, starting at destPos. */
- override protected def copyRange(src: Array[Int], srcPos: Int,
- dst: Array[Int], dstPos: Int, length: Int) {
+ override def copyRange(src: Array[Int], srcPos: Int, dst: Array[Int], dstPos: Int, length: Int) {
System.arraycopy(src, srcPos, dst, dstPos, length)
}
/** Allocates a new structure that can hold up to 'length' elements. */
- override protected def allocate(length: Int): Array[Int] = {
+ override def allocate(length: Int): Array[Int] = {
new Array[Int](length)
}
}
+
+/** Format to sort a simple Array[Int]. Could be easily generified and specialized. */
+class IntArraySortDataFormat extends AbstractIntArraySortDataFormat[Int] {
+
+ override protected def getKey(data: Array[Int], pos: Int): Int = {
+ data(pos)
+ }
+}
+
+/** Wrapper of Int for key reuse. */
+class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] {
+
+ override def compare(that: IntWrapper): Int = {
+ Ordering.Int.compare(key, that.key)
+ }
+}
+
+/** SortDataFormat for Array[Int] with reused keys. */
+class KeyReuseIntArraySortDataFormat extends AbstractIntArraySortDataFormat[IntWrapper] {
+
+ override def newKey(): IntWrapper = {
+ new IntWrapper()
+ }
+
+ override def getKey(data: Array[Int], pos: Int, reuse: IntWrapper): IntWrapper = {
+ if (reuse == null) {
+ new IntWrapper(data(pos))
+ } else {
+ reuse.key = data(pos)
+ reuse
+ }
+ }
+
+ override protected def getKey(data: Array[Int], pos: Int): IntWrapper = {
+ getKey(data, pos, null)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
index 36877476e708e..20944b62473c5 100644
--- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala
@@ -18,96 +18,523 @@
package org.apache.spark.util.random
import java.util.Random
+import scala.collection.mutable.ArrayBuffer
+import org.apache.commons.math3.distribution.PoissonDistribution
-import cern.jet.random.Poisson
-import org.scalatest.{BeforeAndAfter, FunSuite}
-import org.scalatest.mock.EasyMockSugar
-
-class RandomSamplerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar {
-
- val a = List(1, 2, 3, 4, 5, 6, 7, 8, 9)
-
- var random: Random = _
- var poisson: Poisson = _
-
- before {
- random = mock[Random]
- poisson = mock[Poisson]
- }
-
- test("BernoulliSamplerWithRange") {
- expecting {
- for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
- random.nextDouble().andReturn(x)
- }
- }
- whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.25, 0.55)
- sampler.rng = random
- assert(sampler.sample(a.iterator).toList == List(3, 4, 5))
- }
- }
-
- test("BernoulliSamplerWithRangeInverse") {
- expecting {
- for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
- random.nextDouble().andReturn(x)
- }
- }
- whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
- sampler.rng = random
- assert(sampler.sample(a.iterator).toList === List(1, 2, 6, 7, 8, 9))
- }
- }
-
- test("BernoulliSamplerWithRatio") {
- expecting {
- for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
- random.nextDouble().andReturn(x)
- }
- }
- whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.35)
- sampler.rng = random
- assert(sampler.sample(a.iterator).toList == List(1, 2, 3))
- }
- }
-
- test("BernoulliSamplerWithComplement") {
- expecting {
- for(x <- Seq(0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)) {
- random.nextDouble().andReturn(x)
- }
- }
- whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.25, 0.55, true)
- sampler.rng = random
- assert(sampler.sample(a.iterator).toList == List(1, 2, 6, 7, 8, 9))
- }
- }
-
- test("BernoulliSamplerSetSeed") {
- expecting {
- random.setSeed(10L)
- }
- whenExecuting(random) {
- val sampler = new BernoulliSampler[Int](0.2)
- sampler.rng = random
- sampler.setSeed(10L)
- }
- }
-
- test("PoissonSampler") {
- expecting {
- for(x <- Seq(0, 1, 2, 0, 1, 1, 0, 0, 0)) {
- poisson.nextInt().andReturn(x)
- }
- }
- whenExecuting(poisson) {
- val sampler = new PoissonSampler[Int](0.2)
- sampler.rng = poisson
- assert(sampler.sample(a.iterator).toList == List(2, 3, 3, 5, 6))
- }
+import org.scalatest.{FunSuite, Matchers}
+
+class RandomSamplerSuite extends FunSuite with Matchers {
+ /**
+ * My statistical testing methodology is to run a Kolmogorov-Smirnov (KS) test
+ * between the random samplers and simple reference samplers (known to work correctly).
+ * The sampling gap sizes between chosen samples should show up as having the same
+ * distributions between test and reference, if things are working properly. That is,
+ * the KS test will fail to strongly reject the null hypothesis that the distributions of
+ * sampling gaps are the same.
+ * There are no actual KS tests implemented for scala (that I can find) - and so what I
+ * have done here is pre-compute "D" - the KS statistic - that corresponds to a "weak"
+ * p-value for a particular sample size. I can then test that my measured KS stats
+ * are less than D. Computing D-values is easy, and implemented below.
+ *
+ * I used the scipy 'kstwobign' distribution to pre-compute my D value:
+ *
+ * def ksdval(q=0.1, n=1000):
+ * en = np.sqrt(float(n) / 2.0)
+ * return stats.kstwobign.isf(float(q)) / (en + 0.12 + 0.11 / en)
+ *
+ * When comparing KS stats I take the median of a small number of independent test runs
+ * to compensate for the issue that any sampled statistic will show "false positive" with
+ * some probability. Even when two distributions are the same, they will register as
+ * different 10% of the time at a p-value of 0.1
+ */
+
+ // This D value is the precomputed KS statistic for p-value 0.1, sample size 1000:
+ val sampleSize = 1000
+ val D = 0.0544280747619
+
+ // I'm not a big fan of fixing seeds, but unit testing based on running statistical tests
+ // will always fail with some nonzero probability, so I'll fix the seed to prevent these
+ // tests from generating random failure noise in CI testing, etc.
+ val rngSeed: Random = RandomSampler.newDefaultRNG
+ rngSeed.setSeed(235711)
+
+ // Reference implementation of sampling without replacement (bernoulli)
+ def sample[T](data: Iterator[T], f: Double): Iterator[T] = {
+ val rng: Random = RandomSampler.newDefaultRNG
+ rng.setSeed(rngSeed.nextLong)
+ data.filter(_ => (rng.nextDouble <= f))
+ }
+
+ // Reference implementation of sampling with replacement
+ def sampleWR[T](data: Iterator[T], f: Double): Iterator[T] = {
+ val rng = new PoissonDistribution(f)
+ rng.reseedRandomGenerator(rngSeed.nextLong)
+ data.flatMap { v => {
+ val rep = rng.sample()
+ if (rep == 0) Iterator.empty else Iterator.fill(rep)(v)
+ }}
+ }
+
+ // Returns iterator over gap lengths between samples.
+ // This function assumes input data is integers sampled from the sequence of
+ // increasing integers: {0, 1, 2, ...}. This works because that is how I generate them,
+ // and the samplers preserve their input order
+ def gaps(data: Iterator[Int]): Iterator[Int] = {
+ data.sliding(2).withPartial(false).map { x => x(1) - x(0) }
+ }
+
+ // Returns the cumulative distribution from a histogram
+ def cumulativeDist(hist: Array[Int]): Array[Double] = {
+ val n = hist.sum.toDouble
+ assert(n > 0.0)
+ hist.scanLeft(0)(_ + _).drop(1).map { _.toDouble / n }
+ }
+
+ // Returns aligned cumulative distributions from two arrays of data
+ def cumulants(d1: Array[Int], d2: Array[Int],
+ ss: Int = sampleSize): (Array[Double], Array[Double]) = {
+ assert(math.min(d1.length, d2.length) > 0)
+ assert(math.min(d1.min, d2.min) >= 0)
+ val m = 1 + math.max(d1.max, d2.max)
+ val h1 = Array.fill[Int](m)(0)
+ val h2 = Array.fill[Int](m)(0)
+ for (v <- d1) { h1(v) += 1 }
+ for (v <- d2) { h2(v) += 1 }
+ assert(h1.sum == h2.sum)
+ assert(h1.sum == ss)
+ (cumulativeDist(h1), cumulativeDist(h2))
+ }
+
+ // Computes the Kolmogorov-Smirnov 'D' statistic from two cumulative distributions
+ def KSD(cdf1: Array[Double], cdf2: Array[Double]): Double = {
+ assert(cdf1.length == cdf2.length)
+ val n = cdf1.length
+ assert(n > 0)
+ assert(cdf1(n-1) == 1.0)
+ assert(cdf2(n-1) == 1.0)
+ cdf1.zip(cdf2).map { x => Math.abs(x._1 - x._2) }.max
+ }
+
+ // Returns the median KS 'D' statistic between two samples, over (m) sampling trials
+ def medianKSD(data1: => Iterator[Int], data2: => Iterator[Int], m: Int = 5): Double = {
+ val t = Array.fill[Double](m) {
+ val (c1, c2) = cumulants(data1.take(sampleSize).toArray,
+ data2.take(sampleSize).toArray)
+ KSD(c1, c2)
+ }.sorted
+ // return the median KS statistic
+ t(m / 2)
+ }
+
+ test("utilities") {
+ val s1 = Array(0, 1, 1, 0, 2)
+ val s2 = Array(1, 0, 3, 2, 1)
+ val (c1, c2) = cumulants(s1, s2, ss = 5)
+ c1 should be (Array(0.4, 0.8, 1.0, 1.0))
+ c2 should be (Array(0.2, 0.6, 0.8, 1.0))
+ KSD(c1, c2) should be (0.2 +- 0.000001)
+ KSD(c2, c1) should be (KSD(c1, c2))
+ gaps(List(0, 1, 1, 2, 4, 11).iterator).toArray should be (Array(1, 0, 1, 2, 7))
+ }
+
+ test("sanity check medianKSD against references") {
+ var d: Double = 0.0
+
+ // should be statistically same, i.e. fail to reject null hypothesis strongly
+ d = medianKSD(gaps(sample(Iterator.from(0), 0.5)), gaps(sample(Iterator.from(0), 0.5)))
+ d should be < D
+
+ // should be statistically different - null hypothesis will have high D value,
+ // corresponding to low p-value that rejects the null hypothesis
+ d = medianKSD(gaps(sample(Iterator.from(0), 0.4)), gaps(sample(Iterator.from(0), 0.5)))
+ d should be > D
+
+ // same!
+ d = medianKSD(gaps(sampleWR(Iterator.from(0), 0.5)), gaps(sampleWR(Iterator.from(0), 0.5)))
+ d should be < D
+
+ // different!
+ d = medianKSD(gaps(sampleWR(Iterator.from(0), 0.5)), gaps(sampleWR(Iterator.from(0), 0.6)))
+ d should be > D
+ }
+
+ test("bernoulli sampling") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+
+ var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.5)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.5)))
+ d should be < D
+
+ sampler = new BernoulliSampler[Int](0.7)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.7)))
+ d should be < D
+
+ sampler = new BernoulliSampler[Int](0.9)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.9)))
+ d should be < D
+
+ // sampling at different frequencies should show up as statistically different:
+ sampler = new BernoulliSampler[Int](0.5)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.6)))
+ d should be > D
+ }
+
+ test("bernoulli sampling with gap sampling optimization") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+
+ var sampler: RandomSampler[Int, Int] = new BernoulliSampler[Int](0.01)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.01)))
+ d should be < D
+
+ sampler = new BernoulliSampler[Int](0.1)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.1)))
+ d should be < D
+
+ sampler = new BernoulliSampler[Int](0.3)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.3)))
+ d should be < D
+
+ // sampling at different frequencies should show up as statistically different:
+ sampler = new BernoulliSampler[Int](0.3)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.4)))
+ d should be > D
+ }
+
+ test("bernoulli boundary cases") {
+ val data = (1 to 100).toArray
+
+ var sampler = new BernoulliSampler[Int](0.0)
+ sampler.sample(data.iterator).toArray should be (Array.empty[Int])
+
+ sampler = new BernoulliSampler[Int](1.0)
+ sampler.sample(data.iterator).toArray should be (data)
+
+ sampler = new BernoulliSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0))
+ sampler.sample(data.iterator).toArray should be (Array.empty[Int])
+
+ sampler = new BernoulliSampler[Int](1.0 + (RandomSampler.roundingEpsilon / 2.0))
+ sampler.sample(data.iterator).toArray should be (data)
+ }
+
+ test("bernoulli data types") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+ var sampler = new BernoulliSampler[Int](0.1)
+ sampler.setSeed(rngSeed.nextLong)
+
+ // Array iterator (indexable type)
+ d = medianKSD(
+ gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).toArray.iterator)),
+ gaps(sample(Iterator.from(0), 0.1)))
+ d should be < D
+
+ // ArrayBuffer iterator (indexable type)
+ d = medianKSD(
+ gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).to[ArrayBuffer].iterator)),
+ gaps(sample(Iterator.from(0), 0.1)))
+ d should be < D
+
+ // List iterator (non-indexable type)
+ d = medianKSD(
+ gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).toList.iterator)),
+ gaps(sample(Iterator.from(0), 0.1)))
+ d should be < D
+ }
+
+ test("bernoulli clone") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d = 0.0
+ var sampler = new BernoulliSampler[Int](0.1).clone
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.1)))
+ d should be < D
+
+ sampler = new BernoulliSampler[Int](0.9).clone
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.9)))
+ d should be < D
+ }
+
+ test("bernoulli set seed") {
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+ var sampler1 = new BernoulliSampler[Int](0.2)
+ var sampler2 = new BernoulliSampler[Int](0.2)
+
+ // distributions should be identical if seeds are set same
+ sampler1.setSeed(73)
+ sampler2.setSeed(73)
+ d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0))))
+ d should be (0.0)
+
+ // should be different for different seeds
+ sampler1.setSeed(73)
+ sampler2.setSeed(37)
+ d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0))))
+ d should be > 0.0
+ d should be < D
+
+ sampler1 = new BernoulliSampler[Int](0.8)
+ sampler2 = new BernoulliSampler[Int](0.8)
+
+ // distributions should be identical if seeds are set same
+ sampler1.setSeed(73)
+ sampler2.setSeed(73)
+ d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0))))
+ d should be (0.0)
+
+ // should be different for different seeds
+ sampler1.setSeed(73)
+ sampler2.setSeed(37)
+ d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0))))
+ d should be > 0.0
+ d should be < D
+ }
+
+ test("replacement sampling") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+
+ var sampler = new PoissonSampler[Int](0.5)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.5)))
+ d should be < D
+
+ sampler = new PoissonSampler[Int](0.7)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.7)))
+ d should be < D
+
+ sampler = new PoissonSampler[Int](0.9)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.9)))
+ d should be < D
+
+ // sampling at different frequencies should show up as statistically different:
+ sampler = new PoissonSampler[Int](0.5)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.6)))
+ d should be > D
+ }
+
+ test("replacement sampling with gap sampling") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+
+ var sampler = new PoissonSampler[Int](0.01)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.01)))
+ d should be < D
+
+ sampler = new PoissonSampler[Int](0.1)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.1)))
+ d should be < D
+
+ sampler = new PoissonSampler[Int](0.3)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.3)))
+ d should be < D
+
+ // sampling at different frequencies should show up as statistically different:
+ sampler = new PoissonSampler[Int](0.3)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.4)))
+ d should be > D
+ }
+
+ test("replacement boundary cases") {
+ val data = (1 to 100).toArray
+
+ var sampler = new PoissonSampler[Int](0.0)
+ sampler.sample(data.iterator).toArray should be (Array.empty[Int])
+
+ sampler = new PoissonSampler[Int](0.0 - (RandomSampler.roundingEpsilon / 2.0))
+ sampler.sample(data.iterator).toArray should be (Array.empty[Int])
+
+ // sampling with replacement has no upper bound on sampling fraction
+ sampler = new PoissonSampler[Int](2.0)
+ sampler.sample(data.iterator).length should be > (data.length)
+ }
+
+ test("replacement data types") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+ var sampler = new PoissonSampler[Int](0.1)
+ sampler.setSeed(rngSeed.nextLong)
+
+ // Array iterator (indexable type)
+ d = medianKSD(
+ gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).toArray.iterator)),
+ gaps(sampleWR(Iterator.from(0), 0.1)))
+ d should be < D
+
+ // ArrayBuffer iterator (indexable type)
+ d = medianKSD(
+ gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).to[ArrayBuffer].iterator)),
+ gaps(sampleWR(Iterator.from(0), 0.1)))
+ d should be < D
+
+ // List iterator (non-indexable type)
+ d = medianKSD(
+ gaps(sampler.sample(Iterator.from(0).take(20*sampleSize).toList.iterator)),
+ gaps(sampleWR(Iterator.from(0), 0.1)))
+ d should be < D
+ }
+
+ test("replacement clone") {
+ // Tests expect maximum gap sampling fraction to be this value
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d = 0.0
+ var sampler = new PoissonSampler[Int](0.1).clone
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.1)))
+ d should be < D
+
+ sampler = new PoissonSampler[Int](0.9).clone
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sampleWR(Iterator.from(0), 0.9)))
+ d should be < D
+ }
+
+ test("replacement set seed") {
+ RandomSampler.defaultMaxGapSamplingFraction should be (0.4)
+
+ var d: Double = 0.0
+ var sampler1 = new PoissonSampler[Int](0.2)
+ var sampler2 = new PoissonSampler[Int](0.2)
+
+ // distributions should be identical if seeds are set same
+ sampler1.setSeed(73)
+ sampler2.setSeed(73)
+ d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0))))
+ d should be (0.0)
+
+ // should be different for different seeds
+ sampler1.setSeed(73)
+ sampler2.setSeed(37)
+ d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0))))
+ d should be > 0.0
+ d should be < D
+
+ sampler1 = new PoissonSampler[Int](0.8)
+ sampler2 = new PoissonSampler[Int](0.8)
+
+ // distributions should be identical if seeds are set same
+ sampler1.setSeed(73)
+ sampler2.setSeed(73)
+ d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0))))
+ d should be (0.0)
+
+ // should be different for different seeds
+ sampler1.setSeed(73)
+ sampler2.setSeed(37)
+ d = medianKSD(gaps(sampler1.sample(Iterator.from(0))), gaps(sampler2.sample(Iterator.from(0))))
+ d should be > 0.0
+ d should be < D
+ }
+
+ test("bernoulli partitioning sampling") {
+ var d: Double = 0.0
+
+ var sampler = new BernoulliCellSampler[Int](0.1, 0.2)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.1)))
+ d should be < D
+
+ sampler = new BernoulliCellSampler[Int](0.1, 0.2, true)
+ sampler.setSeed(rngSeed.nextLong)
+ d = medianKSD(gaps(sampler.sample(Iterator.from(0))), gaps(sample(Iterator.from(0), 0.9)))
+ d should be < D
+ }
+
+ test("bernoulli partitioning boundary cases") {
+ val data = (1 to 100).toArray
+ val d = RandomSampler.roundingEpsilon / 2.0
+
+ var sampler = new BernoulliCellSampler[Int](0.0, 0.0)
+ sampler.sample(data.iterator).toArray should be (Array.empty[Int])
+
+ sampler = new BernoulliCellSampler[Int](0.5, 0.5)
+ sampler.sample(data.iterator).toArray should be (Array.empty[Int])
+
+ sampler = new BernoulliCellSampler[Int](1.0, 1.0)
+ sampler.sample(data.iterator).toArray should be (Array.empty[Int])
+
+ sampler = new BernoulliCellSampler[Int](0.0, 1.0)
+ sampler.sample(data.iterator).toArray should be (data)
+
+ sampler = new BernoulliCellSampler[Int](0.0 - d, 1.0 + d)
+ sampler.sample(data.iterator).toArray should be (data)
+
+ sampler = new BernoulliCellSampler[Int](0.5, 0.5 - d)
+ sampler.sample(data.iterator).toArray should be (Array.empty[Int])
+ }
+
+ test("bernoulli partitioning data") {
+ val seed = rngSeed.nextLong
+ val data = (1 to 100).toArray
+
+ var sampler = new BernoulliCellSampler[Int](0.4, 0.6)
+ sampler.setSeed(seed)
+ val s1 = sampler.sample(data.iterator).toArray
+ s1.length should be > 0
+
+ sampler = new BernoulliCellSampler[Int](0.4, 0.6, true)
+ sampler.setSeed(seed)
+ val s2 = sampler.sample(data.iterator).toArray
+ s2.length should be > 0
+
+ (s1 ++ s2).sorted should be (data)
+
+ sampler = new BernoulliCellSampler[Int](0.5, 0.5)
+ sampler.sample(data.iterator).toArray should be (Array.empty[Int])
+
+ sampler = new BernoulliCellSampler[Int](0.5, 0.5, true)
+ sampler.sample(data.iterator).toArray should be (data)
+ }
+
+ test("bernoulli partitioning clone") {
+ val seed = rngSeed.nextLong
+ val data = (1 to 100).toArray
+ val base = new BernoulliCellSampler[Int](0.35, 0.65)
+
+ var sampler = base.clone
+ sampler.setSeed(seed)
+ val s1 = sampler.sample(data.iterator).toArray
+ s1.length should be > 0
+
+ sampler = base.cloneComplement
+ sampler.setSeed(seed)
+ val s2 = sampler.sample(data.iterator).toArray
+ s2.length should be > 0
+
+ (s1 ++ s2).sorted should be (data)
}
}
diff --git a/dev/run-tests b/dev/run-tests
index f47fcf66ff7e7..de607e4344453 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -140,19 +140,26 @@ CURRENT_BLOCK=$BLOCK_BUILD
{
# We always build with Hive because the PySpark Spark SQL tests need it.
- BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
+ BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-0.12.0"
- echo "[info] Building Spark with these arguments: $BUILD_MVN_PROFILE_ARGS"
# NOTE: echo "q" is needed because sbt on encountering a build file with failure
#+ (either resolution or compilation) prompts the user for input either q, r, etc
#+ to quit or retry. This echo is there to make it not block.
- # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a
+ # NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a
#+ single argument!
# QUESTION: Why doesn't 'yes "q"' work?
# QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
+ # First build with 0.12 to ensure patches do not break the hive 12 build
+ echo "[info] Compile with hive 0.12"
echo -e "q\n" \
- | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean package assembly/assembly \
+ | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean hive/compile hive-thriftserver/compile \
+ | grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
+
+ # Then build with default version(0.13.1) because tests are based on this version
+ echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS -Phive"
+ echo -e "q\n" \
+ | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive package assembly/assembly \
| grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
}
@@ -173,7 +180,7 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
if [ -n "$_SQL_TESTS_ONLY" ]; then
# This must be an array of individual arguments. Otherwise, having one long string
#+ will be interpreted as a single test, which doesn't work.
- SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "hive-thriftserver/test")
+ SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test")
else
SBT_MAVEN_TEST_ARGS=("test")
fi
diff --git a/dev/run-tests-jenkins b/dev/run-tests-jenkins
index 451f3b771cc76..87c6715153da7 100755
--- a/dev/run-tests-jenkins
+++ b/dev/run-tests-jenkins
@@ -53,9 +53,9 @@ function post_message () {
local message=$1
local data="{\"body\": \"$message\"}"
local HTTP_CODE_HEADER="HTTP Response Code: "
-
+
echo "Attempting to post to Github..."
-
+
local curl_output=$(
curl `#--dump-header -` \
--silent \
@@ -75,12 +75,12 @@ function post_message () {
echo " > data: ${data}" >&2
# exit $curl_status
fi
-
+
local api_response=$(
echo "${curl_output}" \
| grep -v -e "^${HTTP_CODE_HEADER}"
)
-
+
local http_code=$(
echo "${curl_output}" \
| grep -e "^${HTTP_CODE_HEADER}" \
@@ -92,12 +92,39 @@ function post_message () {
echo " > api_response: ${api_response}" >&2
echo " > data: ${data}" >&2
fi
-
+
if [ "$curl_status" -eq 0 ] && [ "$http_code" -eq "201" ]; then
echo " > Post successful."
fi
}
+function send_archived_logs () {
+ echo "Archiving unit tests logs..."
+
+ local log_files=$(find . -name "unit-tests.log")
+
+ if [ -z "$log_files" ]; then
+ echo "> No log files found." >&2
+ else
+ local log_archive="unit-tests-logs.tar.gz"
+ echo "$log_files" | xargs tar czf ${log_archive}
+
+ local jenkins_build_dir=${JENKINS_HOME}/jobs/${JOB_NAME}/builds/${BUILD_NUMBER}
+ local scp_output=$(scp ${log_archive} amp-jenkins-master:${jenkins_build_dir}/${log_archive})
+ local scp_status="$?"
+
+ if [ "$scp_status" -ne 0 ]; then
+ echo "Failed to send archived unit tests logs to Jenkins master." >&2
+ echo "> scp_status: ${scp_status}" >&2
+ echo "> scp_output: ${scp_output}" >&2
+ else
+ echo "> Send successful."
+ fi
+
+ rm -f ${log_archive}
+ fi
+}
+
# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR
#+ and not anything else added to master since the PR was branched.
@@ -109,7 +136,7 @@ function post_message () {
else
merge_note=" * This patch merges cleanly."
fi
-
+
source_files=$(
git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \
| grep -v -e "\/test" `# ignore files in test directories` \
@@ -144,12 +171,12 @@ function post_message () {
# post start message
{
start_message="\
- [QA tests have started](${BUILD_URL}consoleFull) for \
+ [Test build ${BUILD_DISPLAY_NAME} has started](${BUILD_URL}consoleFull) for \
PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})."
-
+
start_message="${start_message}\n${merge_note}"
# start_message="${start_message}\n${public_classes_note}"
-
+
post_message "$start_message"
}
@@ -159,7 +186,7 @@ function post_message () {
test_result="$?"
if [ "$test_result" -eq "124" ]; then
- fail_message="**[Tests timed out](${BUILD_URL}consoleFull)** \
+ fail_message="**[Test build ${BUILD_DISPLAY_NAME} timed out](${BUILD_URL}consoleFull)** \
for PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL}) \
after a configured wait of \`${TESTS_TIMEOUT}\`."
@@ -187,15 +214,17 @@ function post_message () {
else
failing_test="some tests"
fi
-
+
test_result_note=" * This patch **fails $failing_test**."
fi
+
+ send_archived_logs
}
# post end message
{
result_message="\
- [QA tests have finished](${BUILD_URL}consoleFull) for \
+ [Test build ${BUILD_DISPLAY_NAME} has finished](${BUILD_URL}consoleFull) for \
PR $ghprbPullId at commit [\`${SHORT_COMMIT_HASH}\`](${COMMIT_URL})."
result_message="${result_message}\n${test_result_note}"
diff --git a/dev/scalastyle b/dev/scalastyle
index c3b356bcb3c06..ed1b6b730af6e 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -25,7 +25,7 @@ echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-
echo -e "q\n" | sbt/sbt -Pyarn -Phadoop-2.2 -Dhadoop.version=2.2.0 yarn/scalastyle \
>> scalastyle.txt
-ERRORS=$(cat scalastyle.txt | grep -e "\")
+ERRORS=$(cat scalastyle.txt | awk '{if($1~/error/)print}')
rm scalastyle.txt
if test ! -z "$ERRORS"; then
diff --git a/docs/_config.yml b/docs/_config.yml
index f4bf242ac191b..cdea02fcffbc5 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -11,10 +11,10 @@ kramdown:
include:
- _static
-# These allow the documentation to be updated with nerw releases
+# These allow the documentation to be updated with newer releases
# of Spark, Scala, and Mesos.
-SPARK_VERSION: 1.0.0-SNAPSHOT
-SPARK_VERSION_SHORT: 1.0.0
+SPARK_VERSION: 1.2.0-SNAPSHOT
+SPARK_VERSION_SHORT: 1.2.0
SCALA_BINARY_VERSION: "2.10"
SCALA_VERSION: "2.10.4"
MESOS_VERSION: 0.18.1
diff --git a/docs/building-spark.md b/docs/building-spark.md
index b2940ee4029e8..238ddae15545e 100644
--- a/docs/building-spark.md
+++ b/docs/building-spark.md
@@ -67,11 +67,13 @@ For Apache Hadoop 2.x, 0.23.x, Cloudera CDH, and other Hadoop versions with YARN
YARN version
Profile required
-
0.23.x to 2.1.x
yarn-alpha
+
0.23.x to 2.1.x
yarn-alpha (Deprecated.)
2.2.x and later
yarn
+Note: Support for YARN-alpha API's will be removed in Spark 1.3 (see SPARK-3445).
+
Examples:
{% highlight bash %}
@@ -99,10 +101,15 @@ mvn -Pyarn-alpha -Phadoop-2.3 -Dhadoop.version=2.3.0 -Dyarn.version=0.23.7 -Dski
# Building With Hive and JDBC Support
To enable Hive integration for Spark SQL along with its JDBC server and CLI,
-add the `-Phive` profile to your existing build options.
+add the `-Phive` profile to your existing build options. By default Spark
+will build with Hive 0.13.1 bindings. You can also build for Hive 0.12.0 using
+the `-Phive-0.12.0` profile.
{% highlight bash %}
-# Apache Hadoop 2.4.X with Hive support
+# Apache Hadoop 2.4.X with Hive 13 support
mvn -Pyarn -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package
+
+# Apache Hadoop 2.4.X with Hive 12 support
+mvn -Pyarn -Phive-0.12.0 -Phadoop-2.4 -Dhadoop.version=2.4.0 -Phive -DskipTests clean package
{% endhighlight %}
# Spark Tests in Maven
@@ -192,4 +199,4 @@ To run test suites of a specific sub project as follows:
compiler. When run locally as a background process, it speeds up builds of Scala-based projects
like Spark. Developers who regularly recompile Spark with Maven will be the most interested in
Zinc. The project site gives instructions for building and running `zinc`; OS X users can
-install it using `brew install zinc`.
\ No newline at end of file
+install it using `brew install zinc`.
diff --git a/docs/configuration.md b/docs/configuration.md
index 96fa1377ec399..685101ea5c9c9 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -111,6 +111,18 @@ of the most common options to set are:
(e.g. 512m, 2g).
+
+
spark.driver.maxResultSize
+
1g
+
+ Limit of total size of serialized results of all partitions for each Spark action (e.g. collect).
+ Should be at least 1M, or 0 for unlimited. Jobs will be aborted if the total size
+ is above this limit.
+ Having a high limit may cause out-of-memory errors in driver (depends on spark.driver.memory
+ and memory overhead of objects in JVM). Setting a proper limit can protect the driver from
+ out-of-memory errors.
+
+
spark.serializer
org.apache.spark.serializer. JavaSerializer
@@ -124,12 +136,23 @@ of the most common options to set are:
org.apache.spark.Serializer.
+
+
spark.kryo.classesToRegister
+
(none)
+
+ If you use Kryo serialization, give a comma-separated list of custom class names to register
+ with Kryo.
+ See the tuning guide for more details.
+
+
spark.kryo.registrator
(none)
- If you use Kryo serialization, set this class to register your custom classes with Kryo.
- It should be set to a class that extends
+ If you use Kryo serialization, set this class to register your custom classes with Kryo. This
+ property is useful if you need to register your classes in a custom way, e.g. to specify a custom
+ field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be
+ set to a class that extends
KryoRegistrator.
See the tuning guide for more details.
@@ -348,6 +371,16 @@ Apart from these, the following properties are also available, and may be useful
map-side aggregation and there are at most this many reduce partitions.
+
+
spark.shuffle.blockTransferService
+
netty
+
+ Implementation to use for transferring shuffle and cached blocks between executors. There
+ are two implementations available: netty and nio. Netty-based
+ block transfer is intended to be simpler but equally efficient and is the default option
+ starting in 1.2.
+
+
#### Spark UI
@@ -364,7 +397,16 @@ Apart from these, the following properties are also available, and may be useful
spark.ui.retainedStages
1000
- How many stages the Spark UI remembers before garbage collecting.
+ How many stages the Spark UI and status APIs remember before garbage
+ collecting.
+
+
+
+
spark.ui.retainedJobs
+
1000
+
+ How many stages the Spark UI and status APIs remember before garbage
+ collecting.
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 7978e934fb36b..c696ae9c8e8c8 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -34,7 +34,7 @@ a given dataset, the algorithm returns the best clustering result).
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
* *epsilon* determines the distance threshold within which we consider k-means to have converged.
-## Examples
+### Examples
@@ -153,3 +153,97 @@ provided in the [Self-Contained Applications](quick-start.html#self-contained-ap
section of the Spark
Quick Start guide. Be sure to also include *spark-mllib* to your build file as
a dependency.
+
+## Streaming clustering
+
+When data arrive in a stream, we may want to estimate clusters dynamically,
+updating them as new data arrive. MLlib provides support for streaming k-means clustering,
+with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm
+uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign
+all points to their nearest cluster, compute new cluster centers, then update each cluster using:
+
+`\begin{equation}
+ c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t}
+\end{equation}`
+`\begin{equation}
+ n_{t+1} = n_t + m_t
+\end{equation}`
+
+Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned
+to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$`
+is the number of points added to the cluster in the current batch. The decay factor `$\alpha$`
+can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning;
+with `$\alpha$=0` only the most recent data will be used. This is analogous to an
+exponentially-weighted moving average.
+
+The decay can be specified using a `halfLife` parameter, which determines the
+correct decay factor `a` such that, for data acquired
+at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5.
+The unit of time can be specified either as `batches` or `points` and the update rule
+will be adjusted accordingly.
+
+### Examples
+
+This example shows how to estimate clusters on streaming data.
+
+
+
+
+
+First we import the neccessary classes.
+
+{% highlight scala %}
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.clustering.StreamingKMeans
+
+{% endhighlight %}
+
+Then we make an input stream of vectors for training, as well as a stream of labeled data
+points for testing. We assume a StreamingContext `ssc` has been created, see
+[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info.
+
+{% highlight scala %}
+
+val trainingData = ssc.textFileStream("/training/data/dir").map(Vectors.parse)
+val testData = ssc.textFileStream("/testing/data/dir").map(LabeledPoint.parse)
+
+{% endhighlight %}
+
+We create a model with random clusters and specify the number of clusters to find
+
+{% highlight scala %}
+
+val numDimensions = 3
+val numClusters = 2
+val model = new StreamingKMeans()
+ .setK(numClusters)
+ .setDecayFactor(1.0)
+ .setRandomCenters(numDimensions, 0.0)
+
+{% endhighlight %}
+
+Now register the streams for training and testing and start the job, printing
+the predicted cluster assignments on new data points as they arrive.
+
+{% highlight scala %}
+
+model.trainOn(trainingData)
+model.predictOnValues(testData).print()
+
+ssc.start()
+ssc.awaitTermination()
+
+{% endhighlight %}
+
+As you add new text files with data the cluster centers will update. Each training
+point should be formatted as `[x1, x2, x3]`, and each test data point
+should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier
+(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir`
+the model will update. Anytime a text file is placed in `/testing/data/dir`
+you will see predictions. With new data, the cluster centers will change!
+
+
+
+
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index 1511ae6dda4ed..197bc77d506c6 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -83,7 +83,7 @@ val idf = new IDF().fit(tf)
val tfidf: RDD[Vector] = idf.transform(tf)
{% endhighlight %}
-MLLib's IDF implementation provides an option for ignoring terms which occur in less than a
+MLlib's IDF implementation provides an option for ignoring terms which occur in less than a
minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature
can be used by passing the `minDocFreq` value to the IDF constructor.
@@ -95,8 +95,49 @@ tf.cache()
val idf = new IDF(minDocFreq = 2).fit(tf)
val tfidf: RDD[Vector] = idf.transform(tf)
{% endhighlight %}
+
+
+
+TF and IDF are implemented in [HashingTF](api/python/pyspark.mllib.html#pyspark.mllib.feature.HashingTF)
+and [IDF](api/python/pyspark.mllib.html#pyspark.mllib.feature.IDF).
+`HashingTF` takes an RDD of list as the input.
+Each record could be an iterable of strings or other types.
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.mllib.feature import HashingTF
+
+sc = SparkContext()
+# Load documents (one per line).
+documents = sc.textFile("...").map(lambda line: line.split(" "))
+
+hashingTF = HashingTF()
+tf = hashingTF.transform(documents)
+{% endhighlight %}
+
+While applying `HashingTF` only needs a single pass to the data, applying `IDF` needs two passes:
+first to compute the IDF vector and second to scale the term frequencies by IDF.
+{% highlight python %}
+from pyspark.mllib.feature import IDF
+
+# ... continue from the previous example
+tf.cache()
+idf = IDF().fit(tf)
+tfidf = idf.transform(tf)
+{% endhighlight %}
+
+MLLib's IDF implementation provides an option for ignoring terms which occur in less than a
+minimum number of documents. In such cases, the IDF for these terms is set to 0. This feature
+can be used by passing the `minDocFreq` value to the IDF constructor.
+
+{% highlight python %}
+# ... continue from the previous example
+tf.cache()
+idf = IDF(minDocFreq=2).fit(tf)
+tfidf = idf.transform(tf)
+{% endhighlight %}
+{% highlight python %}
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.feature import Normalizer
+
+data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+labels = data.map(lambda x: x.label)
+features = data.map(lambda x: x.features)
+
+normalizer1 = Normalizer()
+normalizer2 = Normalizer(p=float("inf"))
+
+# Each sample in data1 will be normalized using $L^2$ norm.
+data1 = labels.zip(normalizer1.transform(features))
+
+# Each sample in data2 will be normalized using $L^\infty$ norm.
+data2 = labels.zip(normalizer2.transform(features))
+{% endhighlight %}
+
diff --git a/docs/mllib-statistics.md b/docs/mllib-statistics.md
index c4632413991f1..10a5131c07414 100644
--- a/docs/mllib-statistics.md
+++ b/docs/mllib-statistics.md
@@ -197,7 +197,7 @@ print Statistics.corr(data, method="pearson")
## Stratified sampling
-Unlike the other statistics functions, which reside in MLLib, stratified sampling methods,
+Unlike the other statistics functions, which reside in MLlib, stratified sampling methods,
`sampleByKey` and `sampleByKeyExact`, can be performed on RDD's of key-value pairs. For stratified
sampling, the keys can be thought of as a label and the value as a specific attribute. For example
the key can be man or woman, or document ids, and the respective values can be the list of ages
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 368c3d0008b07..e399fecbbc78c 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -582,19 +582,27 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or
spark.sql.parquet.cacheMetadata
-
false
+
true
Turns on caching of Parquet schema metadata. Can speed up querying of static data.
spark.sql.parquet.compression.codec
-
snappy
+
gzip
Sets the compression codec use when writing Parquet files. Acceptable values include:
uncompressed, snappy, gzip, lzo.
+
+
spark.sql.hive.convertMetastoreParquet
+
true
+
+ When set to false, Spark SQL will use the Hive SerDe for parquet tables instead of the built in
+ support.
+
+
## JSON Datasets
@@ -815,7 +823,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL
Property Name
Default
Meaning
spark.sql.inMemoryColumnarStorage.compressed
-
false
+
true
When set to true Spark SQL will automatically select a compression codec for each column based
on statistics of the data.
@@ -823,7 +831,7 @@ Configuration of in-memory caching can be done using the `setConf` method on SQL
spark.sql.inMemoryColumnarStorage.batchSize
-
1000
+
10000
Controls the size of batches for columnar caching. Larger batch sizes can improve memory utilization
and compression, but risk OOMs when caching data.
@@ -841,7 +849,7 @@ that these options will be deprecated in future release as more optimizations ar
Property Name
Default
Meaning
spark.sql.autoBroadcastJoinThreshold
-
10000
+
10485760 (10 MB)
Configures the maximum size in bytes for a table that will be broadcast to all worker nodes when
performing a join. By setting this value to -1 broadcasting can be disabled. Note that currently
@@ -1215,7 +1223,7 @@ import org.apache.spark.sql._
DecimalType
-
scala.math.sql.BigDecimal
+
scala.math.BigDecimal
DecimalType
diff --git a/docs/tuning.md b/docs/tuning.md
index 8fb2a0433b1a8..9b5c9adac6a4f 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -47,24 +47,11 @@ registration requirement, but we recommend trying it in any network-intensive ap
Spark automatically includes Kryo serializers for the many commonly-used core Scala classes covered
in the AllScalaRegistrar from the [Twitter chill](https://github.com/twitter/chill) library.
-To register your own custom classes with Kryo, create a public class that extends
-[`org.apache.spark.serializer.KryoRegistrator`](api/scala/index.html#org.apache.spark.serializer.KryoRegistrator) and set the
-`spark.kryo.registrator` config property to point to it, as follows:
+To register your own custom classes with Kryo, use the `registerKryoClasses` method.
{% highlight scala %}
-import com.esotericsoftware.kryo.Kryo
-import org.apache.spark.serializer.KryoRegistrator
-
-class MyRegistrator extends KryoRegistrator {
- override def registerClasses(kryo: Kryo) {
- kryo.register(classOf[MyClass1])
- kryo.register(classOf[MyClass2])
- }
-}
-
val conf = new SparkConf().setMaster(...).setAppName(...)
-conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
-conf.set("spark.kryo.registrator", "mypackage.MyRegistrator")
+conf.registerKryoClasses(Seq(classOf[MyClass1], classOf[MyClass2]))
val sc = new SparkContext(conf)
{% endhighlight %}
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 0d6b82b4944f3..50f88f735650e 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -41,8 +41,9 @@
DEFAULT_SPARK_VERSION = "1.1.0"
+MESOS_SPARK_EC2_BRANCH = "v4"
# A URL prefix from which to fetch AMI information
-AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/v2/ami-list"
+AMI_PREFIX = "https://raw.github.com/mesos/spark-ec2/{b}/ami-list".format(b=MESOS_SPARK_EC2_BRANCH)
class UsageError(Exception):
@@ -583,7 +584,13 @@ def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key):
# NOTE: We should clone the repository before running deploy_files to
# prevent ec2-variables.sh from being overwritten
- ssh(master, opts, "rm -rf spark-ec2 && git clone https://github.com/mesos/spark-ec2.git -b v4")
+ ssh(
+ host=master,
+ opts=opts,
+ command="rm -rf spark-ec2"
+ + " && "
+ + "git clone https://github.com/mesos/spark-ec2.git -b {b}".format(b=MESOS_SPARK_EC2_BRANCH)
+ )
print "Deploying files to master..."
deploy_files(conn, "deploy.generic", opts, master_nodes, slave_nodes, modules)
diff --git a/examples/pom.xml b/examples/pom.xml
index eb49a0e5af22d..bc3291803c324 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -156,6 +156,10 @@
algebird-core_${scala.binary.version}0.1.11
+
+ org.apache.commons
+ commons-math3
+ org.scalatestscalatest_${scala.binary.version}
@@ -268,6 +272,10 @@
com.google.common.base.Optional**
+
+ org.apache.commons.math3
+ org.spark-project.commons.math3
+
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java b/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java
new file mode 100644
index 0000000000000..430e96ab14d9d
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/JavaStatusAPIDemo.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.SparkJobInfo;
+import org.apache.spark.SparkStageInfo;
+import org.apache.spark.api.java.JavaFutureAction;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+
+import java.util.Arrays;
+import java.util.List;
+
+/**
+ * Example of using Spark's status APIs from Java.
+ */
+public final class JavaStatusAPIDemo {
+
+ public static final String APP_NAME = "JavaStatusAPIDemo";
+
+ public static final class IdentityWithDelay implements Function {
+ @Override
+ public T call(T x) throws Exception {
+ Thread.sleep(2 * 1000); // 2 seconds
+ return x;
+ }
+ }
+
+ public static void main(String[] args) throws Exception {
+ SparkConf sparkConf = new SparkConf().setAppName(APP_NAME);
+ final JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ // Example of implementing a progress reporter for a simple job.
+ JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 5).map(
+ new IdentityWithDelay());
+ JavaFutureAction> jobFuture = rdd.collectAsync();
+ while (!jobFuture.isDone()) {
+ Thread.sleep(1000); // 1 second
+ List jobIds = jobFuture.jobIds();
+ if (jobIds.isEmpty()) {
+ continue;
+ }
+ int currentJobId = jobIds.get(jobIds.size() - 1);
+ SparkJobInfo jobInfo = sc.getJobInfo(currentJobId);
+ SparkStageInfo stageInfo = sc.getStageInfo(jobInfo.stageIds()[0]);
+ System.out.println(stageInfo.numTasks() + " tasks total: " + stageInfo.numActiveTasks() +
+ " active, " + stageInfo.numCompletedTasks() + " complete");
+ }
+
+ System.out.println("Job results are: " + jobFuture.get());
+ sc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java
index 8d381d4e0a943..95a430f1da234 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaALS.java
@@ -32,7 +32,7 @@
import scala.Tuple2;
/**
- * Example using MLLib ALS from Java.
+ * Example using MLlib ALS from Java.
*/
public final class JavaALS {
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
index f796123a25727..e575eedeb465c 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaKMeans.java
@@ -30,7 +30,7 @@
import org.apache.spark.mllib.linalg.Vectors;
/**
- * Example using MLLib KMeans from Java.
+ * Example using MLlib KMeans from Java.
*/
public final class JavaKMeans {
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
index 5622df5ce03ff..981bc4f0613a9 100644
--- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaCustomReceiver.java
@@ -57,7 +57,7 @@ public class JavaCustomReceiver extends Receiver {
public static void main(String[] args) {
if (args.length < 2) {
- System.err.println("Usage: JavaNetworkWordCount ");
+ System.err.println("Usage: JavaCustomReceiver ");
System.exit(1);
}
diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py
new file mode 100644
index 0000000000000..540dae785f6ea
--- /dev/null
+++ b/examples/src/main/python/mllib/dataset_example.py
@@ -0,0 +1,62 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+An example of how to use SchemaRDD as a dataset for ML. Run with::
+ bin/spark-submit examples/src/main/python/mllib/dataset_example.py
+"""
+
+import os
+import sys
+import tempfile
+import shutil
+
+from pyspark import SparkContext
+from pyspark.sql import SQLContext
+from pyspark.mllib.util import MLUtils
+from pyspark.mllib.stat import Statistics
+
+
+def summarize(dataset):
+ print "schema: %s" % dataset.schema().json()
+ labels = dataset.map(lambda r: r.label)
+ print "label average: %f" % labels.mean()
+ features = dataset.map(lambda r: r.features)
+ summary = Statistics.colStats(features)
+ print "features average: %r" % summary.mean()
+
+if __name__ == "__main__":
+ if len(sys.argv) > 2:
+ print >> sys.stderr, "Usage: dataset_example.py "
+ exit(-1)
+ sc = SparkContext(appName="DatasetExample")
+ sqlCtx = SQLContext(sc)
+ if len(sys.argv) == 2:
+ input = sys.argv[1]
+ else:
+ input = "data/mllib/sample_libsvm_data.txt"
+ points = MLUtils.loadLibSVMFile(sc, input)
+ dataset0 = sqlCtx.inferSchema(points).setName("dataset0").cache()
+ summarize(dataset0)
+ tempdir = tempfile.NamedTemporaryFile(delete=False).name
+ os.unlink(tempdir)
+ print "Save dataset as a Parquet file to %s." % tempdir
+ dataset0.saveAsParquetFile(tempdir)
+ print "Load it back and summarize it again."
+ dataset1 = sqlCtx.parquetFile(tempdir).setName("dataset1").cache()
+ summarize(dataset1)
+ shutil.rmtree(tempdir)
diff --git a/examples/src/main/python/mllib/word2vec.py b/examples/src/main/python/mllib/word2vec.py
new file mode 100644
index 0000000000000..99fef4276a369
--- /dev/null
+++ b/examples/src/main/python/mllib/word2vec.py
@@ -0,0 +1,50 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# This example uses text8 file from http://mattmahoney.net/dc/text8.zip
+# The file was downloadded, unziped and split into multiple lines using
+#
+# wget http://mattmahoney.net/dc/text8.zip
+# unzip text8.zip
+# grep -o -E '\w+(\W+\w+){0,15}' text8 > text8_lines
+# This was done so that the example can be run in local mode
+
+
+import sys
+
+from pyspark import SparkContext
+from pyspark.mllib.feature import Word2Vec
+
+USAGE = ("bin/spark-submit --driver-memory 4g "
+ "examples/src/main/python/mllib/word2vec.py text8_lines")
+
+if __name__ == "__main__":
+ if len(sys.argv) < 2:
+ print USAGE
+ sys.exit("Argument for file not provided")
+ file_path = sys.argv[1]
+ sc = SparkContext(appName='Word2Vec')
+ inp = sc.textFile(file_path).map(lambda row: row.split(" "))
+
+ word2vec = Word2Vec()
+ model = word2vec.fit(inp)
+
+ synonyms = model.findSynonyms('china', 40)
+
+ for word, cosine_distance in synonyms:
+ print "{}: {}".format(word, cosine_distance)
+ sc.stop()
diff --git a/examples/src/main/python/streaming/hdfs_wordcount.py b/examples/src/main/python/streaming/hdfs_wordcount.py
index 40faff0ccc7db..f7ffb5379681e 100644
--- a/examples/src/main/python/streaming/hdfs_wordcount.py
+++ b/examples/src/main/python/streaming/hdfs_wordcount.py
@@ -21,7 +21,7 @@
is the directory that Spark Streaming will use to find and read new text files.
To run this on your local machine on directory `localdir`, run this example
- $ bin/spark-submit examples/src/main/python/streaming/network_wordcount.py localdir
+ $ bin/spark-submit examples/src/main/python/streaming/hdfs_wordcount.py localdir
Then create a text file in `localdir` and the words in the file will get counted.
"""
diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
index 1f576319b3ca8..3d5259463003d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala
@@ -17,11 +17,7 @@
package org.apache.spark.examples
-import scala.math.sqrt
-
-import cern.colt.matrix._
-import cern.colt.matrix.linalg._
-import cern.jet.math._
+import org.apache.commons.math3.linear._
/**
* Alternating least squares matrix factorization.
@@ -30,84 +26,70 @@ import cern.jet.math._
* please refer to org.apache.spark.mllib.recommendation.ALS
*/
object LocalALS {
+
// Parameters set through command line arguments
var M = 0 // Number of movies
var U = 0 // Number of users
var F = 0 // Number of features
var ITERATIONS = 0
-
val LAMBDA = 0.01 // Regularization coefficient
- // Some COLT objects
- val factory2D = DoubleFactory2D.dense
- val factory1D = DoubleFactory1D.dense
- val algebra = Algebra.DEFAULT
- val blas = SeqBlas.seqBlas
-
- def generateR(): DoubleMatrix2D = {
- val mh = factory2D.random(M, F)
- val uh = factory2D.random(U, F)
- algebra.mult(mh, algebra.transpose(uh))
+ def generateR(): RealMatrix = {
+ val mh = randomMatrix(M, F)
+ val uh = randomMatrix(U, F)
+ mh.multiply(uh.transpose())
}
- def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
- us: Array[DoubleMatrix1D]): Double =
- {
- val r = factory2D.make(M, U)
+ def rmse(targetR: RealMatrix, ms: Array[RealVector], us: Array[RealVector]): Double = {
+ val r = new Array2DRowRealMatrix(M, U)
for (i <- 0 until M; j <- 0 until U) {
- r.set(i, j, blas.ddot(ms(i), us(j)))
+ r.setEntry(i, j, ms(i).dotProduct(us(j)))
}
- blas.daxpy(-1, targetR, r)
- val sumSqs = r.aggregate(Functions.plus, Functions.square)
- sqrt(sumSqs / (M * U))
+ val diffs = r.subtract(targetR)
+ var sumSqs = 0.0
+ for (i <- 0 until M; j <- 0 until U) {
+ val diff = diffs.getEntry(i, j)
+ sumSqs += diff * diff
+ }
+ math.sqrt(sumSqs / (M.toDouble * U.toDouble))
}
- def updateMovie(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
- R: DoubleMatrix2D) : DoubleMatrix1D =
- {
- val XtX = factory2D.make(F, F)
- val Xty = factory1D.make(F)
+ def updateMovie(i: Int, m: RealVector, us: Array[RealVector], R: RealMatrix) : RealVector = {
+ var XtX: RealMatrix = new Array2DRowRealMatrix(F, F)
+ var Xty: RealVector = new ArrayRealVector(F)
// For each user that rated the movie
for (j <- 0 until U) {
val u = us(j)
// Add u * u^t to XtX
- blas.dger(1, u, u, XtX)
+ XtX = XtX.add(u.outerProduct(u))
// Add u * rating to Xty
- blas.daxpy(R.get(i, j), u, Xty)
+ Xty = Xty.add(u.mapMultiply(R.getEntry(i, j)))
}
- // Add regularization coefs to diagonal terms
+ // Add regularization coefficients to diagonal terms
for (d <- 0 until F) {
- XtX.set(d, d, XtX.get(d, d) + LAMBDA * U)
+ XtX.addToEntry(d, d, LAMBDA * U)
}
// Solve it with Cholesky
- val ch = new CholeskyDecomposition(XtX)
- val Xty2D = factory2D.make(Xty.toArray, F)
- val solved2D = ch.solve(Xty2D)
- solved2D.viewColumn(0)
+ new CholeskyDecomposition(XtX).getSolver.solve(Xty)
}
- def updateUser(j: Int, u: DoubleMatrix1D, ms: Array[DoubleMatrix1D],
- R: DoubleMatrix2D) : DoubleMatrix1D =
- {
- val XtX = factory2D.make(F, F)
- val Xty = factory1D.make(F)
+ def updateUser(j: Int, u: RealVector, ms: Array[RealVector], R: RealMatrix) : RealVector = {
+ var XtX: RealMatrix = new Array2DRowRealMatrix(F, F)
+ var Xty: RealVector = new ArrayRealVector(F)
// For each movie that the user rated
for (i <- 0 until M) {
val m = ms(i)
// Add m * m^t to XtX
- blas.dger(1, m, m, XtX)
+ XtX = XtX.add(m.outerProduct(m))
// Add m * rating to Xty
- blas.daxpy(R.get(i, j), m, Xty)
+ Xty = Xty.add(m.mapMultiply(R.getEntry(i, j)))
}
- // Add regularization coefs to diagonal terms
+ // Add regularization coefficients to diagonal terms
for (d <- 0 until F) {
- XtX.set(d, d, XtX.get(d, d) + LAMBDA * M)
+ XtX.addToEntry(d, d, LAMBDA * M)
}
// Solve it with Cholesky
- val ch = new CholeskyDecomposition(XtX)
- val Xty2D = factory2D.make(Xty.toArray, F)
- val solved2D = ch.solve(Xty2D)
- solved2D.viewColumn(0)
+ new CholeskyDecomposition(XtX).getSolver.solve(Xty)
}
def showWarning() {
@@ -135,21 +117,28 @@ object LocalALS {
showWarning()
- printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS)
+ println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS")
val R = generateR()
// Initialize m and u randomly
- var ms = Array.fill(M)(factory1D.random(F))
- var us = Array.fill(U)(factory1D.random(F))
+ var ms = Array.fill(M)(randomVector(F))
+ var us = Array.fill(U)(randomVector(F))
// Iteratively update movies then users
for (iter <- 1 to ITERATIONS) {
- println("Iteration " + iter + ":")
+ println(s"Iteration $iter:")
ms = (0 until M).map(i => updateMovie(i, ms(i), us, R)).toArray
us = (0 until U).map(j => updateUser(j, us(j), ms, R)).toArray
println("RMSE = " + rmse(R, ms, us))
println()
}
}
+
+ private def randomVector(n: Int): RealVector =
+ new ArrayRealVector(Array.fill(n)(math.random))
+
+ private def randomMatrix(rows: Int, cols: Int): RealMatrix =
+ new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random))
+
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
index fde8ffeedf8b4..6c0ac8013ce34 100644
--- a/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/SparkALS.scala
@@ -17,11 +17,7 @@
package org.apache.spark.examples
-import scala.math.sqrt
-
-import cern.colt.matrix._
-import cern.colt.matrix.linalg._
-import cern.jet.math._
+import org.apache.commons.math3.linear._
import org.apache.spark._
@@ -32,62 +28,53 @@ import org.apache.spark._
* please refer to org.apache.spark.mllib.recommendation.ALS
*/
object SparkALS {
+
// Parameters set through command line arguments
var M = 0 // Number of movies
var U = 0 // Number of users
var F = 0 // Number of features
var ITERATIONS = 0
-
val LAMBDA = 0.01 // Regularization coefficient
- // Some COLT objects
- val factory2D = DoubleFactory2D.dense
- val factory1D = DoubleFactory1D.dense
- val algebra = Algebra.DEFAULT
- val blas = SeqBlas.seqBlas
-
- def generateR(): DoubleMatrix2D = {
- val mh = factory2D.random(M, F)
- val uh = factory2D.random(U, F)
- algebra.mult(mh, algebra.transpose(uh))
+ def generateR(): RealMatrix = {
+ val mh = randomMatrix(M, F)
+ val uh = randomMatrix(U, F)
+ mh.multiply(uh.transpose())
}
- def rmse(targetR: DoubleMatrix2D, ms: Array[DoubleMatrix1D],
- us: Array[DoubleMatrix1D]): Double =
- {
- val r = factory2D.make(M, U)
+ def rmse(targetR: RealMatrix, ms: Array[RealVector], us: Array[RealVector]): Double = {
+ val r = new Array2DRowRealMatrix(M, U)
for (i <- 0 until M; j <- 0 until U) {
- r.set(i, j, blas.ddot(ms(i), us(j)))
+ r.setEntry(i, j, ms(i).dotProduct(us(j)))
}
- blas.daxpy(-1, targetR, r)
- val sumSqs = r.aggregate(Functions.plus, Functions.square)
- sqrt(sumSqs / (M * U))
+ val diffs = r.subtract(targetR)
+ var sumSqs = 0.0
+ for (i <- 0 until M; j <- 0 until U) {
+ val diff = diffs.getEntry(i, j)
+ sumSqs += diff * diff
+ }
+ math.sqrt(sumSqs / (M.toDouble * U.toDouble))
}
- def update(i: Int, m: DoubleMatrix1D, us: Array[DoubleMatrix1D],
- R: DoubleMatrix2D) : DoubleMatrix1D =
- {
+ def update(i: Int, m: RealVector, us: Array[RealVector], R: RealMatrix) : RealVector = {
val U = us.size
- val F = us(0).size
- val XtX = factory2D.make(F, F)
- val Xty = factory1D.make(F)
+ val F = us(0).getDimension
+ var XtX: RealMatrix = new Array2DRowRealMatrix(F, F)
+ var Xty: RealVector = new ArrayRealVector(F)
// For each user that rated the movie
for (j <- 0 until U) {
val u = us(j)
// Add u * u^t to XtX
- blas.dger(1, u, u, XtX)
+ XtX = XtX.add(u.outerProduct(u))
// Add u * rating to Xty
- blas.daxpy(R.get(i, j), u, Xty)
+ Xty = Xty.add(u.mapMultiply(R.getEntry(i, j)))
}
// Add regularization coefs to diagonal terms
for (d <- 0 until F) {
- XtX.set(d, d, XtX.get(d, d) + LAMBDA * U)
+ XtX.addToEntry(d, d, LAMBDA * U)
}
// Solve it with Cholesky
- val ch = new CholeskyDecomposition(XtX)
- val Xty2D = factory2D.make(Xty.toArray, F)
- val solved2D = ch.solve(Xty2D)
- solved2D.viewColumn(0)
+ new CholeskyDecomposition(XtX).getSolver.solve(Xty)
}
def showWarning() {
@@ -118,7 +105,7 @@ object SparkALS {
showWarning()
- printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS)
+ println(s"Running with M=$M, U=$U, F=$F, iters=$ITERATIONS")
val sparkConf = new SparkConf().setAppName("SparkALS")
val sc = new SparkContext(sparkConf)
@@ -126,21 +113,21 @@ object SparkALS {
val R = generateR()
// Initialize m and u randomly
- var ms = Array.fill(M)(factory1D.random(F))
- var us = Array.fill(U)(factory1D.random(F))
+ var ms = Array.fill(M)(randomVector(F))
+ var us = Array.fill(U)(randomVector(F))
// Iteratively update movies then users
val Rc = sc.broadcast(R)
var msb = sc.broadcast(ms)
var usb = sc.broadcast(us)
for (iter <- 1 to ITERATIONS) {
- println("Iteration " + iter + ":")
+ println(s"Iteration $iter:")
ms = sc.parallelize(0 until M, slices)
.map(i => update(i, msb.value(i), usb.value, Rc.value))
.collect()
msb = sc.broadcast(ms) // Re-broadcast ms because it was updated
us = sc.parallelize(0 until U, slices)
- .map(i => update(i, usb.value(i), msb.value, algebra.transpose(Rc.value)))
+ .map(i => update(i, usb.value(i), msb.value, Rc.value.transpose()))
.collect()
usb = sc.broadcast(us) // Re-broadcast us because it was updated
println("RMSE = " + rmse(R, ms, us))
@@ -149,4 +136,11 @@ object SparkALS {
sc.stop()
}
+
+ private def randomVector(n: Int): RealVector =
+ new ArrayRealVector(Array.fill(n)(math.random))
+
+ private def randomMatrix(rows: Int, cols: Int): RealMatrix =
+ new Array2DRowRealMatrix(Array.fill(rows, cols)(math.random))
+
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
index e06f4dcd54442..e322d4ce5a745 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
@@ -18,17 +18,7 @@
package org.apache.spark.examples.bagel
import org.apache.spark._
-import org.apache.spark.SparkContext._
-import org.apache.spark.serializer.KryoRegistrator
-
import org.apache.spark.bagel._
-import org.apache.spark.bagel.Bagel._
-
-import scala.collection.mutable.ArrayBuffer
-
-import java.io.{InputStream, OutputStream, DataInputStream, DataOutputStream}
-
-import com.esotericsoftware.kryo._
class PageRankUtils extends Serializable {
def computeWithCombiner(numVertices: Long, epsilon: Double)(
@@ -99,13 +89,6 @@ class PRMessage() extends Message[String] with Serializable {
}
}
-class PRKryoRegistrator extends KryoRegistrator {
- def registerClasses(kryo: Kryo) {
- kryo.register(classOf[PRVertex])
- kryo.register(classOf[PRMessage])
- }
-}
-
class CustomPartitioner(partitions: Int) extends Partitioner {
def numPartitions = partitions
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
index e4db3ec51313d..859abedf2a55e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala
@@ -38,8 +38,7 @@ object WikipediaPageRank {
}
val sparkConf = new SparkConf()
sparkConf.setAppName("WikipediaPageRank")
- sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- sparkConf.set("spark.kryo.registrator", classOf[PRKryoRegistrator].getName)
+ sparkConf.registerKryoClasses(Array(classOf[PRVertex], classOf[PRMessage]))
val inputFile = args(0)
val threshold = args(1).toDouble
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
index 45527d9382fd0..828cffb01ca1e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala
@@ -46,10 +46,8 @@ object Analytics extends Logging {
}
val options = mutable.Map(optionsList: _*)
- val conf = new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
- .set("spark.locality.wait", "100000")
+ val conf = new SparkConf().set("spark.locality.wait", "100000")
+ GraphXUtils.registerKryoClasses(conf)
val numEPart = options.remove("numEPart").map(_.toInt).getOrElse {
println("Set the number of edge partitions using --numEPart.")
@@ -79,7 +77,7 @@ object Analytics extends Logging {
val sc = new SparkContext(conf.setAppName("PageRank(" + fname + ")"))
val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
- minEdgePartitions = numEPart,
+ numEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel).cache()
val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
@@ -112,7 +110,7 @@ object Analytics extends Logging {
val sc = new SparkContext(conf.setAppName("ConnectedComponents(" + fname + ")"))
val unpartitionedGraph = GraphLoader.edgeListFile(sc, fname,
- minEdgePartitions = numEPart,
+ numEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel).cache()
val graph = partitionStrategy.foldLeft(unpartitionedGraph)(_.partitionBy(_))
@@ -133,7 +131,7 @@ object Analytics extends Logging {
val sc = new SparkContext(conf.setAppName("TriangleCount(" + fname + ")"))
val graph = GraphLoader.edgeListFile(sc, fname,
canonicalOrientation = true,
- minEdgePartitions = numEPart,
+ numEdgePartitions = numEPart,
edgeStorageLevel = edgeStorageLevel,
vertexStorageLevel = vertexStorageLevel)
// TriangleCount requires the graph to be partitioned
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
index 5f35a5836462e..3ec20d594b784 100644
--- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala
@@ -18,7 +18,7 @@
package org.apache.spark.examples.graphx
import org.apache.spark.SparkContext._
-import org.apache.spark.graphx.PartitionStrategy
+import org.apache.spark.graphx.{GraphXUtils, PartitionStrategy}
import org.apache.spark.{SparkContext, SparkConf}
import org.apache.spark.graphx.util.GraphGenerators
import java.io.{PrintWriter, FileOutputStream}
@@ -67,7 +67,7 @@ object SynthBenchmark {
options.foreach {
case ("app", v) => app = v
- case ("niter", v) => niter = v.toInt
+ case ("niters", v) => niter = v.toInt
case ("nverts", v) => numVertices = v.toInt
case ("numEPart", v) => numEPart = Some(v.toInt)
case ("partStrategy", v) => partitionStrategy = Some(PartitionStrategy.fromString(v))
@@ -80,8 +80,7 @@ object SynthBenchmark {
val conf = new SparkConf()
.setAppName(s"GraphX Synth Benchmark (nverts = $numVertices, app = $app)")
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+ GraphXUtils.registerKryoClasses(conf)
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
new file mode 100644
index 0000000000000..f8d83f4ec7327
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import java.io.File
+
+import com.google.common.io.Files
+import scopt.OptionParser
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext, SchemaRDD}
+
+/**
+ * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options]
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DatasetExample {
+
+ case class Params(
+ input: String = "data/mllib/sample_libsvm_data.txt",
+ dataFormat: String = "libsvm") extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("DatasetExample") {
+ head("Dataset: an example app using SchemaRDD as a Dataset for ML.")
+ opt[String]("input")
+ .text(s"input path to dataset")
+ .action((x, c) => c.copy(input = x))
+ opt[String]("dataFormat")
+ .text("data format: libsvm (default), dense (deprecated in Spark v1.1)")
+ .action((x, c) => c.copy(input = x))
+ checkConfig { params =>
+ success
+ }
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+
+ val conf = new SparkConf().setAppName(s"DatasetExample with $params")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext._ // for implicit conversions
+
+ // Load input data
+ val origData: RDD[LabeledPoint] = params.dataFormat match {
+ case "dense" => MLUtils.loadLabeledPoints(sc, params.input)
+ case "libsvm" => MLUtils.loadLibSVMFile(sc, params.input)
+ }
+ println(s"Loaded ${origData.count()} instances from file: ${params.input}")
+
+ // Convert input data to SchemaRDD explicitly.
+ val schemaRDD: SchemaRDD = origData
+ println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}")
+ println(s"Converted to SchemaRDD with ${schemaRDD.count()} records")
+
+ // Select columns, using implicit conversion to SchemaRDD.
+ val labelsSchemaRDD: SchemaRDD = origData.select('label)
+ val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v }
+ val numLabels = labels.count()
+ val meanLabel = labels.fold(0.0)(_ + _) / numLabels
+ println(s"Selected label column with average value $meanLabel")
+
+ val featuresSchemaRDD: SchemaRDD = origData.select('features)
+ val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v }
+ val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${featureSummary.mean.toString}")
+
+ val tmpDir = Files.createTempDir()
+ tmpDir.deleteOnExit()
+ val outputDir = new File(tmpDir, "dataset").toString
+ println(s"Saving to $outputDir as Parquet file.")
+ schemaRDD.saveAsParquetFile(outputDir)
+
+ println(s"Loading Parquet file with UDT from $outputDir.")
+ val newDataset = sqlContext.parquetFile(outputDir)
+
+ println(s"Schema from Parquet: ${newDataset.schema.prettyJson}")
+ val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v }
+ val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, feat) => summary.add(feat),
+ (sum1, sum2) => sum1.merge(sum2))
+ println(s"Selected features column with average values:\n ${newFeaturesSummary.mean.toString}")
+
+ sc.stop()
+ }
+
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 0890e6263e165..49751a30491d0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{RandomForest, DecisionTree, impurity}
import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
@@ -62,7 +62,10 @@ object DecisionTreeRunner {
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
- fracTest: Double = 0.2) extends AbstractParams[Params]
+ fracTest: Double = 0.2,
+ useNodeIdCache: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
@@ -102,6 +105,21 @@ object DecisionTreeRunner {
.text(s"fraction of data to hold out for testing. If given option testInput, " +
s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("useNodeIdCache")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.useNodeIdCache}")
+ .action((x, c) => c.copy(useNodeIdCache = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }}")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
opt[String]("testInput")
.text(s"input path to test dataset. If given, option fracTest is ignored." +
s" default: ${defaultParams.testInput}")
@@ -236,7 +254,10 @@ object DecisionTreeRunner {
maxBins = params.maxBins,
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
- minInfoGain = params.minInfoGain)
+ minInfoGain = params.minInfoGain,
+ useNodeIdCache = params.useNodeIdCache,
+ checkpointDir = params.checkpointDir,
+ checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) {
val startTime = System.nanoTime()
val model = DecisionTree.train(training, strategy)
@@ -317,7 +338,7 @@ object DecisionTreeRunner {
/**
* Calculates the mean squared error for regression.
*/
- private def meanSquaredError(tree: RandomForestModel, data: RDD[LabeledPoint]): Double = {
+ private def meanSquaredError(tree: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
data.map { y =>
val err = tree.predict(y.features) - y.label
err * err
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
index fc6678013b932..8796c28db8a66 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MovieLensALS.scala
@@ -19,7 +19,6 @@ package org.apache.spark.examples.mllib
import scala.collection.mutable
-import com.esotericsoftware.kryo.Kryo
import org.apache.log4j.{Level, Logger}
import scopt.OptionParser
@@ -27,7 +26,6 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.recommendation.{ALS, MatrixFactorizationModel, Rating}
import org.apache.spark.rdd.RDD
-import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}
/**
* An example app for ALS on MovieLens data (http://grouplens.org/datasets/movielens/).
@@ -40,13 +38,6 @@ import org.apache.spark.serializer.{KryoSerializer, KryoRegistrator}
*/
object MovieLensALS {
- class ALSRegistrator extends KryoRegistrator {
- override def registerClasses(kryo: Kryo) {
- kryo.register(classOf[Rating])
- kryo.register(classOf[mutable.BitSet])
- }
- }
-
case class Params(
input: String = null,
kryo: Boolean = false,
@@ -108,8 +99,7 @@ object MovieLensALS {
def run(params: Params) {
val conf = new SparkConf().setAppName(s"MovieLensALS with $params")
if (params.kryo) {
- conf.set("spark.serializer", classOf[KryoSerializer].getName)
- .set("spark.kryo.registrator", classOf[ALSRegistrator].getName)
+ conf.registerKryoClasses(Array(classOf[mutable.BitSet], classOf[Rating]))
.set("spark.kryoserializer.buffer.mb", "8")
}
val sc = new SparkContext(conf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala
new file mode 100644
index 0000000000000..33e5760aed997
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingKMeans.scala
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.clustering.StreamingKMeans
+import org.apache.spark.SparkConf
+import org.apache.spark.streaming.{Seconds, StreamingContext}
+
+/**
+ * Estimate clusters on one stream of data and make predictions
+ * on another stream, where the data streams arrive as text files
+ * into two different directories.
+ *
+ * The rows of the training text files must be vector data in the form
+ * `[x1,x2,x3,...,xn]`
+ * Where n is the number of dimensions.
+ *
+ * The rows of the test text files must be labeled data in the form
+ * `(y,[x1,x2,x3,...,xn])`
+ * Where y is some identifier. n must be the same for train and test.
+ *
+ * Usage: StreamingKmeans
+ *
+ * To run on your local machine using the two directories `trainingDir` and `testDir`,
+ * with updates every 5 seconds, 2 dimensions per data point, and 3 clusters, call:
+ * $ bin/run-example \
+ * org.apache.spark.examples.mllib.StreamingKMeans trainingDir testDir 5 3 2
+ *
+ * As you add text files to `trainingDir` the clusters will continuously update.
+ * Anytime you add text files to `testDir`, you'll see predicted labels using the current model.
+ *
+ */
+object StreamingKMeans {
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ System.err.println(
+ "Usage: StreamingKMeans " +
+ "")
+ System.exit(1)
+ }
+
+ val conf = new SparkConf().setMaster("local").setAppName("StreamingLinearRegression")
+ val ssc = new StreamingContext(conf, Seconds(args(2).toLong))
+
+ val trainingData = ssc.textFileStream(args(0)).map(Vectors.parse)
+ val testData = ssc.textFileStream(args(1)).map(LabeledPoint.parse)
+
+ val model = new StreamingKMeans()
+ .setK(args(3).toInt)
+ .setDecayFactor(1.0)
+ .setRandomCenters(args(4).toInt, 0.0)
+
+ model.trainOn(trainingData)
+ model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
+
+ ssc.start()
+ ssc.awaitTermination()
+ }
+}
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
index 4b2ea45fb81d0..2de2a7926bfd1 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala
@@ -66,7 +66,7 @@ class SparkFlumeEvent() extends Externalizable {
var event : AvroFlumeEvent = new AvroFlumeEvent()
/* De-serialize from bytes. */
- def readExternal(in: ObjectInput) {
+ def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException {
val bodyLength = in.readInt()
val bodyBuff = new Array[Byte](bodyLength)
in.readFully(bodyBuff)
@@ -93,7 +93,7 @@ class SparkFlumeEvent() extends Externalizable {
}
/* Serialize to bytes. */
- def writeExternal(out: ObjectOutput) {
+ def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException {
val body = event.getBody.array()
out.writeInt(body.length)
out.write(body)
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index 32a19787a28e1..475026e8eb140 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -145,11 +145,16 @@ class FlumePollingStreamSuite extends TestSuiteBase {
outputStream.register()
ssc.start()
- writeAndVerify(Seq(channel, channel2), ssc, outputBuffer)
- assertChannelIsEmpty(channel)
- assertChannelIsEmpty(channel2)
- sink.stop()
- channel.stop()
+ try {
+ writeAndVerify(Seq(channel, channel2), ssc, outputBuffer)
+ assertChannelIsEmpty(channel)
+ assertChannelIsEmpty(channel2)
+ } finally {
+ sink.stop()
+ sink2.stop()
+ channel.stop()
+ channel2.stop()
+ }
}
def writeAndVerify(channels: Seq[MemoryChannel], ssc: StreamingContext,
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
index 5bcb96b136ed7..5267560b3e5ce 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala
@@ -82,12 +82,17 @@ class EdgeRDD[@specialized ED: ClassTag, VD: ClassTag](
this
}
- /** Persists the vertex partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */
+ /** Persists the edge partitions using `targetStorageLevel`, which defaults to MEMORY_ONLY. */
override def cache(): this.type = {
partitionsRDD.persist(targetStorageLevel)
this
}
+ /** The number of edges in the RDD. */
+ override def count(): Long = {
+ partitionsRDD.map(_._2.size.toLong).reduce(_ + _)
+ }
+
private[graphx] def mapEdgePartitions[ED2: ClassTag, VD2: ClassTag](
f: (PartitionID, EdgePartition[ED, VD]) => EdgePartition[ED2, VD2]): EdgeRDD[ED2, VD2] = {
this.withPartitionsRDD[ED2, VD2](partitionsRDD.mapPartitions({ iter =>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
index 1948c978c30bf..563c948957ecf 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala
@@ -27,10 +27,10 @@ import org.apache.spark.graphx.impl._
import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
import org.apache.spark.util.collection.OpenHashSet
-
/**
* Registers GraphX classes with Kryo for improved performance.
*/
+@deprecated("Register GraphX classes with Kryo using GraphXUtils.registerKryoClasses", "1.2.0")
class GraphKryoRegistrator extends KryoRegistrator {
def registerClasses(kryo: Kryo) {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
index f4c79365b16da..4933aecba1286 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala
@@ -48,7 +48,8 @@ object GraphLoader extends Logging {
* @param path the path to the file (e.g., /home/data/file or hdfs://file)
* @param canonicalOrientation whether to orient edges in the positive
* direction
- * @param minEdgePartitions the number of partitions for the edge RDD
+ * @param numEdgePartitions the number of partitions for the edge RDD
+ * Setting this value to -1 will use the default parallelism.
* @param edgeStorageLevel the desired storage level for the edge partitions
* @param vertexStorageLevel the desired storage level for the vertex partitions
*/
@@ -56,7 +57,7 @@ object GraphLoader extends Logging {
sc: SparkContext,
path: String,
canonicalOrientation: Boolean = false,
- minEdgePartitions: Int = 1,
+ numEdgePartitions: Int = -1,
edgeStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY,
vertexStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY)
: Graph[Int, Int] =
@@ -64,7 +65,12 @@ object GraphLoader extends Logging {
val startTime = System.currentTimeMillis
// Parse the edge data table directly into edge partitions
- val lines = sc.textFile(path, minEdgePartitions).coalesce(minEdgePartitions)
+ val lines =
+ if (numEdgePartitions > 0) {
+ sc.textFile(path, numEdgePartitions).coalesce(numEdgePartitions)
+ } else {
+ sc.textFile(path)
+ }
val edges = lines.mapPartitionsWithIndex { (pid, iter) =>
val builder = new EdgePartitionBuilder[Int, Int]
iter.foreach { line =>
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
new file mode 100644
index 0000000000000..2cb07937eaa2a
--- /dev/null
+++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.graphx
+
+import org.apache.spark.SparkConf
+
+import org.apache.spark.graphx.impl._
+import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap
+
+import org.apache.spark.util.collection.{OpenHashSet, BitSet}
+import org.apache.spark.util.BoundedPriorityQueue
+
+object GraphXUtils {
+ /**
+ * Registers classes that GraphX uses with Kryo.
+ */
+ def registerKryoClasses(conf: SparkConf) {
+ conf.registerKryoClasses(Array(
+ classOf[Edge[Object]],
+ classOf[(VertexId, Object)],
+ classOf[EdgePartition[Object, Object]],
+ classOf[BitSet],
+ classOf[VertexIdToIndexMap],
+ classOf[VertexAttributeBlock[Object]],
+ classOf[PartitionStrategy],
+ classOf[BoundedPriorityQueue[Object]],
+ classOf[EdgeDirection],
+ classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]],
+ classOf[OpenHashSet[Int]],
+ classOf[OpenHashSet[Long]]))
+ }
+}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
index 47594a800a3b1..a3e28efc75a98 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/LocalSparkContext.scala
@@ -17,9 +17,6 @@
package org.apache.spark.graphx
-import org.scalatest.Suite
-import org.scalatest.BeforeAndAfterEach
-
import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
@@ -31,8 +28,7 @@ trait LocalSparkContext {
/** Runs `f` on a new SparkContext and ensures that it is stopped afterwards. */
def withSpark[T](f: SparkContext => T) = {
val conf = new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator")
+ GraphXUtils.registerKryoClasses(conf)
val sc = new SparkContext("local", "test", conf)
try {
f(sc)
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
index 9d00f76327e4c..db1dac6160080 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala
@@ -129,9 +129,9 @@ class EdgePartitionSuite extends FunSuite {
val aList = List((0, 1, 0), (1, 0, 0), (1, 2, 0), (5, 4, 0), (5, 5, 0))
val a: EdgePartition[Int, Int] = makeEdgePartition(aList)
val javaSer = new JavaSerializer(new SparkConf())
- val kryoSer = new KryoSerializer(new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+ val conf = new SparkConf()
+ GraphXUtils.registerKryoClasses(conf)
+ val kryoSer = new KryoSerializer(conf)
for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
val aSer: EdgePartition[Int, Int] = s.deserialize(s.serialize(a))
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
index f9e771a900013..fe8304c1cdc32 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala
@@ -125,9 +125,9 @@ class VertexPartitionSuite extends FunSuite {
val verts = Set((0L, 1), (1L, 1), (2L, 1))
val vp = VertexPartition(verts.iterator)
val javaSer = new JavaSerializer(new SparkConf())
- val kryoSer = new KryoSerializer(new SparkConf()
- .set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
- .set("spark.kryo.registrator", "org.apache.spark.graphx.GraphKryoRegistrator"))
+ val conf = new SparkConf()
+ GraphXUtils.registerKryoClasses(conf)
+ val kryoSer = new KryoSerializer(conf)
for (ser <- List(javaSer, kryoSer); s = ser.newInstance()) {
val vpSer: VertexPartition[Int] = s.deserialize(s.serialize(vp))
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 696e9396f627c..87a7ddaba97f2 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -45,6 +45,11 @@
spark-streaming_${scala.binary.version}${project.version}
+
+ org.apache.spark
+ spark-sql_${scala.binary.version}
+ ${project.version}
+ org.eclipse.jettyjetty-server
@@ -71,6 +76,10 @@
+
+ org.apache.commons
+ commons-math3
+ org.scalatestscalatest_${scala.binary.version}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 9a100170b75c6..65b98a8ceea55 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.api.python
import java.io.OutputStream
-import java.util.{ArrayList => JArrayList}
+import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.language.existentials
@@ -31,8 +31,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
-import org.apache.spark.mllib.feature.Word2Vec
-import org.apache.spark.mllib.feature.Word2VecModel
+import org.apache.spark.mllib.feature._
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
@@ -73,15 +72,11 @@ class PythonMLLibAPI extends Serializable {
private def trainRegressionModel(
learner: GeneralizedLinearAlgorithm[_ <: GeneralizedLinearModel],
data: JavaRDD[LabeledPoint],
- initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
- val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector]
+ initialWeights: Vector): JList[Object] = {
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
learner.disableUncachedWarning()
val model = learner.run(data.rdd, initialWeights)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(SerDe.dumps(model.weights))
- ret.add(model.intercept: java.lang.Double)
- ret
+ List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
}
/**
@@ -92,10 +87,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
lrAlg.optimizer
@@ -114,7 +109,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
lrAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -126,7 +121,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val lassoAlg = new LassoWithSGD()
lassoAlg.optimizer
.setNumIterations(numIterations)
@@ -136,7 +131,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
lassoAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -148,7 +143,7 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte]): java.util.List[java.lang.Object] = {
+ initialWeights: Vector): JList[Object] = {
val ridgeAlg = new RidgeRegressionWithSGD()
ridgeAlg.optimizer
.setNumIterations(numIterations)
@@ -158,7 +153,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
ridgeAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -170,9 +165,9 @@ class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
SVMAlg.optimizer
@@ -191,7 +186,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
SVMAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -202,10 +197,10 @@ class PythonMLLibAPI extends Serializable {
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
- initialWeightsBA: Array[Byte],
+ initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): java.util.List[java.lang.Object] = {
+ intercept: Boolean): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
LogRegAlg.optimizer
@@ -224,7 +219,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
LogRegAlg,
data,
- initialWeightsBA)
+ initialWeights)
}
/**
@@ -232,13 +227,10 @@ class PythonMLLibAPI extends Serializable {
*/
def trainNaiveBayes(
data: JavaRDD[LabeledPoint],
- lambda: Double): java.util.List[java.lang.Object] = {
+ lambda: Double): JList[Object] = {
val model = NaiveBayes.train(data.rdd, lambda)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(Vectors.dense(model.labels))
- ret.add(Vectors.dense(model.pi))
- ret.add(model.theta)
- ret
+ List(Vectors.dense(model.labels), Vectors.dense(model.pi), model.theta).
+ map(_.asInstanceOf[Object]).asJava
}
/**
@@ -260,6 +252,21 @@ class PythonMLLibAPI extends Serializable {
return kMeansAlg.run(data.rdd)
}
+ /**
+ * A Wrapper of MatrixFactorizationModel to provide helpfer method for Python
+ */
+ private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel)
+ extends MatrixFactorizationModel(model.rank, model.userFeatures, model.productFeatures) {
+
+ def predict(userAndProducts: JavaRDD[Array[Any]]): RDD[Rating] =
+ predict(SerDe.asTupleRDD(userAndProducts.rdd))
+
+ def getUserFeatures = SerDe.fromTuple2RDD(userFeatures.asInstanceOf[RDD[(Any, Any)]])
+
+ def getProductFeatures = SerDe.fromTuple2RDD(productFeatures.asInstanceOf[RDD[(Any, Any)]])
+
+ }
+
/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
@@ -272,7 +279,7 @@ class PythonMLLibAPI extends Serializable {
iterations: Int,
lambda: Double,
blocks: Int): MatrixFactorizationModel = {
- ALS.train(ratings.rdd, rank, iterations, lambda, blocks)
+ new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks))
}
/**
@@ -288,7 +295,45 @@ class PythonMLLibAPI extends Serializable {
lambda: Double,
blocks: Int,
alpha: Double): MatrixFactorizationModel = {
- ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)
+ new MatrixFactorizationModelWrapper(
+ ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha))
+ }
+
+ /**
+ * Java stub for Normalizer.transform()
+ */
+ def normalizeVector(p: Double, vector: Vector): Vector = {
+ new Normalizer(p).transform(vector)
+ }
+
+ /**
+ * Java stub for Normalizer.transform()
+ */
+ def normalizeVector(p: Double, rdd: JavaRDD[Vector]): JavaRDD[Vector] = {
+ new Normalizer(p).transform(rdd)
+ }
+
+ /**
+ * Java stub for IDF.fit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ */
+ def fitStandardScaler(
+ withMean: Boolean,
+ withStd: Boolean,
+ data: JavaRDD[Vector]): StandardScalerModel = {
+ new StandardScaler(withMean, withStd).fit(data.rdd)
+ }
+
+ /**
+ * Java stub for IDF.fit(). This stub returns a
+ * handle to the Java object instead of the content of the Java object.
+ * Extra care needs to be taken in the Python code to ensure it gets freed on
+ * exit; see the Py4J documentation.
+ */
+ def fitIDF(minDocFreq: Int, dataset: JavaRDD[Vector]): IDFModel = {
+ new IDF(minDocFreq).fit(dataset)
}
/**
@@ -328,19 +373,25 @@ class PythonMLLibAPI extends Serializable {
model.transform(word)
}
- def findSynonyms(word: String, num: Int): java.util.List[java.lang.Object] = {
+ /**
+ * Transforms an RDD of words to its vector representation
+ * @param rdd an RDD of words
+ * @return an RDD of vector representations of words
+ */
+ def transform(rdd: JavaRDD[String]): JavaRDD[Vector] = {
+ rdd.rdd.map(model.transform)
+ }
+
+ def findSynonyms(word: String, num: Int): JList[Object] = {
val vec = transform(word)
findSynonyms(vec, num)
}
- def findSynonyms(vector: Vector, num: Int): java.util.List[java.lang.Object] = {
+ def findSynonyms(vector: Vector, num: Int): JList[Object] = {
val result = model.findSynonyms(vector, num)
val similarity = Vectors.dense(result.map(_._2))
val words = result.map(_._1)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(words)
- ret.add(similarity)
- ret
+ List(words, similarity).map(_.asInstanceOf[Object]).asJava
}
}
@@ -350,13 +401,13 @@ class PythonMLLibAPI extends Serializable {
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
* see the Py4J documentation.
* @param data Training data
- * @param categoricalFeaturesInfoJMap Categorical features info, as Java map
+ * @param categoricalFeaturesInfo Categorical features info, as Java map
*/
def trainDecisionTreeModel(
data: JavaRDD[LabeledPoint],
algoStr: String,
numClasses: Int,
- categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
+ categoricalFeaturesInfo: JMap[Int, Int],
impurityStr: String,
maxDepth: Int,
maxBins: Int,
@@ -372,7 +423,7 @@ class PythonMLLibAPI extends Serializable {
maxDepth = maxDepth,
numClassesForClassification = numClasses,
maxBins = maxBins,
- categoricalFeaturesInfo = categoricalFeaturesInfoJMap.asScala.toMap,
+ categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap,
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)
@@ -544,7 +595,7 @@ private[spark] object SerDe extends Serializable {
if (objects.length == 0 || objects.length > 3) {
out.write(Opcodes.MARK)
}
- objects.foreach(pickler.save(_))
+ objects.foreach(pickler.save)
val code = objects.length match {
case 1 => Opcodes.TUPLE1
case 2 => Opcodes.TUPLE2
@@ -673,6 +724,11 @@ private[spark] object SerDe extends Serializable {
rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
}
+ /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
+ def fromTuple2RDD(rdd: RDD[(Any, Any)]): RDD[Array[Any]] = {
+ rdd.map(x => Array(x._1, x._2))
+ }
+
/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
@@ -680,7 +736,7 @@ private[spark] object SerDe extends Serializable {
def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
jRDD.rdd.mapPartitions { iter =>
initialize() // let it called in executor
- new PythonRDD.AutoBatchedPickler(iter)
+ new SerDeUtil.AutoBatchedPickler(iter)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala
index 87bdc8558aaf5..c67a6d3ae6cce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/package.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.api
/**
- * Internal support for MLLib Python API.
+ * Internal support for MLlib Python API.
*
* @see [[org.apache.spark.mllib.api.python.PythonMLLibAPI]]
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
new file mode 100644
index 0000000000000..6189dce9b27da
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -0,0 +1,268 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.StreamingContext._
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeansModel extends MLlib's KMeansModel for streaming
+ * algorithms, so it can keep track of a continuously updated weight
+ * associated with each cluster, and also update the model by
+ * doing a single iteration of the standard k-means algorithm.
+ *
+ * The update algorithm uses the "mini-batch" KMeans rule,
+ * generalized to incorporate forgetfullness (i.e. decay).
+ * The update rule (for each cluster) is:
+ *
+ * c_t+1 = [(c_t * n_t * a) + (x_t * m_t)] / [n_t + m_t]
+ * n_t+t = n_t * a + m_t
+ *
+ * Where c_t is the previously estimated centroid for that cluster,
+ * n_t is the number of points assigned to it thus far, x_t is the centroid
+ * estimated on the current batch, and m_t is the number of points assigned
+ * to that centroid in the current batch.
+ *
+ * The decay factor 'a' scales the contribution of the clusters as estimated thus far,
+ * by applying a as a discount weighting on the current point when evaluating
+ * new incoming data. If a=1, all batches are weighted equally. If a=0, new centroids
+ * are determined entirely by recent data. Lower values correspond to
+ * more forgetting.
+ *
+ * Decay can optionally be specified by a half life and associated
+ * time unit. The time unit can either be a batch of data or a single
+ * data point. Considering data arrived at time t, the half life h is defined
+ * such that at time t + h the discount applied to the data from t is 0.5.
+ * The definition remains the same whether the time unit is given
+ * as batches or points.
+ *
+ */
+@DeveloperApi
+class StreamingKMeansModel(
+ override val clusterCenters: Array[Vector],
+ val clusterWeights: Array[Double]) extends KMeansModel(clusterCenters) with Logging {
+
+ /** Perform a k-means update on a batch of data. */
+ def update(data: RDD[Vector], decayFactor: Double, timeUnit: String): StreamingKMeansModel = {
+
+ // find nearest cluster to each point
+ val closest = data.map(point => (this.predict(point), (point, 1L)))
+
+ // get sums and counts for updating each cluster
+ val mergeContribs: ((Vector, Long), (Vector, Long)) => (Vector, Long) = (p1, p2) => {
+ BLAS.axpy(1.0, p2._1, p1._1)
+ (p1._1, p1._2 + p2._2)
+ }
+ val dim = clusterCenters(0).size
+ val pointStats: Array[(Int, (Vector, Long))] = closest
+ .aggregateByKey((Vectors.zeros(dim), 0L))(mergeContribs, mergeContribs)
+ .collect()
+
+ val discount = timeUnit match {
+ case StreamingKMeans.BATCHES => decayFactor
+ case StreamingKMeans.POINTS =>
+ val numNewPoints = pointStats.view.map { case (_, (_, n)) =>
+ n
+ }.sum
+ math.pow(decayFactor, numNewPoints)
+ }
+
+ // apply discount to weights
+ BLAS.scal(discount, Vectors.dense(clusterWeights))
+
+ // implement update rule
+ pointStats.foreach { case (label, (sum, count)) =>
+ val centroid = clusterCenters(label)
+
+ val updatedWeight = clusterWeights(label) + count
+ val lambda = count / math.max(updatedWeight, 1e-16)
+
+ clusterWeights(label) = updatedWeight
+ BLAS.scal(1.0 - lambda, centroid)
+ BLAS.axpy(lambda / count, sum, centroid)
+
+ // display the updated cluster centers
+ val display = clusterCenters(label).size match {
+ case x if x > 100 => centroid.toArray.take(100).mkString("[", ",", "...")
+ case _ => centroid.toArray.mkString("[", ",", "]")
+ }
+
+ logInfo(s"Cluster $label updated with weight $updatedWeight and centroid: $display")
+ }
+
+ // Check whether the smallest cluster is dying. If so, split the largest cluster.
+ val weightsWithIndex = clusterWeights.view.zipWithIndex
+ val (maxWeight, largest) = weightsWithIndex.maxBy(_._1)
+ val (minWeight, smallest) = weightsWithIndex.minBy(_._1)
+ if (minWeight < 1e-8 * maxWeight) {
+ logInfo(s"Cluster $smallest is dying. Split the largest cluster $largest into two.")
+ val weight = (maxWeight + minWeight) / 2.0
+ clusterWeights(largest) = weight
+ clusterWeights(smallest) = weight
+ val largestClusterCenter = clusterCenters(largest)
+ val smallestClusterCenter = clusterCenters(smallest)
+ var j = 0
+ while (j < dim) {
+ val x = largestClusterCenter(j)
+ val p = 1e-14 * math.max(math.abs(x), 1.0)
+ largestClusterCenter.toBreeze(j) = x + p
+ smallestClusterCenter.toBreeze(j) = x - p
+ j += 1
+ }
+ }
+
+ this
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * StreamingKMeans provides methods for configuring a
+ * streaming k-means analysis, training the model on streaming,
+ * and using the model to make predictions on streaming data.
+ * See KMeansModel for details on algorithm and update rules.
+ *
+ * Use a builder pattern to construct a streaming k-means analysis
+ * in an application, like:
+ *
+ * val model = new StreamingKMeans()
+ * .setDecayFactor(0.5)
+ * .setK(3)
+ * .setRandomCenters(5, 100.0)
+ * .trainOn(DStream)
+ */
+@DeveloperApi
+class StreamingKMeans(
+ var k: Int,
+ var decayFactor: Double,
+ var timeUnit: String) extends Logging {
+
+ def this() = this(2, 1.0, StreamingKMeans.BATCHES)
+
+ protected var model: StreamingKMeansModel = new StreamingKMeansModel(null, null)
+
+ /** Set the number of clusters. */
+ def setK(k: Int): this.type = {
+ this.k = k
+ this
+ }
+
+ /** Set the decay factor directly (for forgetful algorithms). */
+ def setDecayFactor(a: Double): this.type = {
+ this.decayFactor = decayFactor
+ this
+ }
+
+ /** Set the half life and time unit ("batches" or "points") for forgetful algorithms. */
+ def setHalfLife(halfLife: Double, timeUnit: String): this.type = {
+ if (timeUnit != StreamingKMeans.BATCHES && timeUnit != StreamingKMeans.POINTS) {
+ throw new IllegalArgumentException("Invalid time unit for decay: " + timeUnit)
+ }
+ this.decayFactor = math.exp(math.log(0.5) / halfLife)
+ logInfo("Setting decay factor to: %g ".format (this.decayFactor))
+ this.timeUnit = timeUnit
+ this
+ }
+
+ /** Specify initial centers directly. */
+ def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
+ model = new StreamingKMeansModel(centers, weights)
+ this
+ }
+
+ /**
+ * Initialize random centers, requiring only the number of dimensions.
+ *
+ * @param dim Number of dimensions
+ * @param weight Weight for each center
+ * @param seed Random seed
+ */
+ def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
+ val random = new XORShiftRandom(seed)
+ val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
+ val weights = Array.fill(k)(weight)
+ model = new StreamingKMeansModel(centers, weights)
+ this
+ }
+
+ /** Return the latest model. */
+ def latestModel(): StreamingKMeansModel = {
+ model
+ }
+
+ /**
+ * Update the clustering model by training on batches of data from a DStream.
+ * This operation registers a DStream for training the model,
+ * checks whether the cluster centers have been initialized,
+ * and updates the model using each batch of data from the stream.
+ *
+ * @param data DStream containing vector data
+ */
+ def trainOn(data: DStream[Vector]) {
+ assertInitialized()
+ data.foreachRDD { (rdd, time) =>
+ model = model.update(rdd, decayFactor, timeUnit)
+ }
+ }
+
+ /**
+ * Use the clustering model to make predictions on batches of data from a DStream.
+ *
+ * @param data DStream containing vector data
+ * @return DStream containing predictions
+ */
+ def predictOn(data: DStream[Vector]): DStream[Int] = {
+ assertInitialized()
+ data.map(model.predict)
+ }
+
+ /**
+ * Use the model to make predictions on the values of a DStream and carry over its keys.
+ *
+ * @param data DStream containing (key, feature vector) pairs
+ * @tparam K key type
+ * @return DStream containing the input keys and the predictions as values
+ */
+ def predictOnValues[K: ClassTag](data: DStream[(K, Vector)]): DStream[(K, Int)] = {
+ assertInitialized()
+ data.mapValues(model.predict)
+ }
+
+ /** Check whether cluster centers have been initialized. */
+ private[this] def assertInitialized(): Unit = {
+ if (model.clusterCenters == null) {
+ throw new IllegalStateException(
+ "Initial cluster centers must be set before starting predictions")
+ }
+ }
+}
+
+private[clustering] object StreamingKMeans {
+ final val BATCHES = "batches"
+ final val POINTS = "points"
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
new file mode 100644
index 0000000000000..ea10bde5fa252
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext._
+
+/**
+ * Evaluator for multilabel classification.
+ * @param predictionAndLabels an RDD of (predictions, labels) pairs,
+ * both are non-null Arrays, each with unique elements.
+ */
+class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
+
+ private lazy val numDocs: Long = predictionAndLabels.count()
+
+ private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
+ labels}.distinct().count()
+
+ /**
+ * Returns subset accuracy
+ * (for equal sets of labels)
+ */
+ lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
+ predictions.deep == labels.deep
+ }.count().toDouble / numDocs
+
+ /**
+ * Returns accuracy
+ */
+ lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.intersect(predictions).size.toDouble /
+ (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs
+
+
+ /**
+ * Returns Hamming-loss
+ */
+ lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.size + predictions.size - 2 * labels.intersect(predictions).size
+ }.sum / (numDocs * numLabels)
+
+ /**
+ * Returns document-based precision averaged by the number of documents
+ */
+ lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) =>
+ if (predictions.size > 0) {
+ predictions.intersect(labels).size.toDouble / predictions.size
+ } else {
+ 0
+ }
+ }.sum / numDocs
+
+ /**
+ * Returns document-based recall averaged by the number of documents
+ */
+ lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.intersect(predictions).size.toDouble / labels.size
+ }.sum / numDocs
+
+ /**
+ * Returns document-based f1-measure averaged by the number of documents
+ */
+ lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) =>
+ 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)
+ }.sum / numDocs
+
+ private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
+ predictions.intersect(labels)
+ }.countByValue()
+
+ private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
+ predictions.diff(labels)
+ }.countByValue()
+
+ private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
+ labels.diff(predictions)
+ }.countByValue()
+
+ /**
+ * Returns precision for a given label (category)
+ * @param label the label.
+ */
+ def precision(label: Double) = {
+ val tp = tpPerClass(label)
+ val fp = fpPerClass.getOrElse(label, 0L)
+ if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
+ }
+
+ /**
+ * Returns recall for a given label (category)
+ * @param label the label.
+ */
+ def recall(label: Double) = {
+ val tp = tpPerClass(label)
+ val fn = fnPerClass.getOrElse(label, 0L)
+ if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
+ }
+
+ /**
+ * Returns f1-measure for a given label (category)
+ * @param label the label.
+ */
+ def f1Measure(label: Double) = {
+ val p = precision(label)
+ val r = recall(label)
+ if((p + r) == 0) 0 else 2 * p * r / (p + r)
+ }
+
+ private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
+ private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp }
+ private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn }
+
+ /**
+ * Returns micro-averaged label-based precision
+ * (equals to micro-averaged document-based precision)
+ */
+ lazy val microPrecision = {
+ val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
+ sumTp.toDouble / (sumTp + sumFp)
+ }
+
+ /**
+ * Returns micro-averaged label-based recall
+ * (equals to micro-averaged document-based recall)
+ */
+ lazy val microRecall = {
+ val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
+ sumTp.toDouble / (sumTp + sumFn)
+ }
+
+ /**
+ * Returns micro-averaged label-based f1-measure
+ * (equals to micro-averaged document-based f1-measure)
+ */
+ lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
+
+ /**
+ * Returns the sequence of labels in ascending order
+ */
+ lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
new file mode 100644
index 0000000000000..93a7353e2c070
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+
+/**
+ * ::Experimental::
+ * Evaluator for ranking algorithms.
+ *
+ * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
+ */
+@Experimental
+class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
+ extends Logging with Serializable {
+
+ /**
+ * Compute the average precision of all the queries, truncated at ranking position k.
+ *
+ * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be
+ * computed as #(relevant items retrieved) / k. This formula also applies when the size of the
+ * ground truth set is less than k.
+ *
+ * If a query has an empty ground truth set, zero will be used as precision together with
+ * a log warning.
+ *
+ * See the following paper for detail:
+ *
+ * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+ *
+ * @param k the position to compute the truncated precision, must be positive
+ * @return the average precision at the first k ranking positions
+ */
+ def precisionAt(k: Int): Double = {
+ require(k > 0, "ranking position k should be positive")
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ val n = math.min(pred.length, k)
+ var i = 0
+ var cnt = 0
+ while (i < n) {
+ if (labSet.contains(pred(i))) {
+ cnt += 1
+ }
+ i += 1
+ }
+ cnt.toDouble / k
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+ /**
+ * Returns the mean average precision (MAP) of all the queries.
+ * If a query has an empty ground truth set, the average precision will be zero and a log
+ * warining is generated.
+ */
+ lazy val meanAveragePrecision: Double = {
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ var i = 0
+ var cnt = 0
+ var precSum = 0.0
+ val n = pred.length
+ while (i < n) {
+ if (labSet.contains(pred(i))) {
+ cnt += 1
+ precSum += cnt.toDouble / (i + 1)
+ }
+ i += 1
+ }
+ precSum / labSet.size
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+ /**
+ * Compute the average NDCG value of all the queries, truncated at ranking position k.
+ * The discounted cumulative gain at position k is computed as:
+ * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
+ * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
+ * implementation, the relevance value is binary.
+
+ * If a query has an empty ground truth set, zero will be used as ndcg together with
+ * a log warning.
+ *
+ * See the following paper for detail:
+ *
+ * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+ *
+ * @param k the position to compute the truncated ndcg, must be positive
+ * @return the average ndcg at the first k ranking positions
+ */
+ def ndcgAt(k: Int): Double = {
+ require(k > 0, "ranking position k should be positive")
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ val labSetSize = labSet.size
+ val n = math.min(math.max(pred.length, labSetSize), k)
+ var maxDcg = 0.0
+ var dcg = 0.0
+ var i = 0
+ while (i < n) {
+ val gain = 1.0 / math.log(i + 2)
+ if (labSet.contains(pred(i))) {
+ dcg += gain
+ }
+ if (i < labSetSize) {
+ maxDcg += gain
+ }
+ i += 1
+ }
+ dcg / maxDcg
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
new file mode 100644
index 0000000000000..693117d820580
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.Logging
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer}
+
+/**
+ * :: Experimental ::
+ * Evaluator for regression.
+ *
+ * @param predictionAndObservations an RDD of (prediction, observation) pairs.
+ */
+@Experimental
+class RegressionMetrics(predictionAndObservations: RDD[(Double, Double)]) extends Logging {
+
+ /**
+ * Use MultivariateOnlineSummarizer to calculate summary statistics of observations and errors.
+ */
+ private lazy val summary: MultivariateStatisticalSummary = {
+ val summary: MultivariateStatisticalSummary = predictionAndObservations.map {
+ case (prediction, observation) => Vectors.dense(observation, observation - prediction)
+ }.aggregate(new MultivariateOnlineSummarizer())(
+ (summary, v) => summary.add(v),
+ (sum1, sum2) => sum1.merge(sum2)
+ )
+ summary
+ }
+
+ /**
+ * Returns the explained variance regression score.
+ * explainedVariance = 1 - variance(y - \hat{y}) / variance(y)
+ * Reference: [[http://en.wikipedia.org/wiki/Explained_variation]]
+ */
+ def explainedVariance: Double = {
+ 1 - summary.variance(1) / summary.variance(0)
+ }
+
+ /**
+ * Returns the mean absolute error, which is a risk function corresponding to the
+ * expected value of the absolute error loss or l1-norm loss.
+ */
+ def meanAbsoluteError: Double = {
+ summary.normL1(1) / summary.count
+ }
+
+ /**
+ * Returns the mean squared error, which is a risk function corresponding to the
+ * expected value of the squared error loss or quadratic loss.
+ */
+ def meanSquaredError: Double = {
+ val rmse = summary.normL2(1) / math.sqrt(summary.count)
+ rmse * rmse
+ }
+
+ /**
+ * Returns the root mean squared error, which is defined as the square root of
+ * the mean squared error.
+ */
+ def rootMeanSquaredError: Double = {
+ summary.normL2(1) / math.sqrt(summary.count)
+ }
+
+ /**
+ * Returns R^2^, the coefficient of determination.
+ * Reference: [[http://en.wikipedia.org/wiki/Coefficient_of_determination]]
+ */
+ def r2: Double = {
+ 1 - math.pow(summary.normL2(1), 2) / (summary.variance(0) * (summary.count - 1))
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
index 4734251127bb4..dfad25d57c947 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Normalizer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
* :: Experimental ::
* Normalizes samples individually to unit L^p^ norm
*
- * For any 1 <= p < Double.PositiveInfinity, normalizes samples using
+ * For any 1 <= p < Double.PositiveInfinity, normalizes samples using
* sum(abs(vector).^p^)^(1/p)^ as norm.
*
* For p = Double.PositiveInfinity, max(abs(vector)) will be used as norm for normalization.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala
index 415a845332d45..7358c1c84f79c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/VectorTransformer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.feature
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
@@ -48,4 +49,14 @@ trait VectorTransformer extends Serializable {
data.map(x => this.transform(x))
}
+ /**
+ * Applies transformation on an JavaRDD[Vector].
+ *
+ * @param data JavaRDD[Vector] to be transformed.
+ * @return transformed JavaRDD[Vector].
+ */
+ def transform(data: JavaRDD[Vector]): JavaRDD[Vector] = {
+ transform(data.rdd)
+ }
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index d321994c2a651..f5f7ad613d4c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -432,7 +432,7 @@ class Word2VecModel private[mllib] (
throw new IllegalStateException(s"$word not in vocabulary")
}
}
-
+
/**
* Find synonyms of a word
* @param word a word
@@ -443,7 +443,7 @@ class Word2VecModel private[mllib] (
val vector = transform(word)
findSynonyms(vector,num)
}
-
+
/**
* Find synonyms of the vector representation of a word
* @param vector vector representation of a word
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6af225b7f49f7..ac217edc619ab 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -17,22 +17,26 @@
package org.apache.spark.mllib.linalg
-import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import java.util
+import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable}
import scala.annotation.varargs
import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
-import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.SparkException
+import org.apache.spark.mllib.util.NumericParser
+import org.apache.spark.sql.catalyst.annotation.SQLUserDefinedType
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
+import org.apache.spark.sql.catalyst.types._
/**
* Represents a numeric vector, whose index type is Int and value type is Double.
*
* Note: Users should not implement this interface.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
sealed trait Vector extends Serializable {
/**
@@ -74,6 +78,65 @@ sealed trait Vector extends Serializable {
}
}
+/**
+ * User-defined type for [[Vector]] which allows easy interaction with SQL
+ * via [[org.apache.spark.sql.SchemaRDD]].
+ */
+private[spark] class VectorUDT extends UserDefinedType[Vector] {
+
+ override def sqlType: StructType = {
+ // type: 0 = sparse, 1 = dense
+ // We only use "values" for dense vectors, and "size", "indices", and "values" for sparse
+ // vectors. The "values" field is nullable because we might want to add binary vectors later,
+ // which uses "size" and "indices", but not "values".
+ StructType(Seq(
+ StructField("type", ByteType, nullable = false),
+ StructField("size", IntegerType, nullable = true),
+ StructField("indices", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
+ }
+
+ override def serialize(obj: Any): Row = {
+ val row = new GenericMutableRow(4)
+ obj match {
+ case sv: SparseVector =>
+ row.setByte(0, 0)
+ row.setInt(1, sv.size)
+ row.update(2, sv.indices.toSeq)
+ row.update(3, sv.values.toSeq)
+ case dv: DenseVector =>
+ row.setByte(0, 1)
+ row.setNullAt(1)
+ row.setNullAt(2)
+ row.update(3, dv.values.toSeq)
+ }
+ row
+ }
+
+ override def deserialize(datum: Any): Vector = {
+ datum match {
+ case row: Row =>
+ require(row.length == 4,
+ s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
+ val tpe = row.getByte(0)
+ tpe match {
+ case 0 =>
+ val size = row.getInt(1)
+ val indices = row.getAs[Iterable[Int]](2).toArray
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new SparseVector(size, indices, values)
+ case 1 =>
+ val values = row.getAs[Iterable[Double]](3).toArray
+ new DenseVector(values)
+ }
+ }
+ }
+
+ override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
+
+ override def userClass: Class[Vector] = classOf[Vector]
+}
+
/**
* Factory methods for [[org.apache.spark.mllib.linalg.Vector]].
* We don't use the name `Vector` because Scala imports
@@ -191,6 +254,7 @@ object Vectors {
/**
* A dense vector represented by a value array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class DenseVector(val values: Array[Double]) extends Vector {
override def size: Int = values.length
@@ -215,6 +279,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
* @param indices index array, assume to be strictly increasing.
* @param values value array, must have the same length as the index array.
*/
+@SQLUserDefinedType(udt = classOf[VectorUDT])
class SparseVector(
override val size: Int,
val indices: Array[Int],
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index ec2d481dccc22..10a515af88802 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -152,7 +152,7 @@ class RowMatrix(
* storing the right singular vectors, is computed via matrix multiplication as
* U = A * (V * S^-1^), if requested by user. The actual method to use is determined
* automatically based on the cost:
- * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian
+ * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian
* matrix first and then compute its top eigenvalues and eigenvectors locally on the driver.
* This requires a single pass with O(n^2^) storage on each executor and on the driver, and
* O(n^2^ k) time on the driver.
@@ -169,7 +169,8 @@ class RowMatrix(
* @note The conditions that decide which method to use internally and the default parameters are
* subject to change.
*
- * @param k number of leading singular values to keep (0 < k <= n). It might return less than k if
+ * @param k number of leading singular values to keep (0 < k <= n).
+ * It might return less than k if
* there are numerically zero singular values or there are not enough Ritz values
* converged before the maximum number of Arnoldi update iterations is reached (in case
* that matrix A is ill-conditioned).
@@ -192,7 +193,7 @@ class RowMatrix(
/**
* The actual SVD implementation, visible for testing.
*
- * @param k number of leading singular values to keep (0 < k <= n)
+ * @param k number of leading singular values to keep (0 < k <= n)
* @param computeU whether to compute U
* @param rCond the reciprocal condition number
* @param maxIter max number of iterations (if ARPACK is used)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
index e4b436b023794..fef062e02b6ec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/NNLS.scala
@@ -79,7 +79,7 @@ private[mllib] object NNLS {
// stopping condition
def stop(step: Double, ndir: Double, nx: Double): Boolean = {
((step.isNaN) // NaN
- || (step < 1e-6) // too small or negative
+ || (step < 1e-7) // too small or negative
|| (step > 1e40) // too small; almost certainly numerical problems
|| (ndir < 1e-12 * nx) // gradient relatively too small
|| (ndir < 1e-32) // gradient absolutely too small; numerical issues may lurk
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
index 28179fbc450c0..51f9b8657c640 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala
@@ -17,8 +17,7 @@
package org.apache.spark.mllib.random
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom}
@@ -89,12 +88,13 @@ class StandardNormalGenerator extends RandomDataGenerator[Double] {
@DeveloperApi
class PoissonGenerator(val mean: Double) extends RandomDataGenerator[Double] {
- private var rng = new Poisson(mean, new DRand)
+ private var rng = new PoissonDistribution(mean)
- override def nextValue(): Double = rng.nextDouble()
+ override def nextValue(): Double = rng.sample()
override def setSeed(seed: Long) {
- rng = new Poisson(mean, new DRand(seed.toInt))
+ rng = new PoissonDistribution(mean)
+ rng.reseedRandomGenerator(seed)
}
override def copy(): PoissonGenerator = new PoissonGenerator(mean)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
index 3025d4837cab4..fab7c4405c65d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.mllib.stat
import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.{Vectors, Vector}
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector}
/**
* :: DeveloperApi ::
@@ -72,9 +72,8 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
require(n == sample.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $n but got ${sample.size}.")
- sample.toBreeze.activeIterator.foreach {
- case (_, 0.0) => // Skip explicit zero elements.
- case (i, value) =>
+ @inline def update(i: Int, value: Double) = {
+ if (value != 0.0) {
if (currMax(i) < value) {
currMax(i) = value
}
@@ -89,6 +88,24 @@ class MultivariateOnlineSummarizer extends MultivariateStatisticalSummary with S
currL1(i) += math.abs(value)
nnz(i) += 1.0
+ }
+ }
+
+ sample match {
+ case dv: DenseVector => {
+ var j = 0
+ while (j < dv.size) {
+ update(j, dv.values(j))
+ j += 1
+ }
+ }
+ case sv: SparseVector =>
+ var j = 0
+ while (j < sv.indices.size) {
+ update(sv.indices(j), sv.values(j))
+ j += 1
+ }
+ case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass)
}
totalCnt += 1
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
index 0089419c2c5d4..ea82d39b72c03 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.stat.test
import breeze.linalg.{DenseMatrix => BDM}
-import cern.jet.stat.Probability.chiSquareComplemented
+import org.apache.commons.math3.distribution.ChiSquaredDistribution
import org.apache.spark.{SparkException, Logging}
import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors}
@@ -33,7 +33,7 @@ import scala.collection.mutable
* on an input of type `Matrix` in which independence between columns is assessed.
* We also provide a method for computing the chi-squared statistic between each feature and the
* label for an input `RDD[LabeledPoint]`, return an `Array[ChiSquaredTestResult]` of size =
- * number of features in the inpuy RDD.
+ * number of features in the input RDD.
*
* Supported methods for goodness of fit: `pearson` (default)
* Supported methods for independence: `pearson` (default)
@@ -139,7 +139,7 @@ private[stat] object ChiSqTest extends Logging {
}
/*
- * Pearon's goodness of fit test on the input observed and expected counts/relative frequencies.
+ * Pearson's goodness of fit test on the input observed and expected counts/relative frequencies.
* Uniform distribution is assumed when `expected` is not passed in.
*/
def chiSquared(observed: Vector,
@@ -188,12 +188,12 @@ private[stat] object ChiSqTest extends Logging {
}
}
val df = size - 1
- val pValue = chiSquareComplemented(df, statistic)
+ val pValue = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(statistic)
new ChiSqTestResult(pValue, df, statistic, PEARSON.name, NullHypothesis.goodnessOfFit.toString)
}
/*
- * Pearon's independence test on the input contingency matrix.
+ * Pearson's independence test on the input contingency matrix.
* TODO: optimize for SparseMatrix when it becomes supported.
*/
def chiSquaredMatrix(counts: Matrix, methodName:String = PEARSON.name): ChiSqTestResult = {
@@ -238,7 +238,13 @@ private[stat] object ChiSqTest extends Logging {
j += 1
}
val df = (numCols - 1) * (numRows - 1)
- val pValue = chiSquareComplemented(df, statistic)
- new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString)
+ if (df == 0) {
+ // 1 column or 1 row. Constant distribution is independent of anything.
+ // pValue = 1.0 and statistic = 0.0 in this case.
+ new ChiSqTestResult(1.0, 0, 0.0, methodName, NullHypothesis.independence.toString)
+ } else {
+ val pValue = 1.0 - new ChiSquaredDistribution(df).cumulativeProbability(statistic)
+ new ChiSqTestResult(pValue, df, statistic, methodName, NullHypothesis.independence.toString)
+ }
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 03eeaa707715b..78acc17f901c1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
@@ -60,7 +62,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Note: random seed will not be used since numTrees = 1.
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = 0)
val rfModel = rf.train(input)
- rfModel.trees(0)
+ rfModel.weakHypotheses(0)
}
}
@@ -435,6 +437,11 @@ object DecisionTree extends Serializable with Logging {
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
+ * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
+ * each value in the array is the data point's node Id
+ * for a corresponding tree. This is used to prevent the need
+ * to pass the entire tree to the executors during
+ * the node stat aggregation phase.
*/
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
@@ -445,7 +452,8 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
nodeQueue: mutable.Queue[(Int, Node)],
- timer: TimeTracker = new TimeTracker): Unit = {
+ timer: TimeTracker = new TimeTracker,
+ nodeIdCache: Option[NodeIdCache] = None): Unit = {
/*
* The high-level descriptions of the best split optimizations are noted here.
@@ -477,6 +485,37 @@ object DecisionTree extends Serializable with Logging {
logDebug("isMulticlass = " + metadata.isMulticlass)
logDebug("isMulticlassWithCategoricalFeatures = " +
metadata.isMulticlassWithCategoricalFeatures)
+ logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
+
+ /**
+ * Performs a sequential aggregation over a partition for a particular tree and node.
+ *
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
+ * bins.
+ *
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
+ * @param nodeInfo The node info for the tree node.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+ * for each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ */
+ def nodeBinSeqOp(
+ treeIndex: Int,
+ nodeInfo: RandomForest.NodeIndexInfo,
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint]): Unit = {
+ if (nodeInfo != null) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
+ } else {
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
+ instanceWeight, featuresForNode)
+ }
+ }
+ }
/**
* Performs a sequential aggregation over a partition.
@@ -495,20 +534,25 @@ object DecisionTree extends Serializable with Logging {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
- val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
- // If the example does not reach a node in this group, then nodeIndex = null.
- if (nodeInfo != null) {
- val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val featuresForNode = nodeInfo.featureSubset
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
- if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
- } else {
- mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
- instanceWeight, featuresForNode)
- }
- }
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
+
+ agg
+ }
+
+ /**
+ * Do the same thing as binSeqOp, but with nodeIdCache.
+ */
+ def binSeqOpWithNodeIdCache(
+ agg: Array[DTStatsAggregator],
+ dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
+ }
+
agg
}
@@ -551,7 +595,26 @@ object DecisionTree extends Serializable with Logging {
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
- val nodeToBestSplits =
+
+ val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
+ input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+ }
+ } else {
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
@@ -568,7 +631,10 @@ object DecisionTree extends Serializable with Logging {
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
- }.reduceByKey((a, b) => a.merge(b))
+ }
+ }
+
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
@@ -582,6 +648,13 @@ object DecisionTree extends Serializable with Logging {
timer.stop("chooseSplits")
+ val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
+ Array.fill[mutable.Map[Int, NodeIndexUpdater]](
+ metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
+ } else {
+ null
+ }
+
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
@@ -611,6 +684,13 @@ object DecisionTree extends Serializable with Logging {
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+ if (nodeIdCache.nonEmpty) {
+ val nodeIndexUpdater = NodeIndexUpdater(
+ split = split,
+ nodeIndex = nodeIndex)
+ nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
+ }
+
// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.leftNode.get))
@@ -627,6 +707,10 @@ object DecisionTree extends Serializable with Logging {
}
}
+ if (nodeIdCache.nonEmpty) {
+ // Update the cache if needed.
+ nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
+ }
}
/**
@@ -909,32 +993,39 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
- val numSplits = metadata.numSplits(featureIndex)
- val numBins = metadata.numBins(featureIndex)
if (metadata.isContinuous(featureIndex)) {
- val numSamples = sampledInput.length
+ val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
+ val featureSplits = findSplitsForContinuousFeature(featureSamples,
+ metadata, featureIndex)
+
+ val numSplits = featureSplits.length
+ val numBins = numSplits + 1
+ logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)
- val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
- val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex)
- logDebug("stride = " + stride)
- for (splitIndex <- 0 until numSplits) {
- val sampleIndex = splitIndex * stride.toInt
- // Set threshold halfway in between 2 samples.
- val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
+
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val threshold = featureSplits(splitIndex)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List())
+ splitIndex += 1
}
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
- for (splitIndex <- 1 until numSplits) {
+
+ splitIndex = 1
+ while (splitIndex < numSplits) {
bins(featureIndex)(splitIndex) =
new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
Continuous, Double.MinValue)
+ splitIndex += 1
}
bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
} else {
+ val numSplits = metadata.numSplits(featureIndex)
+ val numBins = metadata.numBins(featureIndex)
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
@@ -1011,4 +1102,77 @@ object DecisionTree extends Serializable with Logging {
categories
}
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ * @param featureSamples feature values of each sample
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of splits
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ featureSamples: Array[Double],
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ require(metadata.isContinuous(featureIndex),
+ "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+ val splits = {
+ val numSplits = metadata.numSplits(featureIndex)
+
+ // get count for each distinct value
+ val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
+ m + ((x, m.getOrElse(x, 0) + 1))
+ }
+ // sort distinct values
+ val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+ // if possible splits is not enough or just enough, just return all possible splits
+ val possibleSplits = valueCounts.length
+ if (possibleSplits <= numSplits) {
+ valueCounts.map(_._1)
+ } else {
+ // stride between splits
+ val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+ logDebug("stride = " + stride)
+
+ // iterate `valueCount` to find splits
+ val splits = new ArrayBuffer[Double]
+ var index = 1
+ // currentCount: sum of counts of values that have been visited
+ var currentCount = valueCounts(0)._2
+ // targetCount: target value for `currentCount`.
+ // If `currentCount` is closest value to `targetCount`,
+ // then current value is a split threshold.
+ // After finding a split threshold, `targetCount` is added by stride.
+ var targetCount = stride
+ while (index < valueCounts.length) {
+ val previousCount = currentCount
+ currentCount += valueCounts(index)._2
+ val previousGap = math.abs(previousCount - targetCount)
+ val currentGap = math.abs(currentCount - targetCount)
+ // If adding count of current value to currentCount
+ // makes the gap between currentCount and targetCount smaller,
+ // previous value is a split threshold.
+ if (previousGap < currentGap) {
+ splits.append(valueCounts(index - 1)._1)
+ targetCount += stride
+ }
+ index += 1
+ }
+
+ splits.toArray
+ }
+ }
+
+ assert(splits.length > 0)
+ // set number of splits accordingly
+ metadata.setNumSplits(featureIndex, splits.length)
+
+ splits
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
new file mode 100644
index 0000000000000..1a847201ce157
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoosting.scala
@@ -0,0 +1,314 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import scala.collection.JavaConverters._
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.tree.configuration.{Strategy, BoostingStrategy}
+import org.apache.spark.Logging
+import org.apache.spark.mllib.tree.impl.TimeTracker
+import org.apache.spark.mllib.tree.loss.Losses
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Sum
+
+/**
+ * :: Experimental ::
+ * A class that implements gradient boosting for regression and binary classification problems.
+ * @param boostingStrategy Parameters for the gradient boosting algorithm
+ */
+@Experimental
+class GradientBoosting (
+ private val boostingStrategy: BoostingStrategy) extends Serializable with Logging {
+
+ /**
+ * Method to train a gradient boosting model
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
+ val algo = boostingStrategy.algo
+ algo match {
+ case Regression => GradientBoosting.boost(input, boostingStrategy)
+ case Classification =>
+ val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ GradientBoosting.boost(remappedInput, boostingStrategy)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+ }
+ }
+
+}
+
+
+object GradientBoosting extends Logging {
+
+ /**
+ * Method to train a gradient boosting model.
+ *
+ * Note: Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
+ * is recommended to clearly specify regression.
+ * Using [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
+ * is recommended to clearly specify regression.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def train(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting classification model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ val algo = boostingStrategy.algo
+ require(algo == Classification, s"Only Classification algo supported. Provided algo is $algo.")
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting regression model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param boostingStrategy Configuration options for the boosting algorithm.
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+ val algo = boostingStrategy.algo
+ require(algo == Regression, s"Only Regression algo supported. Provided algo is $algo.")
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting binary classification model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param numEstimators Number of estimators used in boosting stages. In other words,
+ * number of boosting iterations performed.
+ * @param loss Loss function used for minimization during gradient boosting.
+ * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+ * learning rate should be between in the interval (0, 1]
+ * @param subsamplingRate Fraction of the training data used for learning the decision tree.
+ * @param numClassesForClassification Number of classes for classification.
+ * (Ignored for regression.)
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+ * the number of discrete values they take. For example,
+ * an entry (n -> k) implies the feature n is categorical with k
+ * categories 0, 1, 2, ... , k-1. It's important to note that
+ * features are zero-indexed.
+ * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
+ * supported.)
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ numEstimators: Int,
+ loss: String,
+ learningRate: Double,
+ subsamplingRate: Double,
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: Map[Int, Int],
+ weakLearnerParams: Strategy): WeightedEnsembleModel = {
+ val lossType = Losses.fromString(loss)
+ val boostingStrategy = new BoostingStrategy(Classification, numEstimators, lossType,
+ learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
+ weakLearnerParams)
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Method to train a gradient boosting regression model.
+ *
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * For classification, labels should take values {0, 1, ..., numClasses-1}.
+ * For regression, labels are real numbers.
+ * @param numEstimators Number of estimators used in boosting stages. In other words,
+ * number of boosting iterations performed.
+ * @param loss Loss function used for minimization during gradient boosting.
+ * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+ * learning rate should be between in the interval (0, 1]
+ * @param subsamplingRate Fraction of the training data used for learning the decision tree.
+ * @param numClassesForClassification Number of classes for classification.
+ * (Ignored for regression.)
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and
+ * the number of discrete values they take. For example,
+ * an entry (n -> k) implies the feature n is categorical with k
+ * categories 0, 1, 2, ... , k-1. It's important to note that
+ * features are zero-indexed.
+ * @param weakLearnerParams Parameters for the weak learner. (Currently only decision tree is
+ * supported.)
+ * @return WeightedEnsembleModel that can be used for prediction
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ numEstimators: Int,
+ loss: String,
+ learningRate: Double,
+ subsamplingRate: Double,
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: Map[Int, Int],
+ weakLearnerParams: Strategy): WeightedEnsembleModel = {
+ val lossType = Losses.fromString(loss)
+ val boostingStrategy = new BoostingStrategy(Regression, numEstimators, lossType,
+ learningRate, subsamplingRate, numClassesForClassification, categoricalFeaturesInfo,
+ weakLearnerParams)
+ new GradientBoosting(boostingStrategy).train(input)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainClassifier]]
+ */
+ def trainClassifier(
+ input: RDD[LabeledPoint],
+ numEstimators: Int,
+ loss: String,
+ learningRate: Double,
+ subsamplingRate: Double,
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo:java.util.Map[java.lang.Integer, java.lang.Integer],
+ weakLearnerParams: Strategy): WeightedEnsembleModel = {
+ trainClassifier(input, numEstimators, loss, learningRate, subsamplingRate,
+ numClassesForClassification,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ weakLearnerParams)
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoosting$#trainRegressor]]
+ */
+ def trainRegressor(
+ input: RDD[LabeledPoint],
+ numEstimators: Int,
+ loss: String,
+ learningRate: Double,
+ subsamplingRate: Double,
+ numClassesForClassification: Int,
+ categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+ weakLearnerParams: Strategy): WeightedEnsembleModel = {
+ trainRegressor(input, numEstimators, loss, learningRate, subsamplingRate,
+ numClassesForClassification,
+ categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+ weakLearnerParams)
+ }
+
+
+ /**
+ * Internal method for performing regression using trees as base learners.
+ * @param input training dataset
+ * @param boostingStrategy boosting parameters
+ * @return
+ */
+ private def boost(
+ input: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy): WeightedEnsembleModel = {
+
+ val timer = new TimeTracker()
+ timer.start("total")
+ timer.start("init")
+
+ // Initialize gradient boosting parameters
+ val numEstimators = boostingStrategy.numEstimators
+ val baseLearners = new Array[DecisionTreeModel](numEstimators)
+ val baseLearnerWeights = new Array[Double](numEstimators)
+ val loss = boostingStrategy.loss
+ val learningRate = boostingStrategy.learningRate
+ val strategy = boostingStrategy.weakLearnerParams
+
+ // Cache input
+ input.persist(StorageLevel.MEMORY_AND_DISK)
+
+ timer.stop("init")
+
+ logDebug("##########")
+ logDebug("Building tree 0")
+ logDebug("##########")
+ var data = input
+
+ // 1. Initialize tree
+ timer.start("building tree 0")
+ val firstTreeModel = new DecisionTree(strategy).train(data)
+ baseLearners(0) = firstTreeModel
+ baseLearnerWeights(0) = 1.0
+ val startingModel = new WeightedEnsembleModel(Array(firstTreeModel), Array(1.0), Regression,
+ Sum)
+ logDebug("error of gbt = " + loss.computeError(startingModel, input))
+ // Note: A model of type regression is used since we require raw prediction
+ timer.stop("building tree 0")
+
+ // psuedo-residual for second iteration
+ data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
+ point.features))
+
+ var m = 1
+ while (m < numEstimators) {
+ timer.start(s"building tree $m")
+ logDebug("###################################################")
+ logDebug("Gradient boosting tree iteration " + m)
+ logDebug("###################################################")
+ val model = new DecisionTree(strategy).train(data)
+ timer.stop(s"building tree $m")
+ // Create partial model
+ baseLearners(m) = model
+ baseLearnerWeights(m) = learningRate
+ // Note: A model of type regression is used since we require raw prediction
+ val partialModel = new WeightedEnsembleModel(baseLearners.slice(0, m + 1),
+ baseLearnerWeights.slice(0, m + 1), Regression, Sum)
+ logDebug("error of gbt = " + loss.computeError(partialModel, input))
+ // Update data with pseudo-residuals
+ data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
+ point.features))
+ m += 1
+ }
+
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
+
+ // 3. Output classifier
+ new WeightedEnsembleModel(baseLearners, baseLearnerWeights, boostingStrategy.algo, Sum)
+
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index ebbd8e0257209..9683916d9b3f1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -26,8 +26,9 @@ import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
@@ -59,7 +60,7 @@ import org.apache.spark.util.Utils
* if numTrees == 1, set to "all";
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
- * @param seed Random seed for bootstrapping and choosing feature subsets.
+ * @param seed Random seed for bootstrapping and choosing feature subsets.
*/
@Experimental
private class RandomForest (
@@ -78,9 +79,9 @@ private class RandomForest (
/**
* Method to train a decision tree model over an RDD
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
- def train(input: RDD[LabeledPoint]): RandomForestModel = {
+ def train(input: RDD[LabeledPoint]): WeightedEnsembleModel = {
val timer = new TimeTracker()
@@ -111,11 +112,20 @@ private class RandomForest (
// Bin feature values (TreePoint representation).
// Cache input RDD for speedup during multiple passes.
val treeInput = TreePoint.convertToTreeRDD(retaggedInput, bins, metadata)
- val baggedInput = if (numTrees > 1) {
- BaggedPoint.convertToBaggedRDD(treeInput, numTrees, seed)
- } else {
- BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
- }.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val (subsample, withReplacement) = {
+ // TODO: Have a stricter check for RF in the strategy
+ val isRandomForest = numTrees > 1
+ if (isRandomForest) {
+ (1.0, true)
+ } else {
+ (strategy.subsamplingRate, false)
+ }
+ }
+
+ val baggedInput
+ = BaggedPoint.convertToBaggedRDD(treeInput, subsample, numTrees, withReplacement, seed)
+ .persist(StorageLevel.MEMORY_AND_DISK)
// depth of the decision tree
val maxDepth = strategy.maxDepth
@@ -150,6 +160,19 @@ private class RandomForest (
* in lower levels).
*/
+ // Create an RDD of node Id cache.
+ // At first, all the rows belong to the root nodes (node Id == 1).
+ val nodeIdCache = if (strategy.useNodeIdCache) {
+ Some(NodeIdCache.init(
+ data = baggedInput,
+ numTrees = numTrees,
+ checkpointDir = strategy.checkpointDir,
+ checkpointInterval = strategy.checkpointInterval,
+ initVal = 1))
+ } else {
+ None
+ }
+
// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()
@@ -172,7 +195,7 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
- treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+ treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
timer.stop("findBestSplits")
}
@@ -183,8 +206,14 @@ private class RandomForest (
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
+ // Delete any remaining checkpoints used for node Id cache.
+ if (nodeIdCache.nonEmpty) {
+ nodeIdCache.get.deleteAllCheckpoints()
+ }
+
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
- RandomForestModel.build(trees)
+ val treeWeights = Array.fill[Double](numTrees)(1.0)
+ new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
}
}
@@ -205,14 +234,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
require(strategy.algo == Classification,
s"RandomForest.trainClassifier given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -243,7 +272,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainClassifier(
input: RDD[LabeledPoint],
@@ -254,7 +283,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): RandomForestModel = {
+ seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Classification, impurityType, maxDepth,
numClassesForClassification, maxBins, Sort, categoricalFeaturesInfo)
@@ -273,7 +302,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
trainClassifier(input.rdd, numClassesForClassification,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
@@ -293,14 +322,14 @@ object RandomForest extends Serializable with Logging {
* if numTrees > 1 (forest) set to "sqrt" for classification and
* to "onethird" for regression.
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
strategy: Strategy,
numTrees: Int,
featureSubsetStrategy: String,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
require(strategy.algo == Regression,
s"RandomForest.trainRegressor given Strategy with invalid algo: ${strategy.algo}")
val rf = new RandomForest(strategy, numTrees, featureSubsetStrategy, seed)
@@ -330,7 +359,7 @@ object RandomForest extends Serializable with Logging {
* @param maxBins maximum number of bins used for splitting features
* (suggested value: 100)
* @param seed Random seed for bootstrapping and choosing feature subsets.
- * @return RandomForestModel that can be used for prediction
+ * @return WeightedEnsembleModel that can be used for prediction
*/
def trainRegressor(
input: RDD[LabeledPoint],
@@ -340,7 +369,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int = Utils.random.nextInt()): RandomForestModel = {
+ seed: Int = Utils.random.nextInt()): WeightedEnsembleModel = {
val impurityType = Impurities.fromString(impurity)
val strategy = new Strategy(Regression, impurityType, maxDepth,
0, maxBins, Sort, categoricalFeaturesInfo)
@@ -358,7 +387,7 @@ object RandomForest extends Serializable with Logging {
impurity: String,
maxDepth: Int,
maxBins: Int,
- seed: Int): RandomForestModel = {
+ seed: Int): WeightedEnsembleModel = {
trainRegressor(input.rdd,
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
new file mode 100644
index 0000000000000..501d9ff9ea9b7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import scala.beans.BeanProperty
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
+import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
+
+/**
+ * :: Experimental ::
+ * Stores all the configuration options for the boosting algorithms
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @param numEstimators Number of estimators used in boosting stages. In other words,
+ * number of boosting iterations performed.
+ * @param loss Loss function used for minimization during gradient boosting.
+ * @param learningRate Learning rate for shrinking the contribution of each estimator. The
+ * learning rate should be between in the interval (0, 1]
+ * @param subsamplingRate Fraction of the training data used for learning the decision tree.
+ * @param numClassesForClassification Number of classes for classification.
+ * (Ignored for regression.)
+ * Default value is 2 (binary classification).
+ * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
+ * number of discrete values they take. For example, an entry (n ->
+ * k) implies the feature n is categorical with k categories 0,
+ * 1, 2, ... , k-1. It's important to note that features are
+ * zero-indexed.
+ * @param weakLearnerParams Parameters for weak learners. Currently only decision trees are
+ * supported.
+ */
+@Experimental
+case class BoostingStrategy(
+ // Required boosting parameters
+ algo: Algo,
+ @BeanProperty var numEstimators: Int,
+ @BeanProperty var loss: Loss,
+ // Optional boosting parameters
+ @BeanProperty var learningRate: Double = 0.1,
+ @BeanProperty var subsamplingRate: Double = 1.0,
+ @BeanProperty var numClassesForClassification: Int = 2,
+ @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+ @BeanProperty var weakLearnerParams: Strategy) extends Serializable {
+
+ require(learningRate <= 1, "Learning rate should be <= 1. Provided learning rate is " +
+ s"$learningRate.")
+ require(learningRate > 0, "Learning rate should be > 0. Provided learning rate is " +
+ s"$learningRate.")
+
+ // Ensure values for weak learner are the same as what is provided to the boosting algorithm.
+ weakLearnerParams.categoricalFeaturesInfo = categoricalFeaturesInfo
+ weakLearnerParams.numClassesForClassification = numClassesForClassification
+ weakLearnerParams.subsamplingRate = subsamplingRate
+
+}
+
+@Experimental
+object BoostingStrategy {
+
+ /**
+ * Returns default configuration for the boosting algorithm
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @return Configuration for boosting algorithm
+ */
+ def defaultParams(algo: Algo): BoostingStrategy = {
+ val treeStrategy = defaultWeakLearnerParams(algo)
+ algo match {
+ case Classification =>
+ new BoostingStrategy(algo, 100, LogLoss, weakLearnerParams = treeStrategy)
+ case Regression =>
+ new BoostingStrategy(algo, 100, SquaredError, weakLearnerParams = treeStrategy)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the boosting.")
+ }
+ }
+
+ /**
+ * Returns default configuration for the weak learner (decision tree) algorithm
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @return Configuration for weak learner
+ */
+ def defaultWeakLearnerParams(algo: Algo): Strategy = {
+ // Note: Regression tree used even for classification for GBT.
+ new Strategy(Regression, Variance, 3)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
new file mode 100644
index 0000000000000..82889dc00cdad
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/EnsembleCombiningStrategy.scala
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.configuration
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: Experimental ::
+ * Enum to select ensemble combining strategy for base learners
+ */
+@DeveloperApi
+object EnsembleCombiningStrategy extends Enumeration {
+ type EnsembleCombiningStrategy = Value
+ val Sum, Average = Value
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index caaccbfb8ad16..d09295c507d67 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.tree.configuration
+import scala.beans.BeanProperty
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
@@ -43,7 +44,7 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* for choosing how to split on features at each node.
* More bins give higher granularity.
* @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported:
- * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
+ * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
@@ -58,19 +59,31 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* this split will not be considered as a valid split.
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
+ * maintain a separate RDD of node Id cache for each row.
+ * @param checkpointDir If the node Id cache is used, it will help to checkpoint
+ * the node Id cache periodically. This is the checkpoint directory
+ * to be used for the node Id cache.
+ * @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
+ * E.g. 10 means that the cache will get checkpointed every 10 updates.
*/
@Experimental
class Strategy (
val algo: Algo,
- val impurity: Impurity,
- val maxDepth: Int,
- val numClassesForClassification: Int = 2,
- val maxBins: Int = 32,
- val quantileCalculationStrategy: QuantileStrategy = Sort,
- val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
- val minInstancesPerNode: Int = 1,
- val minInfoGain: Double = 0.0,
- val maxMemoryInMB: Int = 256) extends Serializable {
+ @BeanProperty var impurity: Impurity,
+ @BeanProperty var maxDepth: Int,
+ @BeanProperty var numClassesForClassification: Int = 2,
+ @BeanProperty var maxBins: Int = 32,
+ @BeanProperty var quantileCalculationStrategy: QuantileStrategy = Sort,
+ @BeanProperty var categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+ @BeanProperty var minInstancesPerNode: Int = 1,
+ @BeanProperty var minInfoGain: Double = 0.0,
+ @BeanProperty var maxMemoryInMB: Int = 256,
+ @BeanProperty var subsamplingRate: Double = 1,
+ @BeanProperty var useNodeIdCache: Boolean = false,
+ @BeanProperty var checkpointDir: Option[String] = None,
+ @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
if (algo == Classification) {
require(numClassesForClassification >= 2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
index 937c8a2ac5836..089010c81ffb6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala
@@ -17,18 +17,18 @@
package org.apache.spark.mllib.tree.impl
-import cern.jet.random.Poisson
-import cern.jet.random.engine.DRand
+import org.apache.commons.math3.distribution.PoissonDistribution
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
/**
* Internal representation of a datapoint which belongs to several subsamples of the same dataset,
* particularly for bagging (e.g., for random forests).
*
* This holds one instance, as well as an array of weights which represent the (weighted)
- * number of times which this instance appears in each subsample.
+ * number of times which this instance appears in each subsamplingRate.
* E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that
* this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively.
*
@@ -45,27 +45,71 @@ private[tree] object BaggedPoint {
/**
* Convert an input dataset into its BaggedPoint representation,
- * choosing subsample counts for each instance.
- * Each subsample has the same number of instances as the original dataset,
- * and is created by subsampling with replacement.
- * @param input Input dataset.
- * @param numSubsamples Number of subsamples of this RDD to take.
- * @param seed Random seed.
- * @return BaggedPoint dataset representation
+ * choosing subsamplingRate counts for each instance.
+ * Each subsamplingRate has the same number of instances as the original dataset,
+ * and is created by subsampling without replacement.
+ * @param input Input dataset.
+ * @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param numSubsamples Number of subsamples of this RDD to take.
+ * @param withReplacement Sampling with/without replacement.
+ * @param seed Random seed.
+ * @return BaggedPoint dataset representation.
*/
- def convertToBaggedRDD[Datum](
+ def convertToBaggedRDD[Datum] (
input: RDD[Datum],
+ subsamplingRate: Double,
numSubsamples: Int,
+ withReplacement: Boolean,
seed: Int = Utils.random.nextInt()): RDD[BaggedPoint[Datum]] = {
+ if (withReplacement) {
+ convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed)
+ } else {
+ if (numSubsamples == 1 && subsamplingRate == 1.0) {
+ convertToBaggedRDDWithoutSampling(input)
+ } else {
+ convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed)
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithoutReplacement[Datum] (
+ input: RDD[Datum],
+ subsamplingRate: Double,
+ numSubsamples: Int,
+ seed: Int): RDD[BaggedPoint[Datum]] = {
+ input.mapPartitionsWithIndex { (partitionIndex, instances) =>
+ // Use random seed = seed + partitionIndex + 1 to make generation reproducible.
+ val rng = new XORShiftRandom
+ rng.setSeed(seed + partitionIndex + 1)
+ instances.map { instance =>
+ val subsampleWeights = new Array[Double](numSubsamples)
+ var subsampleIndex = 0
+ while (subsampleIndex < numSubsamples) {
+ val x = rng.nextDouble()
+ subsampleWeights(subsampleIndex) = {
+ if (x < subsamplingRate) 1.0 else 0.0
+ }
+ subsampleIndex += 1
+ }
+ new BaggedPoint(instance, subsampleWeights)
+ }
+ }
+ }
+
+ private def convertToBaggedRDDSamplingWithReplacement[Datum] (
+ input: RDD[Datum],
+ subsample: Double,
+ numSubsamples: Int,
+ seed: Int): RDD[BaggedPoint[Datum]] = {
input.mapPartitionsWithIndex { (partitionIndex, instances) =>
- // TODO: Support different sampling rates, and sampling without replacement.
// Use random seed = seed + partitionIndex + 1 to make generation reproducible.
- val poisson = new Poisson(1.0, new DRand(seed + partitionIndex + 1))
+ val poisson = new PoissonDistribution(subsample)
+ poisson.reseedRandomGenerator(seed + partitionIndex + 1)
instances.map { instance =>
val subsampleWeights = new Array[Double](numSubsamples)
var subsampleIndex = 0
while (subsampleIndex < numSubsamples) {
- subsampleWeights(subsampleIndex) = poisson.nextInt()
+ subsampleWeights(subsampleIndex) = poisson.sample()
subsampleIndex += 1
}
new BaggedPoint(instance, subsampleWeights)
@@ -73,7 +117,8 @@ private[tree] object BaggedPoint {
}
}
- def convertToBaggedRDDWithoutSampling[Datum](input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
+ private def convertToBaggedRDDWithoutSampling[Datum] (
+ input: RDD[Datum]): RDD[BaggedPoint[Datum]] = {
input.map(datum => new BaggedPoint(datum, Array(1.0)))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 772c02670e541..5bc0f2635c6b1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -76,6 +76,17 @@ private[tree] class DecisionTreeMetadata(
numBins(featureIndex) - 1
}
+
+ /**
+ * Set number of splits for a continuous feature.
+ * For a continuous feature, number of bins is number of splits plus 1.
+ */
+ def setNumSplits(featureIndex: Int, numSplits: Int) {
+ require(isContinuous(featureIndex),
+ s"Only number of bin for a continuous feature can be set.")
+ numBins(featureIndex) = numSplits + 1
+ }
+
/**
* Indicates if feature subsampling is being used.
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
new file mode 100644
index 0000000000000..83011b48b7d9b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
+
+/**
+ * :: DeveloperApi ::
+ * This is used by the node id cache to find the child id that a data point would belong to.
+ * @param split Split information.
+ * @param nodeIndex The current node index of a data point that this will update.
+ */
+@DeveloperApi
+private[tree] case class NodeIndexUpdater(
+ split: Split,
+ nodeIndex: Int) {
+ /**
+ * Determine a child node index based on the feature value and the split.
+ * @param binnedFeatures Binned feature values.
+ * @param bins Bin information to convert the bin indices to approximate feature values.
+ * @return Child node index to update to.
+ */
+ def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
+ if (split.featureType == Continuous) {
+ val featureIndex = split.feature
+ val binIndex = binnedFeatures(featureIndex)
+ val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+ if (featureValueUpperBound <= split.threshold) {
+ Node.leftChildIndex(nodeIndex)
+ } else {
+ Node.rightChildIndex(nodeIndex)
+ }
+ } else {
+ if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
+ Node.leftChildIndex(nodeIndex)
+ } else {
+ Node.rightChildIndex(nodeIndex)
+ }
+ }
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * A given TreePoint would belong to a particular node per tree.
+ * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
+ * in each tree. Initially, values should all be 1 for root node.
+ * The nodeIdsForInstances RDD needs to be updated at each iteration.
+ * @param nodeIdsForInstances The initial values in the cache
+ * (should be an Array of all 1's (meaning the root nodes)).
+ * @param checkpointDir The checkpoint directory where
+ * the checkpointed files will be stored.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ */
+@DeveloperApi
+private[tree] class NodeIdCache(
+ var nodeIdsForInstances: RDD[Array[Int]],
+ val checkpointDir: Option[String],
+ val checkpointInterval: Int) {
+
+ // Keep a reference to a previous node Ids for instances.
+ // Because we will keep on re-persisting updated node Ids,
+ // we want to unpersist the previous RDD.
+ private var prevNodeIdsForInstances: RDD[Array[Int]] = null
+
+ // To keep track of the past checkpointed RDDs.
+ private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
+ private var rddUpdateCount = 0
+
+ // If a checkpoint directory is given, and there's no prior checkpoint directory,
+ // then set the checkpoint directory with the given one.
+ if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
+ nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
+ }
+
+ /**
+ * Update the node index values in the cache.
+ * This updates the RDD and its lineage.
+ * TODO: Passing bin information to executors seems unnecessary and costly.
+ * @param data The RDD of training rows.
+ * @param nodeIdUpdaters A map of node index updaters.
+ * The key is the indices of nodes that we want to update.
+ * @param bins Bin information needed to find child node indices.
+ */
+ def updateNodeIndices(
+ data: RDD[BaggedPoint[TreePoint]],
+ nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
+ bins: Array[Array[Bin]]): Unit = {
+ if (prevNodeIdsForInstances != null) {
+ // Unpersist the previous one if one exists.
+ prevNodeIdsForInstances.unpersist()
+ }
+
+ prevNodeIdsForInstances = nodeIdsForInstances
+ nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
+ dataPoint => {
+ var treeId = 0
+ while (treeId < nodeIdUpdaters.length) {
+ val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null)
+ if (nodeIdUpdater != null) {
+ val newNodeIndex = nodeIdUpdater.updateNodeIndex(
+ binnedFeatures = dataPoint._1.datum.binnedFeatures,
+ bins = bins)
+ dataPoint._2(treeId) = newNodeIndex
+ }
+
+ treeId += 1
+ }
+
+ dataPoint._2
+ }
+ }
+
+ // Keep on persisting new ones.
+ nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
+ rddUpdateCount += 1
+
+ // Handle checkpointing if the directory is not None.
+ if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
+ (rddUpdateCount % checkpointInterval) == 0) {
+ // Let's see if we can delete previous checkpoints.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // We can delete the oldest checkpoint iff
+ // the next checkpoint actually exists in the file system.
+ if (checkpointQueue.get(1).get.getCheckpointFile != None) {
+ val old = checkpointQueue.dequeue()
+
+ // Since the old checkpoint is not deleted by Spark,
+ // we'll manually delete it here.
+ val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ } else {
+ canDelete = false
+ }
+ }
+
+ nodeIdsForInstances.checkpoint()
+ checkpointQueue.enqueue(nodeIdsForInstances)
+ }
+ }
+
+ /**
+ * Call this after training is finished to delete any remaining checkpoints.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.size > 0) {
+ val old = checkpointQueue.dequeue()
+ if (old.getCheckpointFile != None) {
+ val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ }
+ }
+ }
+}
+
+@DeveloperApi
+private[tree] object NodeIdCache {
+ /**
+ * Initialize the node Id cache with initial node Id values.
+ * @param data The RDD of training rows.
+ * @param numTrees The number of trees that we want to create cache for.
+ * @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ * @param initVal The initial values in the cache.
+ * @return A node Id cache containing an RDD of initial root node Indices.
+ */
+ def init(
+ data: RDD[BaggedPoint[TreePoint]],
+ numTrees: Int,
+ checkpointDir: Option[String],
+ checkpointInterval: Int,
+ initVal: Int = 1): NodeIdCache = {
+ new NodeIdCache(
+ data.map(_ => Array.fill[Int](numTrees)(initVal)),
+ checkpointDir,
+ checkpointInterval)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
new file mode 100644
index 0000000000000..d111ffe30ed9e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/AbsoluteError.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least absolute error loss calculation.
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: |y - F|
+ * Negative gradient: sign(y - F)
+ */
+@DeveloperApi
+object AbsoluteError extends Loss {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation for least
+ * absolute error calculation.
+ * @param model Model of the weak learner
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double = {
+ if ((point.label - model.predict(point.features)) < 0) 1.0 else -1.0
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ val sumOfAbsolutes = data.map { y =>
+ val err = model.predict(y.features) - y.label
+ math.abs(err)
+ }.sum()
+ sumOfAbsolutes / data.count()
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
new file mode 100644
index 0000000000000..6f3d4340f0d3b
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/LogLoss.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least squares error loss calculation.
+ *
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: log(1 + exp(-2yF)), y in {-1, 1}
+ * Negative gradient: 2y / ( 1 + exp(2yF))
+ */
+@DeveloperApi
+object LogLoss extends Loss {
+
+ /**
+ * Method to calculate the loss gradients for the gradient boosting calculation for binary
+ * classification
+ * @param model Model of the weak learner
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double = {
+ val prediction = model.predict(point.features)
+ 1.0 / (1.0 + math.exp(-prediction)) - point.label
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ val wrongPredictions = data.filter(lp => model.predict(lp.features) != lp.label).count()
+ wrongPredictions / data.count
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
new file mode 100644
index 0000000000000..5580866c879e2
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Trait for adding "pluggable" loss functions for the gradient boosting algorithm.
+ */
+@DeveloperApi
+trait Loss extends Serializable {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation.
+ * @param model Model of the weak learner.
+ * @param point Instance of the training dataset.
+ * @return Loss gradient.
+ */
+ def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
new file mode 100644
index 0000000000000..42c9ead9884b4
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Losses.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+object Losses {
+
+ def fromString(name: String): Loss = name match {
+ case "leastSquaresError" => SquaredError
+ case "leastAbsoluteError" => AbsoluteError
+ case "logLoss" => LogLoss
+ case _ => throw new IllegalArgumentException(s"Did not recognize Loss name: $name")
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
new file mode 100644
index 0000000000000..4349fefef2c74
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/SquaredError.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.loss
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: DeveloperApi ::
+ * Class for least squares error loss calculation.
+ *
+ * The features x and the corresponding label y is predicted using the function F.
+ * For each instance:
+ * Loss: (y - F)**2/2
+ * Negative gradient: y - F
+ */
+@DeveloperApi
+object SquaredError extends Loss {
+
+ /**
+ * Method to calculate the gradients for the gradient boosting calculation for least
+ * squares error calculation.
+ * @param model Model of the weak learner
+ * @param point Instance of the training dataset
+ * @return Loss gradient
+ */
+ override def gradient(
+ model: WeightedEnsembleModel,
+ point: LabeledPoint): Double = {
+ model.predict(point.features) - point.label
+ }
+
+ /**
+ * Method to calculate error of the base learner for the gradient boosting calculation.
+ * Note: This method is not used by the gradient boosting algorithm but is useful for debugging
+ * purposes.
+ * @param model Model of the weak learner.
+ * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @return
+ */
+ override def computeError(model: WeightedEnsembleModel, data: RDD[LabeledPoint]): Double = {
+ data.map { y =>
+ val err = model.predict(y.features) - y.label
+ err * err
+ }.mean()
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
deleted file mode 100644
index 6a22e2abe59bd..0000000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/RandomForestModel.scala
+++ /dev/null
@@ -1,115 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.model
-
-import scala.collection.mutable
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.rdd.RDD
-
-/**
- * :: Experimental ::
- * Random forest model for classification or regression.
- * This model stores a collection of [[DecisionTreeModel]] instances and uses them to make
- * aggregate predictions.
- * @param trees Trees which make up this forest. This cannot be empty.
- * @param algo algorithm type -- classification or regression
- */
-@Experimental
-class RandomForestModel(val trees: Array[DecisionTreeModel], val algo: Algo) extends Serializable {
-
- require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
-
- /**
- * Predict values for a single data point.
- *
- * @param features array representing a single data point
- * @return Double prediction from the trained model
- */
- def predict(features: Vector): Double = {
- algo match {
- case Classification =>
- val predictionToCount = new mutable.HashMap[Int, Int]()
- trees.foreach { tree =>
- val prediction = tree.predict(features).toInt
- predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
- }
- predictionToCount.maxBy(_._2)._1
- case Regression =>
- trees.map(_.predict(features)).sum / trees.size
- }
- }
-
- /**
- * Predict values for the given data set.
- *
- * @param features RDD representing data points to be predicted
- * @return RDD[Double] where each entry contains the corresponding prediction
- */
- def predict(features: RDD[Vector]): RDD[Double] = {
- features.map(x => predict(x))
- }
-
- /**
- * Get number of trees in forest.
- */
- def numTrees: Int = trees.size
-
- /**
- * Get total number of nodes, summed over all trees in the forest.
- */
- def totalNumNodes: Int = trees.map(tree => tree.numNodes).sum
-
- /**
- * Print a summary of the model.
- */
- override def toString: String = algo match {
- case Classification =>
- s"RandomForestModel classifier with $numTrees trees and $totalNumNodes total nodes"
- case Regression =>
- s"RandomForestModel regressor with $numTrees trees and $totalNumNodes total nodes"
- case _ => throw new IllegalArgumentException(
- s"RandomForestModel given unknown algo parameter: $algo.")
- }
-
- /**
- * Print the full model to a string.
- */
- def toDebugString: String = {
- val header = toString + "\n"
- header + trees.zipWithIndex.map { case (tree, treeIndex) =>
- s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
- }.fold("")(_ + _)
- }
-
-}
-
-private[tree] object RandomForestModel {
-
- def build(trees: Array[DecisionTreeModel]): RandomForestModel = {
- require(trees.size > 0, s"RandomForestModel cannot be created with empty trees collection.")
- val algo: Algo = trees(0).algo
- require(trees.forall(_.algo == algo),
- "RandomForestModel cannot combine trees which have different output types" +
- " (classification/regression).")
- new RandomForestModel(trees, algo)
- }
-
-}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
new file mode 100644
index 0000000000000..7b052d9163a13
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/WeightedEnsembleModel.scala
@@ -0,0 +1,158 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.rdd.RDD
+
+import scala.collection.mutable
+
+@Experimental
+class WeightedEnsembleModel(
+ val weakHypotheses: Array[DecisionTreeModel],
+ val weakHypothesisWeights: Array[Double],
+ val algo: Algo,
+ val combiningStrategy: EnsembleCombiningStrategy) extends Serializable {
+
+ require(numWeakHypotheses > 0, s"WeightedEnsembleModel cannot be created without weakHypotheses" +
+ s". Number of weakHypotheses = $weakHypotheses")
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ private def predictRaw(features: Vector): Double = {
+ val treePredictions = weakHypotheses.map(learner => learner.predict(features))
+ if (numWeakHypotheses == 1){
+ treePredictions(0)
+ } else {
+ var prediction = treePredictions(0)
+ var index = 1
+ while (index < numWeakHypotheses) {
+ prediction += weakHypothesisWeights(index) * treePredictions(index)
+ index += 1
+ }
+ prediction
+ }
+ }
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ private def predictBySumming(features: Vector): Double = {
+ algo match {
+ case Regression => predictRaw(features)
+ case Classification => {
+ // TODO: predicted labels are +1 or -1 for GBT. Need a better way to store this info.
+ if (predictRaw(features) > 0 ) 1.0 else 0.0
+ }
+ case _ => throw new IllegalArgumentException(
+ s"WeightedEnsembleModel given unknown algo parameter: $algo.")
+ }
+ }
+
+ /**
+ * Predict values for a single data point.
+ *
+ * @param features array representing a single data point
+ * @return Double prediction from the trained model
+ */
+ private def predictByAveraging(features: Vector): Double = {
+ algo match {
+ case Classification =>
+ val predictionToCount = new mutable.HashMap[Int, Int]()
+ weakHypotheses.foreach { learner =>
+ val prediction = learner.predict(features).toInt
+ predictionToCount(prediction) = predictionToCount.getOrElse(prediction, 0) + 1
+ }
+ predictionToCount.maxBy(_._2)._1
+ case Regression =>
+ weakHypotheses.map(_.predict(features)).sum / weakHypotheses.size
+ }
+ }
+
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param features array representing a single data point
+ * @return predicted category from the trained model
+ */
+ def predict(features: Vector): Double = {
+ combiningStrategy match {
+ case Sum => predictBySumming(features)
+ case Average => predictByAveraging(features)
+ case _ => throw new IllegalArgumentException(
+ s"WeightedEnsembleModel given unknown combining parameter: $combiningStrategy.")
+ }
+ }
+
+ /**
+ * Predict values for the given data set.
+ *
+ * @param features RDD representing data points to be predicted
+ * @return RDD[Double] where each entry contains the corresponding prediction
+ */
+ def predict(features: RDD[Vector]): RDD[Double] = features.map(x => predict(x))
+
+ /**
+ * Print a summary of the model.
+ */
+ override def toString: String = {
+ algo match {
+ case Classification =>
+ s"WeightedEnsembleModel classifier with $numWeakHypotheses trees\n"
+ case Regression =>
+ s"WeightedEnsembleModel regressor with $numWeakHypotheses trees\n"
+ case _ => throw new IllegalArgumentException(
+ s"WeightedEnsembleModel given unknown algo parameter: $algo.")
+ }
+ }
+
+ /**
+ * Print the full model to a string.
+ */
+ def toDebugString: String = {
+ val header = toString + "\n"
+ header + weakHypotheses.zipWithIndex.map { case (tree, treeIndex) =>
+ s" Tree $treeIndex:\n" + tree.topNode.subtreeToString(4)
+ }.fold("")(_ + _)
+ }
+
+ /**
+ * Get number of trees in forest.
+ */
+ def numWeakHypotheses: Int = weakHypotheses.size
+
+ // TODO: Remove these helpers methods once class is generalized to support any base learning
+ // algorithms.
+
+ /**
+ * Get total number of nodes, summed over all trees in the forest.
+ */
+ def totalNumNodes: Int = weakHypotheses.map(tree => tree.numNodes).sum
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index ca35100aa99c6..9353351af72a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -26,7 +26,7 @@ import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.PartitionwiseSampledRDD
-import org.apache.spark.util.random.BernoulliSampler
+import org.apache.spark.util.random.BernoulliCellSampler
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.storage.StorageLevel
@@ -76,7 +76,7 @@ object MLUtils {
.map { line =>
val items = line.split(' ')
val label = items.head.toDouble
- val (indices, values) = items.tail.map { item =>
+ val (indices, values) = items.tail.filter(_.nonEmpty).map { item =>
val indexAndValue = item.split(':')
val index = indexAndValue(0).toInt - 1 // Convert 1-based indices to 0-based.
val value = indexAndValue(1).toDouble
@@ -196,8 +196,8 @@ object MLUtils {
/**
* Load labeled data from a file. The data format used here is
- * , ...
- * where , are feature values in Double and is the corresponding label as Double.
+ * L, f1 f2 ...
+ * where f1, f2 are feature values in Double and L is the corresponding label as Double.
*
* @param sc SparkContext
* @param dir Directory to the input data files.
@@ -219,8 +219,8 @@ object MLUtils {
/**
* Save labeled data to a file. The data format used here is
- * , ...
- * where , are feature values in Double and is the corresponding label as Double.
+ * L, f1 f2 ...
+ * where f1, f2 are feature values in Double and L is the corresponding label as Double.
*
* @param data An RDD of LabeledPoints containing data to be saved.
* @param dir Directory to save the data.
@@ -244,7 +244,7 @@ object MLUtils {
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
val numFoldsF = numFolds.toFloat
(1 to numFolds).map { fold =>
- val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
+ val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
complement = false)
val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed)
val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
new file mode 100644
index 0000000000000..850c9fce507cd
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/StreamingKMeansSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.streaming.TestSuiteBase
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.random.XORShiftRandom
+
+class StreamingKMeansSuite extends FunSuite with TestSuiteBase {
+
+ override def maxWaitTimeMillis = 30000
+
+ test("accuracy for single center and equivalence to grand average") {
+ // set parameters
+ val numBatches = 10
+ val numPoints = 50
+ val k = 1
+ val d = 5
+ val r = 0.1
+
+ // create model with one cluster
+ val model = new StreamingKMeans()
+ .setK(1)
+ .setDecayFactor(1.0)
+ .setInitialCenters(Array(Vectors.dense(0.0, 0.0, 0.0, 0.0, 0.0)), Array(0.0))
+
+ // generate random data for k-means
+ val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
+
+ // setup and run the model training
+ val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ model.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+
+ // estimated center should be close to true center
+ assert(centers(0) ~== model.latestModel().clusterCenters(0) absTol 1E-1)
+
+ // estimated center from streaming should exactly match the arithmetic mean of all data points
+ // because the decay factor is set to 1.0
+ val grandMean =
+ input.flatten.map(x => x.toBreeze).reduce(_+_) / (numBatches * numPoints).toDouble
+ assert(model.latestModel().clusterCenters(0) ~== Vectors.dense(grandMean.toArray) absTol 1E-5)
+ }
+
+ test("accuracy for two centers") {
+ val numBatches = 10
+ val numPoints = 5
+ val k = 2
+ val d = 5
+ val r = 0.1
+
+ // create model with two clusters
+ val kMeans = new StreamingKMeans()
+ .setK(2)
+ .setHalfLife(2, "batches")
+ .setInitialCenters(
+ Array(Vectors.dense(-0.1, 0.1, -0.2, -0.3, -0.1),
+ Vectors.dense(0.1, -0.2, 0.0, 0.2, 0.1)),
+ Array(5.0, 5.0))
+
+ // generate random data for k-means
+ val (input, centers) = StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42)
+
+ // setup and run the model training
+ val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ kMeans.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+
+ // check that estimated centers are close to true centers
+ // NOTE exact assignment depends on the initialization!
+ assert(centers(0) ~== kMeans.latestModel().clusterCenters(0) absTol 1E-1)
+ assert(centers(1) ~== kMeans.latestModel().clusterCenters(1) absTol 1E-1)
+ }
+
+ test("detecting dying clusters") {
+ val numBatches = 10
+ val numPoints = 5
+ val k = 1
+ val d = 1
+ val r = 1.0
+
+ // create model with two clusters
+ val kMeans = new StreamingKMeans()
+ .setK(2)
+ .setHalfLife(0.5, "points")
+ .setInitialCenters(
+ Array(Vectors.dense(0.0), Vectors.dense(1000.0)),
+ Array(1.0, 1.0))
+
+ // new data are all around the first cluster 0.0
+ val (input, _) =
+ StreamingKMeansDataGenerator(numPoints, numBatches, k, d, r, 42, Array(Vectors.dense(0.0)))
+
+ // setup and run the model training
+ val ssc = setupStreams(input, (inputDStream: DStream[Vector]) => {
+ kMeans.trainOn(inputDStream)
+ inputDStream.count()
+ })
+ runStreams(ssc, numBatches, numBatches)
+
+ // check that estimated centers are close to true centers
+ // NOTE exact assignment depends on the initialization!
+ val model = kMeans.latestModel()
+ val c0 = model.clusterCenters(0)(0)
+ val c1 = model.clusterCenters(1)(0)
+
+ assert(c0 * c1 < 0.0, "should have one positive center and one negative center")
+ // 0.8 is the mean of half-normal distribution
+ assert(math.abs(c0) ~== 0.8 absTol 0.6)
+ assert(math.abs(c1) ~== 0.8 absTol 0.6)
+ }
+
+ def StreamingKMeansDataGenerator(
+ numPoints: Int,
+ numBatches: Int,
+ k: Int,
+ d: Int,
+ r: Double,
+ seed: Int,
+ initCenters: Array[Vector] = null): (IndexedSeq[IndexedSeq[Vector]], Array[Vector]) = {
+ val rand = new XORShiftRandom(seed)
+ val centers = initCenters match {
+ case null => Array.fill(k)(Vectors.dense(Array.fill(d)(rand.nextGaussian())))
+ case _ => initCenters
+ }
+ val data = (0 until numBatches).map { i =>
+ (0 until numPoints).map { idx =>
+ val center = centers(idx % k)
+ Vectors.dense(Array.tabulate(d)(x => center(x) + rand.nextGaussian() * r))
+ }
+ }
+ (data, centers)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
new file mode 100644
index 0000000000000..342baa0274e9c
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.rdd.RDD
+
+class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
+ test("Multilabel evaluation metrics") {
+ /*
+ * Documents true labels (5x class0, 3x class1, 4x class2):
+ * doc 0 - predict 0, 1 - class 0, 2
+ * doc 1 - predict 0, 2 - class 0, 1
+ * doc 2 - predict none - class 0
+ * doc 3 - predict 2 - class 2
+ * doc 4 - predict 2, 0 - class 2, 0
+ * doc 5 - predict 0, 1, 2 - class 0, 1
+ * doc 6 - predict 1 - class 1, 2
+ *
+ * predicted classes
+ * class 0 - doc 0, 1, 4, 5 (total 4)
+ * class 1 - doc 0, 5, 6 (total 3)
+ * class 2 - doc 1, 3, 4, 5 (total 4)
+ *
+ * true classes
+ * class 0 - doc 0, 1, 2, 4, 5 (total 5)
+ * class 1 - doc 1, 5, 6 (total 3)
+ * class 2 - doc 0, 3, 4, 6 (total 4)
+ *
+ */
+ val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize(
+ Seq((Array(0.0, 1.0), Array(0.0, 2.0)),
+ (Array(0.0, 2.0), Array(0.0, 1.0)),
+ (Array(), Array(0.0)),
+ (Array(2.0), Array(2.0)),
+ (Array(2.0, 0.0), Array(2.0, 0.0)),
+ (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)),
+ (Array(1.0), Array(1.0, 2.0))), 2)
+ val metrics = new MultilabelMetrics(scoreAndLabels)
+ val delta = 0.00001
+ val precision0 = 4.0 / (4 + 0)
+ val precision1 = 2.0 / (2 + 1)
+ val precision2 = 2.0 / (2 + 2)
+ val recall0 = 4.0 / (4 + 1)
+ val recall1 = 2.0 / (2 + 1)
+ val recall2 = 2.0 / (2 + 2)
+ val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
+ val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
+ val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
+ val sumTp = 4 + 2 + 2
+ assert(sumTp == (1 + 1 + 0 + 1 + 2 + 2 + 1))
+ val microPrecisionClass = sumTp.toDouble / (4 + 0 + 2 + 1 + 2 + 2)
+ val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2)
+ val microF1MeasureClass = 2.0 * sumTp.toDouble /
+ (2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2))
+ val macroPrecisionDoc = 1.0 / 7 *
+ (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
+ val macroRecallDoc = 1.0 / 7 *
+ (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)
+ val macroF1MeasureDoc = (1.0 / 7) *
+ 2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) +
+ 2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) )
+ val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
+ val strictAccuracy = 2.0 / 7
+ val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
+ assert(math.abs(metrics.precision(0.0) - precision0) < delta)
+ assert(math.abs(metrics.precision(1.0) - precision1) < delta)
+ assert(math.abs(metrics.precision(2.0) - precision2) < delta)
+ assert(math.abs(metrics.recall(0.0) - recall0) < delta)
+ assert(math.abs(metrics.recall(1.0) - recall1) < delta)
+ assert(math.abs(metrics.recall(2.0) - recall2) < delta)
+ assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
+ assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
+ assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
+ assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta)
+ assert(math.abs(metrics.microRecall - microRecallClass) < delta)
+ assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta)
+ assert(math.abs(metrics.precision - macroPrecisionDoc) < delta)
+ assert(math.abs(metrics.recall - macroRecallDoc) < delta)
+ assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta)
+ assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
+ assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta)
+ assert(math.abs(metrics.accuracy - accuracy) < delta)
+ assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0)))
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
new file mode 100644
index 0000000000000..a2d4bb41484b8
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.util.LocalSparkContext
+
+class RankingMetricsSuite extends FunSuite with LocalSparkContext {
+ test("Ranking metrics: map, ndcg") {
+ val predictionAndLabels = sc.parallelize(
+ Seq(
+ (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
+ (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
+ (Array[Int](1, 2, 3, 4, 5), Array[Int]())
+ ), 2)
+ val eps: Double = 1E-5
+
+ val metrics = new RankingMetrics(predictionAndLabels)
+ val map = metrics.meanAveragePrecision
+
+ assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps)
+ assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps)
+ assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps)
+ assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps)
+ assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps)
+ assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps)
+ assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps)
+
+ assert(map ~== 0.355026 absTol eps)
+
+ assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps)
+ assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
+ assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
+ assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
+
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
new file mode 100644
index 0000000000000..5396d7b2b74fa
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class RegressionMetricsSuite extends FunSuite with LocalSparkContext {
+
+ test("regression metrics") {
+ val predictionAndObservations = sc.parallelize(
+ Seq((2.5,3.0),(0.0,-0.5),(2.0,2.0),(8.0,7.0)), 2)
+ val metrics = new RegressionMetrics(predictionAndObservations)
+ assert(metrics.explainedVariance ~== 0.95717 absTol 1E-5,
+ "explained variance regression score mismatch")
+ assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5,
+ "root mean squared error mismatch")
+ assert(metrics.r2 ~== 0.94861 absTol 1E-5, "r2 score mismatch")
+ }
+
+ test("regression metrics with complete fitting") {
+ val predictionAndObservations = sc.parallelize(
+ Seq((3.0,3.0),(0.0,0.0),(2.0,2.0),(8.0,8.0)), 2)
+ val metrics = new RegressionMetrics(predictionAndObservations)
+ assert(metrics.explainedVariance ~== 1.0 absTol 1E-5,
+ "explained variance regression score mismatch")
+ assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch")
+ assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch")
+ assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5,
+ "root mean squared error mismatch")
+ assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch")
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index cd651fe2d2ddf..93a84fe07b32a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -155,4 +155,15 @@ class VectorsSuite extends FunSuite {
throw new RuntimeException(s"copy returned ${dvCopy.getClass} on ${dv.getClass}.")
}
}
+
+ test("VectorUDT") {
+ val dv0 = Vectors.dense(Array.empty[Double])
+ val dv1 = Vectors.dense(1.0, 2.0)
+ val sv0 = Vectors.sparse(2, Array.empty, Array.empty)
+ val sv1 = Vectors.sparse(2, Array(1), Array(2.0))
+ val udt = new VectorUDT()
+ for (v <- Seq(dv0, dv1, sv0, sv1)) {
+ assert(v === udt.deserialize(udt.serialize(v)))
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
index b781a6aed9a8c..82c327bd49fcd 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala
@@ -37,6 +37,12 @@ class NNLSSuite extends FunSuite {
(ata, atb)
}
+ /** Compute the objective value */
+ def computeObjectiveValue(ata: DoubleMatrix, atb: DoubleMatrix, x: DoubleMatrix): Double = {
+ val res = (x.transpose().mmul(ata).mmul(x)).mul(0.5).sub(atb.dot(x))
+ res.get(0)
+ }
+
test("NNLS: exact solution cases") {
val n = 20
val rand = new Random(12346)
@@ -79,4 +85,28 @@ class NNLSSuite extends FunSuite {
assert(x(i) >= 0)
}
}
+
+ test("NNLS: objective value test") {
+ val n = 5
+ val ata = new DoubleMatrix(5, 5
+ , 517399.13534, 242529.67289, -153644.98976, 130802.84503, -798452.29283
+ , 242529.67289, 126017.69765, -75944.21743, 81785.36128, -405290.60884
+ , -153644.98976, -75944.21743, 46986.44577, -45401.12659, 247059.51049
+ , 130802.84503, 81785.36128, -45401.12659, 67457.31310, -253747.03819
+ , -798452.29283, -405290.60884, 247059.51049, -253747.03819, 1310939.40814
+ )
+ val atb = new DoubleMatrix(5, 1,
+ -31755.05710, 13047.14813, -20191.24443, 25993.77580, 11963.55017)
+
+ /** reference solution obtained from matlab function quadprog */
+ val refx = new DoubleMatrix(Array(34.90751, 103.96254, 0.00000, 27.82094, 58.79627))
+ val refObj = computeObjectiveValue(ata, atb, refx)
+
+
+ val ws = NNLS.createWorkspace(n)
+ val x = new DoubleMatrix(NNLS.solve(ata, atb, ws))
+ val obj = computeObjectiveValue(ata, atb, x)
+
+ assert(obj < refObj + 1E-5)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 98a72b0c4d750..c579cb58549f5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
@@ -102,6 +102,72 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
}
+ test("find splits for a continuous feature") {
+ // find splits for normal case
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(6), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array.fill(200000)(math.random)
+ val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 5)
+ assert(fakeMetadata.numSplits(0) === 5)
+ assert(fakeMetadata.numBins(0) === 6)
+ // check returned splits are distinct
+ assert(splits.distinct.length === splits.length)
+ }
+
+ // find splits should not return identical splits
+ // when there are not enough split candidates, reduce the number of splits in metadata
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(5), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
+ val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 3)
+ assert(fakeMetadata.numSplits(0) === 3)
+ assert(fakeMetadata.numBins(0) === 4)
+ // check returned splits are distinct
+ assert(splits.distinct.length === splits.length)
+ }
+
+ // find splits when most samples close to the minimum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(3), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
+ val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 2)
+ assert(fakeMetadata.numSplits(0) === 2)
+ assert(fakeMetadata.numBins(0) === 3)
+ assert(splits(0) === 2.0)
+ assert(splits(1) === 3.0)
+ }
+
+ // find splits when most samples close to the maximum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(3), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
+ val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 1)
+ assert(fakeMetadata.numSplits(0) === 1)
+ assert(fakeMetadata.numBins(0) === 2)
+ assert(splits(0) === 1.0)
+ }
+ }
+
test("Multiclass classification with unordered categorical features:" +
" split and bin calculations") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
@@ -427,7 +493,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(rootNode1.rightNode.nonEmpty)
val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
// Single group second level tree construction.
val nodesForGroup = Map((0, Array(rootNode1.leftNode.get, rootNode1.rightNode.get)))
@@ -720,7 +786,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
val topNode = Node.emptyNode(nodeIndex = 1)
assert(topNode.predict.predict === Double.MinValue)
@@ -763,7 +829,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
- val baggedInput = BaggedPoint.convertToBaggedRDDWithoutSampling(treeInput)
+ val baggedInput = BaggedPoint.convertToBaggedRDD(treeInput, 1.0, 1, false)
val topNode = Node.emptyNode(nodeIndex = 1)
assert(topNode.predict.predict === Double.MinValue)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
new file mode 100644
index 0000000000000..effb7b8259ffb
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.model.WeightedEnsembleModel
+import org.apache.spark.util.StatCounter
+
+import scala.collection.mutable
+
+object EnsembleTestHelper {
+
+ /**
+ * Aggregates all values in data, and tests whether the empirical mean and stddev are within
+ * epsilon of the expected values.
+ * @param data Every element of the data should be an i.i.d. sample from some distribution.
+ */
+ def testRandomArrays(
+ data: Array[Array[Double]],
+ numCols: Int,
+ expectedMean: Double,
+ expectedStddev: Double,
+ epsilon: Double) {
+ val values = new mutable.ArrayBuffer[Double]()
+ data.foreach { row =>
+ assert(row.size == numCols)
+ values ++= row
+ }
+ val stats = new StatCounter(values)
+ assert(math.abs(stats.mean - expectedMean) < epsilon)
+ assert(math.abs(stats.stdev - expectedStddev) < epsilon)
+ }
+
+ def validateClassifier(
+ model: WeightedEnsembleModel,
+ input: Seq[LabeledPoint],
+ requiredAccuracy: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
+ prediction != expected.label
+ }
+ val accuracy = (input.length - numOffPredictions).toDouble / input.length
+ assert(accuracy >= requiredAccuracy,
+ s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
+ }
+
+ def validateRegressor(
+ model: WeightedEnsembleModel,
+ input: Seq[LabeledPoint],
+ requiredMSE: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+ val err = prediction - expected.label
+ err * err
+ }.sum
+ val mse = squaredError / input.length
+ assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
+ }
+
+ def generateOrderedLabeledPoints(numFeatures: Int, numInstances: Int): Array[LabeledPoint] = {
+ val arr = new Array[LabeledPoint](numInstances)
+ for (i <- 0 until numInstances) {
+ val label = if (i < numInstances / 10) {
+ 0.0
+ } else if (i < numInstances / 2) {
+ 1.0
+ } else if (i < numInstances * 0.9) {
+ 0.0
+ } else {
+ 1.0
+ }
+ val features = Array.fill[Double](numFeatures)(i.toDouble)
+ arr(i) = new LabeledPoint(label, Vectors.dense(features))
+ }
+ arr
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
new file mode 100644
index 0000000000000..970fff82215e2
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostingSuite.scala
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
+import org.apache.spark.mllib.tree.impurity.{Variance, Gini}
+import org.apache.spark.mllib.tree.loss.{SquaredError, LogLoss}
+import org.apache.spark.mllib.tree.model.{WeightedEnsembleModel, DecisionTreeModel}
+
+import org.apache.spark.mllib.util.LocalSparkContext
+
+/**
+ * Test suite for [[GradientBoosting]].
+ */
+class GradientBoostingSuite extends FunSuite with LocalSparkContext {
+
+ test("Regression with continuous features: SquaredError") {
+
+ GradientBoostingSuite.testCombinations.foreach {
+ case (numEstimators, learningRate, subsamplingRate) =>
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ subsamplingRate = subsamplingRate)
+
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
+ subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+
+ val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
+ assert(gbt.weakHypotheses.size === numEstimators)
+ val gbtTree = gbt.weakHypotheses(0)
+
+ EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
+
+ // Make sure trees are the same.
+ assert(gbtTree.toString == dt.toString)
+ }
+ }
+
+ test("Regression with continuous features: Absolute Error") {
+
+ GradientBoostingSuite.testCombinations.foreach {
+ case (numEstimators, learningRate, subsamplingRate) =>
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ subsamplingRate = subsamplingRate)
+
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ val boostingStrategy = new BoostingStrategy(Regression, numEstimators, SquaredError,
+ subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+
+ val gbt = GradientBoosting.trainRegressor(rdd, boostingStrategy)
+ assert(gbt.weakHypotheses.size === numEstimators)
+ val gbtTree = gbt.weakHypotheses(0)
+
+ EnsembleTestHelper.validateRegressor(gbt, arr, 0.02)
+
+ // Make sure trees are the same.
+ assert(gbtTree.toString == dt.toString)
+ }
+ }
+
+
+ test("Binary classification with continuous features: Log Loss") {
+
+ GradientBoostingSuite.testCombinations.foreach {
+ case (numEstimators, learningRate, subsamplingRate) =>
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo,
+ subsamplingRate = subsamplingRate)
+
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ val boostingStrategy = new BoostingStrategy(Classification, numEstimators, LogLoss,
+ subsamplingRate, learningRate, 1, categoricalFeaturesInfo, treeStrategy)
+
+ val gbt = GradientBoosting.trainClassifier(rdd, boostingStrategy)
+ assert(gbt.weakHypotheses.size === numEstimators)
+ val gbtTree = gbt.weakHypotheses(0)
+
+ EnsembleTestHelper.validateClassifier(gbt, arr, 0.9)
+
+ // Make sure trees are the same.
+ assert(gbtTree.toString == dt.toString)
+ }
+ }
+
+}
+
+object GradientBoostingSuite {
+
+ // Combinations for estimators, learning rates and subsamplingRate
+ val testCombinations
+ = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 1.0, 0.75), (10, 0.1, 0.75))
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index fb44ceb0f57ee..73c4393c3581a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -25,99 +25,91 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata}
+import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
+import org.apache.spark.mllib.tree.model.Node
import org.apache.spark.mllib.util.LocalSparkContext
-import org.apache.spark.util.StatCounter
/**
* Test suite for [[RandomForest]].
*/
class RandomForestSuite extends FunSuite with LocalSparkContext {
-
- test("BaggedPoint RDD: without subsampling") {
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
+ def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
- val baggedRDD = BaggedPoint.convertToBaggedRDDWithoutSampling(rdd)
- baggedRDD.collect().foreach { baggedPoint =>
- assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
- }
- }
-
- test("BaggedPoint RDD: with subsampling") {
- val numSubsamples = 100
- val (expectedMean, expectedStddev) = (1.0, 1.0)
-
- val seeds = Array(123, 5354, 230, 349867, 23987)
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 1)
- val rdd = sc.parallelize(arr)
- seeds.foreach { seed =>
- val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, numSubsamples, seed = seed)
- val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
- RandomForestSuite.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
- expectedStddev, epsilon = 0.01)
- }
- }
-
- test("Binary classification with continuous features:" +
- " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
-
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
- val rdd = sc.parallelize(arr)
- val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
-
val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
- val rfTree = rf.trees(0)
+ assert(rf.weakHypotheses.size === 1)
+ val rfTree = rf.weakHypotheses(0)
val dt = DecisionTree.train(rdd, strategy)
- RandomForestSuite.validateClassifier(rf, arr, 0.9)
+ EnsembleTestHelper.validateClassifier(rf, arr, 0.9)
DecisionTreeSuite.validateClassifier(dt, arr, 0.9)
// Make sure trees are the same.
assert(rfTree.toString == dt.toString)
}
- test("Regression with continuous features:" +
+ test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ binaryClassificationTestWithContinuousFeatures(strategy)
+ }
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures = 50)
- val rdd = sc.parallelize(arr)
+ test("Binary classification with continuous features and node Id cache :" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
- val numTrees = 1
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ binaryClassificationTestWithContinuousFeatures(strategy)
+ }
- val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ def regressionTestWithContinuousFeatures(strategy: Strategy) {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
+ val rdd = sc.parallelize(arr)
+ val numTrees = 1
val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
- assert(rf.trees.size === 1)
- val rfTree = rf.trees(0)
+ assert(rf.weakHypotheses.size === 1)
+ val rfTree = rf.weakHypotheses(0)
val dt = DecisionTree.train(rdd, strategy)
- RandomForestSuite.validateRegressor(rf, arr, 0.01)
+ EnsembleTestHelper.validateRegressor(rf, arr, 0.01)
DecisionTreeSuite.validateRegressor(dt, arr, 0.01)
// Make sure trees are the same.
assert(rfTree.toString == dt.toString)
}
- test("Binary classification with continuous features: subsampling features") {
- val numFeatures = 50
- val arr = RandomForestSuite.generateOrderedLabeledPoints(numFeatures)
- val rdd = sc.parallelize(arr)
+ test("Regression with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Regression, impurity = Variance,
+ maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+ categoricalFeaturesInfo = categoricalFeaturesInfo)
+ regressionTestWithContinuousFeatures(strategy)
+ }
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ test("Regression with continuous features and node Id cache :" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Regression, impurity = Variance,
+ maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+ categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ regressionTestWithContinuousFeatures(strategy)
+ }
+
+ def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) {
+ val numFeatures = 50
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
+ val rdd = sc.parallelize(arr)
// Select feature subset for top nodes. Return true if OK.
def checkFeatureSubsetStrategy(
@@ -173,6 +165,20 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
}
+ test("Binary classification with continuous features: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
+ test("Binary classification with continuous features and node Id cache: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
test("alternating categorical and continuous features with multiclass labels to test indexing") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
@@ -186,77 +192,8 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 3, categoricalFeaturesInfo = categoricalFeaturesInfo)
val model = RandomForest.trainClassifier(input, strategy, numTrees = 2,
featureSubsetStrategy = "sqrt", seed = 12345)
- RandomForestSuite.validateClassifier(model, arr, 1.0)
+ EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
-
}
-object RandomForestSuite {
- /**
- * Aggregates all values in data, and tests whether the empirical mean and stddev are within
- * epsilon of the expected values.
- * @param data Every element of the data should be an i.i.d. sample from some distribution.
- */
- def testRandomArrays(
- data: Array[Array[Double]],
- numCols: Int,
- expectedMean: Double,
- expectedStddev: Double,
- epsilon: Double) {
- val values = new mutable.ArrayBuffer[Double]()
- data.foreach { row =>
- assert(row.size == numCols)
- values ++= row
- }
- val stats = new StatCounter(values)
- assert(math.abs(stats.mean - expectedMean) < epsilon)
- assert(math.abs(stats.stdev - expectedStddev) < epsilon)
- }
-
- def validateClassifier(
- model: RandomForestModel,
- input: Seq[LabeledPoint],
- requiredAccuracy: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val numOffPredictions = predictions.zip(input).count { case (prediction, expected) =>
- prediction != expected.label
- }
- val accuracy = (input.length - numOffPredictions).toDouble / input.length
- assert(accuracy >= requiredAccuracy,
- s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
- }
-
- def validateRegressor(
- model: RandomForestModel,
- input: Seq[LabeledPoint],
- requiredMSE: Double) {
- val predictions = input.map(x => model.predict(x.features))
- val squaredError = predictions.zip(input).map { case (prediction, expected) =>
- val err = prediction - expected.label
- err * err
- }.sum
- val mse = squaredError / input.length
- assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
- }
-
- def generateOrderedLabeledPoints(numFeatures: Int): Array[LabeledPoint] = {
- val numInstances = 1000
- val arr = new Array[LabeledPoint](numInstances)
- for (i <- 0 until numInstances) {
- val label = if (i < numInstances / 10) {
- 0.0
- } else if (i < numInstances / 2) {
- 1.0
- } else if (i < numInstances * 0.9) {
- 0.0
- } else {
- 1.0
- }
- val features = Array.fill[Double](numFeatures)(i.toDouble)
- arr(i) = new LabeledPoint(label, Vectors.dense(features))
- }
- arr
- }
-
-}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
new file mode 100644
index 0000000000000..5cb433232e714
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.tree.EnsembleTestHelper
+import org.apache.spark.mllib.util.LocalSparkContext
+
+/**
+ * Test suite for [[BaggedPoint]].
+ */
+class BaggedPointSuite extends FunSuite with LocalSparkContext {
+
+ test("BaggedPoint RDD: without subsampling") {
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42)
+ baggedRDD.collect().foreach { baggedPoint =>
+ assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") {
+ val numSubsamples = 100
+ val (expectedMean, expectedStddev) = (1.0, 1.0)
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") {
+ val numSubsamples = 100
+ val subsample = 0.5
+ val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample))
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") {
+ val numSubsamples = 100
+ val (expectedMean, expectedStddev) = (1.0, 0)
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+
+ test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") {
+ val numSubsamples = 100
+ val subsample = 0.5
+ val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample)))
+
+ val seeds = Array(123, 5354, 230, 349867, 23987)
+ val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000)
+ val rdd = sc.parallelize(arr)
+ seeds.foreach { seed =>
+ val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed)
+ val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect()
+ EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean,
+ expectedStddev, epsilon = 0.01)
+ }
+ }
+}
diff --git a/network/common/pom.xml b/network/common/pom.xml
new file mode 100644
index 0000000000000..ea887148d98ba
--- /dev/null
+++ b/network/common/pom.xml
@@ -0,0 +1,110 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent
+ 1.2.0-SNAPSHOT
+ ../../pom.xml
+
+
+ org.apache.spark
+ spark-network-common_2.10
+ jar
+ Spark Project Common Network Code
+ http://spark.apache.org/
+
+ network-common
+
+
+
+
+
+ io.netty
+ netty-all
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+ com.google.guava
+ guava
+ provided
+
+
+
+
+ junit
+ junit
+ test
+
+
+ com.novocode
+ junit-interface
+ test
+
+
+ log4j
+ log4j
+ test
+
+
+ org.mockito
+ mockito-all
+ test
+
+
+ org.scalatest
+ scalatest_${scala.binary.version}
+ test
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+
+ org.apache.maven.plugins
+ maven-jar-plugin
+ 2.2
+
+
+
+ test-jar
+
+
+
+ test-jar-on-test-compile
+ test-compile
+
+ test-jar
+
+
+
+
+
+
+
diff --git a/network/common/src/main/java/org/apache/spark/network/TransportContext.java b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
new file mode 100644
index 0000000000000..a271841e4e56c
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/TransportContext.java
@@ -0,0 +1,121 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network;
+
+import io.netty.channel.Channel;
+import io.netty.channel.socket.SocketChannel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.client.TransportClient;
+import org.apache.spark.network.client.TransportClientFactory;
+import org.apache.spark.network.client.TransportResponseHandler;
+import org.apache.spark.network.protocol.MessageDecoder;
+import org.apache.spark.network.protocol.MessageEncoder;
+import org.apache.spark.network.server.RpcHandler;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.server.TransportRequestHandler;
+import org.apache.spark.network.server.TransportServer;
+import org.apache.spark.network.server.StreamManager;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Contains the context to create a {@link TransportServer}, {@link TransportClientFactory}, and to
+ * setup Netty Channel pipelines with a {@link org.apache.spark.network.server.TransportChannelHandler}.
+ *
+ * There are two communication protocols that the TransportClient provides, control-plane RPCs and
+ * data-plane "chunk fetching". The handling of the RPCs is performed outside of the scope of the
+ * TransportContext (i.e., by a user-provided handler), and it is responsible for setting up streams
+ * which can be streamed through the data plane in chunks using zero-copy IO.
+ *
+ * The TransportServer and TransportClientFactory both create a TransportChannelHandler for each
+ * channel. As each TransportChannelHandler contains a TransportClient, this enables server
+ * processes to send messages back to the client on an existing channel.
+ */
+public class TransportContext {
+ private final Logger logger = LoggerFactory.getLogger(TransportContext.class);
+
+ private final TransportConf conf;
+ private final RpcHandler rpcHandler;
+
+ private final MessageEncoder encoder;
+ private final MessageDecoder decoder;
+
+ public TransportContext(TransportConf conf, RpcHandler rpcHandler) {
+ this.conf = conf;
+ this.rpcHandler = rpcHandler;
+ this.encoder = new MessageEncoder();
+ this.decoder = new MessageDecoder();
+ }
+
+ public TransportClientFactory createClientFactory() {
+ return new TransportClientFactory(this);
+ }
+
+ /** Create a server which will attempt to bind to a specific port. */
+ public TransportServer createServer(int port) {
+ return new TransportServer(this, port);
+ }
+
+ /** Creates a new server, binding to any available ephemeral port. */
+ public TransportServer createServer() {
+ return new TransportServer(this, 0);
+ }
+
+ /**
+ * Initializes a client or server Netty Channel Pipeline which encodes/decodes messages and
+ * has a {@link org.apache.spark.network.server.TransportChannelHandler} to handle request or
+ * response messages.
+ *
+ * @return Returns the created TransportChannelHandler, which includes a TransportClient that can
+ * be used to communicate on this channel. The TransportClient is directly associated with a
+ * ChannelHandler to ensure all users of the same channel get the same TransportClient object.
+ */
+ public TransportChannelHandler initializePipeline(SocketChannel channel) {
+ try {
+ TransportChannelHandler channelHandler = createChannelHandler(channel);
+ channel.pipeline()
+ .addLast("encoder", encoder)
+ .addLast("frameDecoder", NettyUtils.createFrameDecoder())
+ .addLast("decoder", decoder)
+ // NOTE: Chunks are currently guaranteed to be returned in the order of request, but this
+ // would require more logic to guarantee if this were not part of the same event loop.
+ .addLast("handler", channelHandler);
+ return channelHandler;
+ } catch (RuntimeException e) {
+ logger.error("Error while initializing Netty pipeline", e);
+ throw e;
+ }
+ }
+
+ /**
+ * Creates the server- and client-side handler which is used to handle both RequestMessages and
+ * ResponseMessages. The channel is expected to have been successfully created, though certain
+ * properties (such as the remoteAddress()) may not be available yet.
+ */
+ private TransportChannelHandler createChannelHandler(Channel channel) {
+ TransportResponseHandler responseHandler = new TransportResponseHandler(channel);
+ TransportClient client = new TransportClient(channel, responseHandler);
+ TransportRequestHandler requestHandler = new TransportRequestHandler(channel, client,
+ rpcHandler);
+ return new TransportChannelHandler(client, responseHandler, requestHandler);
+ }
+
+ public TransportConf getConf() { return conf; }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
new file mode 100644
index 0000000000000..89ed79bc63903
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.RandomAccessFile;
+import java.nio.ByteBuffer;
+import java.nio.channels.FileChannel;
+
+import com.google.common.base.Objects;
+import com.google.common.io.ByteStreams;
+import io.netty.channel.DefaultFileRegion;
+
+import org.apache.spark.network.util.JavaUtils;
+
+/**
+ * A {@link ManagedBuffer} backed by a segment in a file.
+ */
+public final class FileSegmentManagedBuffer extends ManagedBuffer {
+
+ /**
+ * Memory mapping is expensive and can destabilize the JVM (SPARK-1145, SPARK-3889).
+ * Avoid unless there's a good reason not to.
+ */
+ // TODO: Make this configurable
+ private static final long MIN_MEMORY_MAP_BYTES = 2 * 1024 * 1024;
+
+ private final File file;
+ private final long offset;
+ private final long length;
+
+ public FileSegmentManagedBuffer(File file, long offset, long length) {
+ this.file = file;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public long size() {
+ return length;
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ FileChannel channel = null;
+ try {
+ channel = new RandomAccessFile(file, "r").getChannel();
+ // Just copy the buffer if it's sufficiently small, as memory mapping has a high overhead.
+ if (length < MIN_MEMORY_MAP_BYTES) {
+ ByteBuffer buf = ByteBuffer.allocate((int) length);
+ channel.position(offset);
+ while (buf.remaining() != 0) {
+ if (channel.read(buf) == -1) {
+ throw new IOException(String.format("Reached EOF before filling buffer\n" +
+ "offset=%s\nfile=%s\nbuf.remaining=%s",
+ offset, file.getAbsoluteFile(), buf.remaining()));
+ }
+ }
+ buf.flip();
+ return buf;
+ } else {
+ return channel.map(FileChannel.MapMode.READ_ONLY, offset, length);
+ }
+ } catch (IOException e) {
+ try {
+ if (channel != null) {
+ long size = channel.size();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ }
+ throw new IOException("Error in opening " + this, e);
+ } finally {
+ JavaUtils.closeQuietly(channel);
+ }
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ FileInputStream is = null;
+ try {
+ is = new FileInputStream(file);
+ ByteStreams.skipFully(is, offset);
+ return ByteStreams.limit(is, length);
+ } catch (IOException e) {
+ try {
+ if (is != null) {
+ long size = file.length();
+ throw new IOException("Error in reading " + this + " (actual file length " + size + ")",
+ e);
+ }
+ } catch (IOException ignored) {
+ // ignore
+ } finally {
+ JavaUtils.closeQuietly(is);
+ }
+ throw new IOException("Error in opening " + this, e);
+ } catch (RuntimeException e) {
+ JavaUtils.closeQuietly(is);
+ throw e;
+ }
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ FileChannel fileChannel = new FileInputStream(file).getChannel();
+ return new DefaultFileRegion(fileChannel, offset, length);
+ }
+
+ public File getFile() { return file; }
+
+ public long getOffset() { return offset; }
+
+ public long getLength() { return length; }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("file", file)
+ .add("offset", offset)
+ .add("length", length)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
new file mode 100644
index 0000000000000..a415db593a788
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/ManagedBuffer.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+/**
+ * This interface provides an immutable view for data in the form of bytes. The implementation
+ * should specify how the data is provided:
+ *
+ * - {@link FileSegmentManagedBuffer}: data backed by part of a file
+ * - {@link NioManagedBuffer}: data backed by a NIO ByteBuffer
+ * - {@link NettyManagedBuffer}: data backed by a Netty ByteBuf
+ *
+ * The concrete buffer implementation might be managed outside the JVM garbage collector.
+ * For example, in the case of {@link NettyManagedBuffer}, the buffers are reference counted.
+ * In that case, if the buffer is going to be passed around to a different thread, retain/release
+ * should be called.
+ */
+public abstract class ManagedBuffer {
+
+ /** Number of bytes of the data. */
+ public abstract long size();
+
+ /**
+ * Exposes this buffer's data as an NIO ByteBuffer. Changing the position and limit of the
+ * returned ByteBuffer should not affect the content of this buffer.
+ */
+ // TODO: Deprecate this, usage may require expensive memory mapping or allocation.
+ public abstract ByteBuffer nioByteBuffer() throws IOException;
+
+ /**
+ * Exposes this buffer's data as an InputStream. The underlying implementation does not
+ * necessarily check for the length of bytes read, so the caller is responsible for making sure
+ * it does not go over the limit.
+ */
+ public abstract InputStream createInputStream() throws IOException;
+
+ /**
+ * Increment the reference count by one if applicable.
+ */
+ public abstract ManagedBuffer retain();
+
+ /**
+ * If applicable, decrement the reference count by one and deallocates the buffer if the
+ * reference count reaches zero.
+ */
+ public abstract ManagedBuffer release();
+
+ /**
+ * Convert the buffer into an Netty object, used to write the data out.
+ */
+ public abstract Object convertToNetty() throws IOException;
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
new file mode 100644
index 0000000000000..c806bfa45bef3
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NettyManagedBuffer.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.ByteBufInputStream;
+
+/**
+ * A {@link ManagedBuffer} backed by a Netty {@link ByteBuf}.
+ */
+public final class NettyManagedBuffer extends ManagedBuffer {
+ private final ByteBuf buf;
+
+ public NettyManagedBuffer(ByteBuf buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.readableBytes();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.nioBuffer();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(buf);
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ buf.retain();
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ buf.release();
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
new file mode 100644
index 0000000000000..f55b884bc45ce
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/buffer/NioManagedBuffer.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.buffer;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBufInputStream;
+import io.netty.buffer.Unpooled;
+
+/**
+ * A {@link ManagedBuffer} backed by {@link ByteBuffer}.
+ */
+public final class NioManagedBuffer extends ManagedBuffer {
+ private final ByteBuffer buf;
+
+ public NioManagedBuffer(ByteBuffer buf) {
+ this.buf = buf;
+ }
+
+ @Override
+ public long size() {
+ return buf.remaining();
+ }
+
+ @Override
+ public ByteBuffer nioByteBuffer() throws IOException {
+ return buf.duplicate();
+ }
+
+ @Override
+ public InputStream createInputStream() throws IOException {
+ return new ByteBufInputStream(Unpooled.wrappedBuffer(buf));
+ }
+
+ @Override
+ public ManagedBuffer retain() {
+ return this;
+ }
+
+ @Override
+ public ManagedBuffer release() {
+ return this;
+ }
+
+ @Override
+ public Object convertToNetty() throws IOException {
+ return Unpooled.wrappedBuffer(buf);
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("buf", buf)
+ .toString();
+ }
+}
+
diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
new file mode 100644
index 0000000000000..1fbdcd6780785
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkFetchFailureException.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * General exception caused by a remote exception while fetching a chunk.
+ */
+public class ChunkFetchFailureException extends RuntimeException {
+ public ChunkFetchFailureException(String errorMsg, Throwable cause) {
+ super(errorMsg, cause);
+ }
+
+ public ChunkFetchFailureException(String errorMsg) {
+ super(errorMsg);
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
new file mode 100644
index 0000000000000..519e6cb470d0d
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/ChunkReceivedCallback.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Callback for the result of a single chunk result. For a single stream, the callbacks are
+ * guaranteed to be called by the same thread in the same order as the requests for chunks were
+ * made.
+ *
+ * Note that if a general stream failure occurs, all outstanding chunk requests may be failed.
+ */
+public interface ChunkReceivedCallback {
+ /**
+ * Called upon receipt of a particular chunk.
+ *
+ * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this
+ * call returns. You must therefore either retain() the buffer or copy its contents before
+ * returning.
+ */
+ void onSuccess(int chunkIndex, ManagedBuffer buffer);
+
+ /**
+ * Called upon failure to fetch a particular chunk. Note that this may actually be called due
+ * to failure to fetch a prior chunk in this stream.
+ *
+ * After receiving a failure, the stream may or may not be valid. The client should not assume
+ * that the server's side of the stream has been closed.
+ */
+ void onFailure(int chunkIndex, Throwable e);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
new file mode 100644
index 0000000000000..6ec960d795420
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/RpcResponseCallback.java
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+/**
+ * Callback for the result of a single RPC. This will be invoked once with either success or
+ * failure.
+ */
+public interface RpcResponseCallback {
+ /** Successful serialized result from server. */
+ void onSuccess(byte[] response);
+
+ /** Exception either propagated from server or raised on client side. */
+ void onFailure(Throwable e);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
new file mode 100644
index 0000000000000..01c143fff423c
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.util.UUID;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import com.google.common.base.Preconditions;
+import com.google.common.base.Throwables;
+import com.google.common.util.concurrent.SettableFuture;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelFutureListener;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchRequest;
+import org.apache.spark.network.protocol.RpcRequest;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Client for fetching consecutive chunks of a pre-negotiated stream. This API is intended to allow
+ * efficient transfer of a large amount of data, broken up into chunks with size ranging from
+ * hundreds of KB to a few MB.
+ *
+ * Note that while this client deals with the fetching of chunks from a stream (i.e., data plane),
+ * the actual setup of the streams is done outside the scope of the transport layer. The convenience
+ * method "sendRPC" is provided to enable control plane communication between the client and server
+ * to perform this setup.
+ *
+ * For example, a typical workflow might be:
+ * client.sendRPC(new OpenFile("/foo")) --> returns StreamId = 100
+ * client.fetchChunk(streamId = 100, chunkIndex = 0, callback)
+ * client.fetchChunk(streamId = 100, chunkIndex = 1, callback)
+ * ...
+ * client.sendRPC(new CloseStream(100))
+ *
+ * Construct an instance of TransportClient using {@link TransportClientFactory}. A single
+ * TransportClient may be used for multiple streams, but any given stream must be restricted to a
+ * single client, in order to avoid out-of-order responses.
+ *
+ * NB: This class is used to make requests to the server, while {@link TransportResponseHandler} is
+ * responsible for handling responses from the server.
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportClient implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportClient.class);
+
+ private final Channel channel;
+ private final TransportResponseHandler handler;
+
+ public TransportClient(Channel channel, TransportResponseHandler handler) {
+ this.channel = Preconditions.checkNotNull(channel);
+ this.handler = Preconditions.checkNotNull(handler);
+ }
+
+ public boolean isActive() {
+ return channel.isOpen() || channel.isActive();
+ }
+
+ /**
+ * Requests a single chunk from the remote side, from the pre-negotiated streamId.
+ *
+ * Chunk indices go from 0 onwards. It is valid to request the same chunk multiple times, though
+ * some streams may not support this.
+ *
+ * Multiple fetchChunk requests may be outstanding simultaneously, and the chunks are guaranteed
+ * to be returned in the same order that they were requested, assuming only a single
+ * TransportClient is used to fetch the chunks.
+ *
+ * @param streamId Identifier that refers to a stream in the remote StreamManager. This should
+ * be agreed upon by client and server beforehand.
+ * @param chunkIndex 0-based index of the chunk to fetch
+ * @param callback Callback invoked upon successful receipt of chunk, or upon any failure.
+ */
+ public void fetchChunk(
+ long streamId,
+ final int chunkIndex,
+ final ChunkReceivedCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.debug("Sending fetch chunk request {} to {}", chunkIndex, serverAddr);
+
+ final StreamChunkId streamChunkId = new StreamChunkId(streamId, chunkIndex);
+ handler.addFetchRequest(streamChunkId, callback);
+
+ channel.writeAndFlush(new ChunkFetchRequest(streamChunkId)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", streamChunkId, serverAddr,
+ timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send request %s to %s: %s", streamChunkId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeFetchRequest(streamChunkId);
+ callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
+ channel.close();
+ }
+ }
+ });
+ }
+
+ /**
+ * Sends an opaque message to the RpcHandler on the server-side. The callback will be invoked
+ * with the server's response or upon any failure.
+ */
+ public void sendRpc(byte[] message, final RpcResponseCallback callback) {
+ final String serverAddr = NettyUtils.getRemoteAddress(channel);
+ final long startTime = System.currentTimeMillis();
+ logger.trace("Sending RPC to {}", serverAddr);
+
+ final long requestId = Math.abs(UUID.randomUUID().getLeastSignificantBits());
+ handler.addRpcRequest(requestId, callback);
+
+ channel.writeAndFlush(new RpcRequest(requestId, message)).addListener(
+ new ChannelFutureListener() {
+ @Override
+ public void operationComplete(ChannelFuture future) throws Exception {
+ if (future.isSuccess()) {
+ long timeTaken = System.currentTimeMillis() - startTime;
+ logger.trace("Sending request {} to {} took {} ms", requestId, serverAddr, timeTaken);
+ } else {
+ String errorMsg = String.format("Failed to send RPC %s to %s: %s", requestId,
+ serverAddr, future.cause());
+ logger.error(errorMsg, future.cause());
+ handler.removeRpcRequest(requestId);
+ callback.onFailure(new RuntimeException(errorMsg, future.cause()));
+ channel.close();
+ }
+ }
+ });
+ }
+
+ /**
+ * Synchronously sends an opaque message to the RpcHandler on the server-side, waiting for up to
+ * a specified timeout for a response.
+ */
+ public byte[] sendRpcSync(byte[] message, long timeoutMs) {
+ final SettableFuture result = SettableFuture.create();
+
+ sendRpc(message, new RpcResponseCallback() {
+ @Override
+ public void onSuccess(byte[] response) {
+ result.set(response);
+ }
+
+ @Override
+ public void onFailure(Throwable e) {
+ result.setException(e);
+ }
+ });
+
+ try {
+ return result.get(timeoutMs, TimeUnit.MILLISECONDS);
+ } catch (Exception e) {
+ throw Throwables.propagate(e);
+ }
+ }
+
+ @Override
+ public void close() {
+ // close is a local operation and should finish with milliseconds; timeout just to be safe
+ channel.close().awaitUninterruptibly(10, TimeUnit.SECONDS);
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
new file mode 100644
index 0000000000000..0b4a1d8286407
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.io.Closeable;
+import java.lang.reflect.Field;
+import java.net.InetSocketAddress;
+import java.net.SocketAddress;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicReference;
+
+import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.channel.Channel;
+import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelInitializer;
+import io.netty.channel.ChannelOption;
+import io.netty.channel.EventLoopGroup;
+import io.netty.channel.socket.SocketChannel;
+import io.netty.util.internal.PlatformDependent;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.TransportContext;
+import org.apache.spark.network.server.TransportChannelHandler;
+import org.apache.spark.network.util.IOMode;
+import org.apache.spark.network.util.NettyUtils;
+import org.apache.spark.network.util.TransportConf;
+
+/**
+ * Factory for creating {@link TransportClient}s by using createClient.
+ *
+ * The factory maintains a connection pool to other hosts and should return the same
+ * {@link TransportClient} for the same remote host. It also shares a single worker thread pool for
+ * all {@link TransportClient}s.
+ */
+public class TransportClientFactory implements Closeable {
+ private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
+
+ private final TransportContext context;
+ private final TransportConf conf;
+ private final ConcurrentHashMap connectionPool;
+
+ private final Class extends Channel> socketChannelClass;
+ private EventLoopGroup workerGroup;
+
+ public TransportClientFactory(TransportContext context) {
+ this.context = context;
+ this.conf = context.getConf();
+ this.connectionPool = new ConcurrentHashMap();
+
+ IOMode ioMode = IOMode.valueOf(conf.ioMode());
+ this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
+ // TODO: Make thread pool name configurable.
+ this.workerGroup = NettyUtils.createEventLoop(ioMode, conf.clientThreads(), "shuffle-client");
+ }
+
+ /**
+ * Create a new BlockFetchingClient connecting to the given remote host / port.
+ *
+ * This blocks until a connection is successfully established.
+ *
+ * Concurrency: This method is safe to call from multiple threads.
+ */
+ public TransportClient createClient(String remoteHost, int remotePort) {
+ // Get connection from the connection pool first.
+ // If it is not found or not active, create a new one.
+ final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
+ TransportClient cachedClient = connectionPool.get(address);
+ if (cachedClient != null) {
+ if (cachedClient.isActive()) {
+ return cachedClient;
+ } else {
+ connectionPool.remove(address, cachedClient); // Remove inactive clients.
+ }
+ }
+
+ logger.debug("Creating new connection to " + address);
+
+ Bootstrap bootstrap = new Bootstrap();
+ bootstrap.group(workerGroup)
+ .channel(socketChannelClass)
+ // Disable Nagle's Algorithm since we don't want packets to wait
+ .option(ChannelOption.TCP_NODELAY, true)
+ .option(ChannelOption.SO_KEEPALIVE, true)
+ .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs());
+
+ // Use pooled buffers to reduce temporary buffer allocation
+ bootstrap.option(ChannelOption.ALLOCATOR, createPooledByteBufAllocator());
+
+ final AtomicReference client = new AtomicReference();
+
+ bootstrap.handler(new ChannelInitializer() {
+ @Override
+ public void initChannel(SocketChannel ch) {
+ TransportChannelHandler clientHandler = context.initializePipeline(ch);
+ client.set(clientHandler.getClient());
+ }
+ });
+
+ // Connect to the remote server
+ ChannelFuture cf = bootstrap.connect(address);
+ if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
+ throw new RuntimeException(
+ String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
+ } else if (cf.cause() != null) {
+ throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
+ }
+
+ // Successful connection -- in the event that two threads raced to create a client, we will
+ // use the first one that was put into the connectionPool and close the one we made here.
+ assert client.get() != null : "Channel future completed successfully with null client";
+ TransportClient oldClient = connectionPool.putIfAbsent(address, client.get());
+ if (oldClient == null) {
+ return client.get();
+ } else {
+ logger.debug("Two clients were created concurrently, second one will be disposed.");
+ client.get().close();
+ return oldClient;
+ }
+ }
+
+ /** Close all connections in the connection pool, and shutdown the worker thread pool. */
+ @Override
+ public void close() {
+ for (TransportClient client : connectionPool.values()) {
+ try {
+ client.close();
+ } catch (RuntimeException e) {
+ logger.warn("Ignoring exception during close", e);
+ }
+ }
+ connectionPool.clear();
+
+ if (workerGroup != null) {
+ workerGroup.shutdownGracefully();
+ workerGroup = null;
+ }
+ }
+
+ /**
+ * Create a pooled ByteBuf allocator but disables the thread-local cache. Thread-local caches
+ * are disabled because the ByteBufs are allocated by the event loop thread, but released by the
+ * executor thread rather than the event loop thread. Those thread-local caches actually delay
+ * the recycling of buffers, leading to larger memory usage.
+ */
+ private PooledByteBufAllocator createPooledByteBufAllocator() {
+ return new PooledByteBufAllocator(
+ PlatformDependent.directBufferPreferred(),
+ getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
+ getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
+ getPrivateStaticField("DEFAULT_PAGE_SIZE"),
+ getPrivateStaticField("DEFAULT_MAX_ORDER"),
+ 0, // tinyCacheSize
+ 0, // smallCacheSize
+ 0 // normalCacheSize
+ );
+ }
+
+ /** Used to get defaults from Netty's private static fields. */
+ private int getPrivateStaticField(String name) {
+ try {
+ Field f = PooledByteBufAllocator.DEFAULT.getClass().getDeclaredField(name);
+ f.setAccessible(true);
+ return f.getInt(null);
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
new file mode 100644
index 0000000000000..d8965590b34da
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+
+import com.google.common.annotations.VisibleForTesting;
+import io.netty.channel.Channel;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.spark.network.protocol.ChunkFetchFailure;
+import org.apache.spark.network.protocol.ChunkFetchSuccess;
+import org.apache.spark.network.protocol.ResponseMessage;
+import org.apache.spark.network.protocol.RpcFailure;
+import org.apache.spark.network.protocol.RpcResponse;
+import org.apache.spark.network.protocol.StreamChunkId;
+import org.apache.spark.network.server.MessageHandler;
+import org.apache.spark.network.util.NettyUtils;
+
+/**
+ * Handler that processes server responses, in response to requests issued from a
+ * [[TransportClient]]. It works by tracking the list of outstanding requests (and their callbacks).
+ *
+ * Concurrency: thread safe and can be called from multiple threads.
+ */
+public class TransportResponseHandler extends MessageHandler {
+ private final Logger logger = LoggerFactory.getLogger(TransportResponseHandler.class);
+
+ private final Channel channel;
+
+ private final Map outstandingFetches;
+
+ private final Map outstandingRpcs;
+
+ public TransportResponseHandler(Channel channel) {
+ this.channel = channel;
+ this.outstandingFetches = new ConcurrentHashMap();
+ this.outstandingRpcs = new ConcurrentHashMap();
+ }
+
+ public void addFetchRequest(StreamChunkId streamChunkId, ChunkReceivedCallback callback) {
+ outstandingFetches.put(streamChunkId, callback);
+ }
+
+ public void removeFetchRequest(StreamChunkId streamChunkId) {
+ outstandingFetches.remove(streamChunkId);
+ }
+
+ public void addRpcRequest(long requestId, RpcResponseCallback callback) {
+ outstandingRpcs.put(requestId, callback);
+ }
+
+ public void removeRpcRequest(long requestId) {
+ outstandingRpcs.remove(requestId);
+ }
+
+ /**
+ * Fire the failure callback for all outstanding requests. This is called when we have an
+ * uncaught exception or pre-mature connection termination.
+ */
+ private void failOutstandingRequests(Throwable cause) {
+ for (Map.Entry entry : outstandingFetches.entrySet()) {
+ entry.getValue().onFailure(entry.getKey().chunkIndex, cause);
+ }
+ for (Map.Entry entry : outstandingRpcs.entrySet()) {
+ entry.getValue().onFailure(cause);
+ }
+
+ // It's OK if new fetches appear, as they will fail immediately.
+ outstandingFetches.clear();
+ outstandingRpcs.clear();
+ }
+
+ @Override
+ public void channelUnregistered() {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
+ }
+ }
+
+ @Override
+ public void exceptionCaught(Throwable cause) {
+ if (numOutstandingRequests() > 0) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ logger.error("Still have {} requests outstanding when connection from {} is closed",
+ numOutstandingRequests(), remoteAddress);
+ failOutstandingRequests(cause);
+ }
+ }
+
+ @Override
+ public void handle(ResponseMessage message) {
+ String remoteAddress = NettyUtils.getRemoteAddress(channel);
+ if (message instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess resp = (ChunkFetchSuccess) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} since it is not outstanding",
+ resp.streamChunkId, remoteAddress);
+ resp.buffer.release();
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onSuccess(resp.streamChunkId.chunkIndex, resp.buffer);
+ resp.buffer.release();
+ }
+ } else if (message instanceof ChunkFetchFailure) {
+ ChunkFetchFailure resp = (ChunkFetchFailure) message;
+ ChunkReceivedCallback listener = outstandingFetches.get(resp.streamChunkId);
+ if (listener == null) {
+ logger.warn("Ignoring response for block {} from {} ({}) since it is not outstanding",
+ resp.streamChunkId, remoteAddress, resp.errorString);
+ } else {
+ outstandingFetches.remove(resp.streamChunkId);
+ listener.onFailure(resp.streamChunkId.chunkIndex, new ChunkFetchFailureException(
+ "Failure while fetching " + resp.streamChunkId + ": " + resp.errorString));
+ }
+ } else if (message instanceof RpcResponse) {
+ RpcResponse resp = (RpcResponse) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({} bytes) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.response.length);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onSuccess(resp.response);
+ }
+ } else if (message instanceof RpcFailure) {
+ RpcFailure resp = (RpcFailure) message;
+ RpcResponseCallback listener = outstandingRpcs.get(resp.requestId);
+ if (listener == null) {
+ logger.warn("Ignoring response for RPC {} from {} ({}) since it is not outstanding",
+ resp.requestId, remoteAddress, resp.errorString);
+ } else {
+ outstandingRpcs.remove(resp.requestId);
+ listener.onFailure(new RuntimeException(resp.errorString));
+ }
+ } else {
+ throw new IllegalStateException("Unknown response type: " + message.type());
+ }
+ }
+
+ /** Returns total number of outstanding requests (fetch requests + rpcs) */
+ @VisibleForTesting
+ public int numOutstandingRequests() {
+ return outstandingFetches.size() + outstandingRpcs.size();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
new file mode 100644
index 0000000000000..152af98ced7ce
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchFailure.java
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Charsets;
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Response to {@link ChunkFetchRequest} when there is an error fetching the chunk.
+ */
+public final class ChunkFetchFailure implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final String errorString;
+
+ public ChunkFetchFailure(StreamChunkId streamChunkId, String errorString) {
+ this.streamChunkId = streamChunkId;
+ this.errorString = errorString;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchFailure; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength() + 4 + errorString.getBytes(Charsets.UTF_8).length;
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ byte[] errorBytes = errorString.getBytes(Charsets.UTF_8);
+ buf.writeInt(errorBytes.length);
+ buf.writeBytes(errorBytes);
+ }
+
+ public static ChunkFetchFailure decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ int numErrorStringBytes = buf.readInt();
+ byte[] errorBytes = new byte[numErrorStringBytes];
+ buf.readBytes(errorBytes);
+ return new ChunkFetchFailure(streamChunkId, new String(errorBytes, Charsets.UTF_8));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchFailure) {
+ ChunkFetchFailure o = (ChunkFetchFailure) other;
+ return streamChunkId.equals(o.streamChunkId) && errorString.equals(o.errorString);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("errorString", errorString)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
new file mode 100644
index 0000000000000..980947cf13f6b
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchRequest.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Request to fetch a sequence of a single chunk of a stream. This will correspond to a single
+ * {@link org.apache.spark.network.protocol.ResponseMessage} (either success or failure).
+ */
+public final class ChunkFetchRequest implements RequestMessage {
+ public final StreamChunkId streamChunkId;
+
+ public ChunkFetchRequest(StreamChunkId streamChunkId) {
+ this.streamChunkId = streamChunkId;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchRequest; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ public static ChunkFetchRequest decode(ByteBuf buf) {
+ return new ChunkFetchRequest(StreamChunkId.decode(buf));
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchRequest) {
+ ChunkFetchRequest o = (ChunkFetchRequest) other;
+ return streamChunkId.equals(o.streamChunkId);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
new file mode 100644
index 0000000000000..ff4936470c697
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/ChunkFetchSuccess.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+import org.apache.spark.network.buffer.NettyManagedBuffer;
+
+/**
+ * Response to {@link ChunkFetchRequest} when a chunk exists and has been successfully fetched.
+ *
+ * Note that the server-side encoding of this messages does NOT include the buffer itself, as this
+ * may be written by Netty in a more efficient manner (i.e., zero-copy write).
+ * Similarly, the client-side decoding will reuse the Netty ByteBuf as the buffer.
+ */
+public final class ChunkFetchSuccess implements ResponseMessage {
+ public final StreamChunkId streamChunkId;
+ public final ManagedBuffer buffer;
+
+ public ChunkFetchSuccess(StreamChunkId streamChunkId, ManagedBuffer buffer) {
+ this.streamChunkId = streamChunkId;
+ this.buffer = buffer;
+ }
+
+ @Override
+ public Type type() { return Type.ChunkFetchSuccess; }
+
+ @Override
+ public int encodedLength() {
+ return streamChunkId.encodedLength();
+ }
+
+ /** Encoding does NOT include 'buffer' itself. See {@link MessageEncoder}. */
+ @Override
+ public void encode(ByteBuf buf) {
+ streamChunkId.encode(buf);
+ }
+
+ /** Decoding uses the given ByteBuf as our data, and will retain() it. */
+ public static ChunkFetchSuccess decode(ByteBuf buf) {
+ StreamChunkId streamChunkId = StreamChunkId.decode(buf);
+ buf.retain();
+ NettyManagedBuffer managedBuf = new NettyManagedBuffer(buf.duplicate());
+ return new ChunkFetchSuccess(streamChunkId, managedBuf);
+ }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other instanceof ChunkFetchSuccess) {
+ ChunkFetchSuccess o = (ChunkFetchSuccess) other;
+ return streamChunkId.equals(o.streamChunkId) && buffer.equals(o.buffer);
+ }
+ return false;
+ }
+
+ @Override
+ public String toString() {
+ return Objects.toStringHelper(this)
+ .add("streamChunkId", streamChunkId)
+ .add("buffer", buffer)
+ .toString();
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
new file mode 100644
index 0000000000000..b4e299471b41a
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Encodable.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/**
+ * Interface for an object which can be encoded into a ByteBuf. Multiple Encodable objects are
+ * stored in a single, pre-allocated ByteBuf, so Encodables must also provide their length.
+ *
+ * Encodable objects should provide a static "decode(ByteBuf)" method which is invoked by
+ * {@link MessageDecoder}. During decoding, if the object uses the ByteBuf as its data (rather than
+ * just copying data from it), then you must retain() the ByteBuf.
+ *
+ * Additionally, when adding a new Encodable Message, add it to {@link Message.Type}.
+ */
+public interface Encodable {
+ /** Number of bytes of the encoded form of this object. */
+ int encodedLength();
+
+ /**
+ * Serializes this object by writing into the given ByteBuf.
+ * This method must write exactly encodedLength() bytes.
+ */
+ void encode(ByteBuf buf);
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/Message.java b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
new file mode 100644
index 0000000000000..d568370125fd4
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/Message.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import io.netty.buffer.ByteBuf;
+
+/** An on-the-wire transmittable message. */
+public interface Message extends Encodable {
+ /** Used to identify this request type. */
+ Type type();
+
+ /** Preceding every serialized Message is its type, which allows us to deserialize it. */
+ public static enum Type implements Encodable {
+ ChunkFetchRequest(0), ChunkFetchSuccess(1), ChunkFetchFailure(2),
+ RpcRequest(3), RpcResponse(4), RpcFailure(5);
+
+ private final byte id;
+
+ private Type(int id) {
+ assert id < 128 : "Cannot have more than 128 message types";
+ this.id = (byte) id;
+ }
+
+ public byte id() { return id; }
+
+ @Override public int encodedLength() { return 1; }
+
+ @Override public void encode(ByteBuf buf) { buf.writeByte(id); }
+
+ public static Type decode(ByteBuf buf) {
+ byte id = buf.readByte();
+ switch (id) {
+ case 0: return ChunkFetchRequest;
+ case 1: return ChunkFetchSuccess;
+ case 2: return ChunkFetchFailure;
+ case 3: return RpcRequest;
+ case 4: return RpcResponse;
+ case 5: return RpcFailure;
+ default: throw new IllegalArgumentException("Unknown message type: " + id);
+ }
+ }
+ }
+}
diff --git a/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
new file mode 100644
index 0000000000000..81f8d7f96350f
--- /dev/null
+++ b/network/common/src/main/java/org/apache/spark/network/protocol/MessageDecoder.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import java.util.List;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.channel.ChannelHandler;
+import io.netty.channel.ChannelHandlerContext;
+import io.netty.handler.codec.MessageToMessageDecoder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Decoder used by the client side to encode server-to-client responses.
+ * This encoder is stateless so it is safe to be shared by multiple threads.
+ */
+@ChannelHandler.Sharable
+public final class MessageDecoder extends MessageToMessageDecoder {
+
+ private final Logger logger = LoggerFactory.getLogger(MessageDecoder.class);
+ @Override
+ public void decode(ChannelHandlerContext ctx, ByteBuf in, List