diff --git a/LICENSE b/LICENSE
index b771bd552b762..150ccc54ec6c2 100644
--- a/LICENSE
+++ b/LICENSE
@@ -222,7 +222,7 @@ Python Software Foundation License
----------------------------------
pyspark/heapq3.py
-
+python/docs/_static/copybutton.js
BSD 3-Clause
------------
@@ -258,4 +258,4 @@ data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg
data/mllib/images/kittens/54893.jpg
data/mllib/images/kittens/DP153539.jpg
data/mllib/images/kittens/DP802813.jpg
-data/mllib/images/multi-channel/chr30.4.184.jpg
\ No newline at end of file
+data/mllib/images/multi-channel/chr30.4.184.jpg
diff --git a/LICENSE-binary b/LICENSE-binary
index 5f57133bef43d..2ff881fac5fb0 100644
--- a/LICENSE-binary
+++ b/LICENSE-binary
@@ -302,7 +302,6 @@ com.google.code.gson:gson
com.google.inject:guice
com.google.inject.extensions:guice-servlet
com.twitter:parquet-hadoop-bundle
-commons-beanutils:commons-beanutils-core
commons-cli:commons-cli
commons-dbcp:commons-dbcp
commons-io:commons-io
@@ -490,7 +489,6 @@ Eclipse Distribution License (EDL) 1.0
org.glassfish.jaxb:jaxb-runtime
jakarta.xml.bind:jakarta.xml.bind-api
com.sun.istack:istack-commons-runtime
-jakarta.activation:jakarta.activation-api
Mozilla Public License (MPL) 1.1
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 0566a47cc8755..3bd1f544d77a5 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -3589,6 +3589,8 @@ setMethod("element_at",
#' @details
#' \code{explode}: Creates a new row for each element in the given array or map column.
+#' Uses the default column name \code{col} for elements in the array and
+#' \code{key} and \code{value} for elements in the map unless specified otherwise.
#'
#' @rdname column_collection_functions
#' @aliases explode explode,Column-method
@@ -3649,7 +3651,9 @@ setMethod("sort_array",
#' @details
#' \code{posexplode}: Creates a new row for each element with position in the given array
-#' or map column.
+#' or map column. Uses the default column name \code{pos} for position, and \code{col}
+#' for elements in the array and \code{key} and \code{value} for elements in the map
+#' unless specified otherwise.
#'
#' @rdname column_collection_functions
#' @aliases posexplode posexplode,Column-method
@@ -3790,7 +3794,8 @@ setMethod("repeat_string",
#' \code{explode}: Creates a new row for each element in the given array or map column.
#' Unlike \code{explode}, if the array/map is \code{null} or empty
#' then \code{null} is produced.
-#'
+#' Uses the default column name \code{col} for elements in the array and
+#' \code{key} and \code{value} for elements in the map unless specified otherwise.
#'
#' @rdname column_collection_functions
#' @aliases explode_outer explode_outer,Column-method
@@ -3815,6 +3820,9 @@ setMethod("explode_outer",
#' \code{posexplode_outer}: Creates a new row for each element with position in the given
#' array or map column. Unlike \code{posexplode}, if the array/map is \code{null} or empty
#' then the row (\code{null}, \code{null}) is produced.
+#' Uses the default column name \code{pos} for position, and \code{col}
+#' for elements in the array and \code{key} and \code{value} for elements in the map
+#' unless specified otherwise.
#'
#' @rdname column_collection_functions
#' @aliases posexplode_outer posexplode_outer,Column-method
diff --git a/README.md b/README.md
index 271f2f5f5b1c3..482c00764380a 100644
--- a/README.md
+++ b/README.md
@@ -1,18 +1,18 @@
# Apache Spark
-[![Jenkins Build](https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/badge/icon)](https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7)
-[![AppVeyor Build](https://img.shields.io/appveyor/ci/ApacheSoftwareFoundation/spark/master.svg?style=plastic&logo=appveyor)](https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark)
-[![PySpark Coverage](https://img.shields.io/badge/dynamic/xml.svg?label=pyspark%20coverage&url=https%3A%2F%2Fspark-test.github.io%2Fpyspark-coverage-site&query=%2Fhtml%2Fbody%2Fdiv%5B1%5D%2Fdiv%2Fh1%2Fspan&colorB=brightgreen&style=plastic)](https://spark-test.github.io/pyspark-coverage-site)
-
-Spark is a fast and general cluster computing system for Big Data. It provides
+Spark is a unified analytics engine for large-scale data processing. It provides
high-level APIs in Scala, Java, Python, and R, and an optimized engine that
supports general computation graphs for data analysis. It also supports a
rich set of higher-level tools including Spark SQL for SQL and DataFrames,
MLlib for machine learning, GraphX for graph processing,
-and Spark Streaming for stream processing.
+and Structured Streaming for stream processing.
+[![Jenkins Build](https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/badge/icon)](https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7)
+[![AppVeyor Build](https://img.shields.io/appveyor/ci/ApacheSoftwareFoundation/spark/master.svg?style=plastic&logo=appveyor)](https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark)
+[![PySpark Coverage](https://img.shields.io/badge/dynamic/xml.svg?label=pyspark%20coverage&url=https%3A%2F%2Fspark-test.github.io%2Fpyspark-coverage-site&query=%2Fhtml%2Fbody%2Fdiv%5B1%5D%2Fdiv%2Fh1%2Fspan&colorB=brightgreen&style=plastic)](https://spark-test.github.io/pyspark-coverage-site)
+
## Online Documentation
@@ -41,9 +41,9 @@ The easiest way to start using Spark is through the Scala shell:
./bin/spark-shell
-Try the following command, which should return 1000:
+Try the following command, which should return 1,000,000,000:
- scala> sc.parallelize(1 to 1000).count()
+ scala> spark.range(1000 * 1000 * 1000).count()
## Interactive Python Shell
@@ -51,9 +51,9 @@ Alternatively, if you prefer Python, you can use the Python shell:
./bin/pyspark
-And run the following command, which should also return 1000:
+And run the following command, which should also return 1,000,000,000:
- >>> sc.parallelize(range(1000)).count()
+ >>> spark.range(1000 * 1000 * 1000).count()
## Example Programs
diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh
index 0388e23979dda..68fafbb848001 100755
--- a/bin/docker-image-tool.sh
+++ b/bin/docker-image-tool.sh
@@ -282,7 +282,7 @@ do
if ! minikube status 1>/dev/null; then
error "Cannot contact minikube. Make sure it's running."
fi
- eval $(minikube docker-env)
+ eval $(minikube docker-env --shell bash)
;;
u) SPARK_UID=${OPTARG};;
esac
diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
index a530e16734db4..6f90df5f611a9 100644
--- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
+++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java
@@ -46,9 +46,9 @@
*/
public class RetryingBlockFetcherSuite {
- ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
- ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
- ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
+ private final ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13]));
+ private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7]));
+ private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19]));
@Test
public void testNoFailures() throws IOException, InterruptedException {
@@ -291,7 +291,7 @@ private static void performInteractions(List extends Map> inte
}
assertNotNull(stub);
- stub.when(fetchStarter).createAndStart(any(), anyObject());
+ stub.when(fetchStarter).createAndStart(any(), any());
String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]);
new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start();
}
diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml
index 55cdc3140aa08..c642f3b4a1600 100644
--- a/common/network-yarn/pom.xml
+++ b/common/network-yarn/pom.xml
@@ -35,7 +35,7 @@
provided
${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar
- org/spark_project/
+ org/sparkproject/
@@ -128,6 +128,50 @@
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ regex-property
+
+ regex-property
+
+
+ spark.shade.native.packageName
+ ${spark.shade.packageName}
+ \.
+ _
+ true
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-antrun-plugin
+
+
+ unpack
+ package
+
+
+
+
+
+
+
+
+
+
+ run
+
+
+
+
diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
index 621f2c6bf3777..e36efa3b0f22b 100644
--- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
+++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java
@@ -18,6 +18,7 @@
package org.apache.spark.unsafe.types;
import java.io.Serializable;
+import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -66,6 +67,10 @@ private static long toLong(String s) {
}
}
+ /**
+ * Convert a string to CalendarInterval. Return null if the input string is not a valid interval.
+ * This method is case-sensitive and all characters in the input string should be in lower case.
+ */
public static CalendarInterval fromString(String s) {
if (s == null) {
return null;
@@ -87,6 +92,26 @@ public static CalendarInterval fromString(String s) {
}
}
+ /**
+ * Convert a string to CalendarInterval. Unlike fromString, this method is case-insensitive and
+ * will throw IllegalArgumentException when the input string is not a valid interval.
+ *
+ * @throws IllegalArgumentException if the string is not a valid internal.
+ */
+ public static CalendarInterval fromCaseInsensitiveString(String s) {
+ if (s == null || s.trim().isEmpty()) {
+ throw new IllegalArgumentException("Interval cannot be null or blank.");
+ }
+ String sInLowerCase = s.trim().toLowerCase(Locale.ROOT);
+ String interval =
+ sInLowerCase.startsWith("interval ") ? sInLowerCase : "interval " + sInLowerCase;
+ CalendarInterval cal = fromString(interval);
+ if (cal == null) {
+ throw new IllegalArgumentException("Invalid interval: " + s);
+ }
+ return cal;
+ }
+
public static long toLongWithRange(String fieldName,
String s, long minValue, long maxValue) throws IllegalArgumentException {
long result = 0;
@@ -319,6 +344,8 @@ public String toString() {
appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond");
rest %= MICROS_PER_MILLI;
appendUnit(sb, rest, "microsecond");
+ } else if (months == 0) {
+ sb.append(" 0 microseconds");
}
return sb.toString();
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
index 9e69e264ff287..994af8f082447 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
@@ -41,6 +41,9 @@ public void equalsTest() {
public void toStringTest() {
CalendarInterval i;
+ i = new CalendarInterval(0, 0);
+ assertEquals("interval 0 microseconds", i.toString());
+
i = new CalendarInterval(34, 0);
assertEquals("interval 2 years 10 months", i.toString());
@@ -101,6 +104,31 @@ public void fromStringTest() {
assertNull(fromString(input));
}
+ @Test
+ public void fromCaseInsensitiveStringTest() {
+ for (String input : new String[]{"5 MINUTES", "5 minutes", "5 Minutes"}) {
+ assertEquals(fromCaseInsensitiveString(input), new CalendarInterval(0, 5L * 60 * 1_000_000));
+ }
+
+ for (String input : new String[]{null, "", " "}) {
+ try {
+ fromCaseInsensitiveString(input);
+ fail("Expected to throw an exception for the invalid input");
+ } catch (IllegalArgumentException e) {
+ assertTrue(e.getMessage().contains("cannot be null or blank"));
+ }
+ }
+
+ for (String input : new String[]{"interval", "interval1 day", "foo", "foo 1 day"}) {
+ try {
+ fromCaseInsensitiveString(input);
+ fail("Expected to throw an exception for the invalid input");
+ } catch (IllegalArgumentException e) {
+ assertTrue(e.getMessage().contains("Invalid interval"));
+ }
+ }
+ }
+
@Test
public void fromYearMonthStringTest() {
String input;
diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template
index ec1aa187dfb32..e91595dd324b0 100644
--- a/conf/log4j.properties.template
+++ b/conf/log4j.properties.template
@@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}:
log4j.logger.org.apache.spark.repl.Main=WARN
# Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark_project.jetty=WARN
-log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.sparkproject.jetty=WARN
+log4j.logger.org.sparkproject.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
log4j.logger.org.apache.parquet=ERROR
diff --git a/core/pom.xml b/core/pom.xml
index 45bda44916f52..9d57028019880 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -347,7 +347,7 @@
net.razorvine
pyrolite
- 4.13
+ 4.23
net.razorvine
diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
index 277010015072a..eb12848900b58 100644
--- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties
+++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties
@@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}:
log4j.logger.org.apache.spark.repl.Main=WARN
# Settings to quiet third party logs that are too verbose
-log4j.logger.org.spark_project.jetty=WARN
-log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR
+log4j.logger.org.sparkproject.jetty=WARN
+log4j.logger.org.sparkproject.jetty.util.component.AbstractLifeCycle=ERROR
log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO
log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO
diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
index 2d842b98ead43..a354f44a1be19 100644
--- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
+++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
@@ -20,7 +20,6 @@ package org.apache.spark
import java.util.{Properties, Timer, TimerTask}
import scala.concurrent.duration._
-import scala.language.postfixOps
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
@@ -122,7 +121,7 @@ class BarrierTaskContext private[spark] (
barrierEpoch),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
- timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout"))
+ timeout = new RpcTimeout(365.days, "barrierTimeout"))
barrierEpoch += 1
logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " +
"global sync successfully, waited for " +
diff --git a/core/src/main/scala/org/apache/spark/ResourceDiscoverer.scala b/core/src/main/scala/org/apache/spark/ResourceDiscoverer.scala
new file mode 100644
index 0000000000000..19639420b8b9f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ResourceDiscoverer.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.File
+
+import com.fasterxml.jackson.core.JsonParseException
+import org.json4s.{DefaultFormats, MappingException}
+import org.json4s.JsonAST.JValue
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.util.Utils.executeAndGetOutput
+
+/**
+ * Discovers resources (GPUs/FPGAs/etc). It currently only supports resources that have
+ * addresses.
+ * This class finds resources by running and parsing the output of the user specified script
+ * from the config spark.{driver/executor}.resource.{resourceType}.discoveryScript.
+ * The output of the script it runs is expected to be JSON in the format of the
+ * ResourceInformation class.
+ *
+ * For example: {"name": "gpu", "addresses": ["0","1"]}
+ */
+private[spark] object ResourceDiscoverer extends Logging {
+
+ private implicit val formats = DefaultFormats
+
+ def findResources(sparkConf: SparkConf, isDriver: Boolean): Map[String, ResourceInformation] = {
+ val prefix = if (isDriver) {
+ SPARK_DRIVER_RESOURCE_PREFIX
+ } else {
+ SPARK_EXECUTOR_RESOURCE_PREFIX
+ }
+ // get unique resource types by grabbing first part config with multiple periods,
+ // ie resourceType.count, grab resourceType part
+ val resourceNames = sparkConf.getAllWithPrefix(prefix).map { case (k, _) =>
+ k.split('.').head
+ }.toSet
+ resourceNames.map { rName => {
+ val rInfo = getResourceInfoForType(sparkConf, prefix, rName)
+ (rName -> rInfo)
+ }}.toMap
+ }
+
+ private def getResourceInfoForType(
+ sparkConf: SparkConf,
+ prefix: String,
+ resourceType: String): ResourceInformation = {
+ val discoveryConf = prefix + resourceType + SPARK_RESOURCE_DISCOVERY_SCRIPT_POSTFIX
+ val script = sparkConf.getOption(discoveryConf)
+ val result = if (script.nonEmpty) {
+ val scriptFile = new File(script.get)
+ // check that script exists and try to execute
+ if (scriptFile.exists()) {
+ try {
+ val output = executeAndGetOutput(Seq(script.get), new File("."))
+ val parsedJson = parse(output)
+ val name = (parsedJson \ "name").extract[String]
+ val addresses = (parsedJson \ "addresses").extract[Array[String]].toArray
+ new ResourceInformation(name, addresses)
+ } catch {
+ case e @ (_: SparkException | _: MappingException | _: JsonParseException) =>
+ throw new SparkException(s"Error running the resource discovery script: $scriptFile" +
+ s" for $resourceType", e)
+ }
+ } else {
+ throw new SparkException(s"Resource script: $scriptFile to discover $resourceType" +
+ s" doesn't exist!")
+ }
+ } else {
+ throw new SparkException(s"User is expecting to use $resourceType resources but " +
+ s"didn't specify a script via conf: $discoveryConf, to find them!")
+ }
+ result
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/core/src/main/scala/org/apache/spark/ResourceInformation.scala
similarity index 54%
rename from sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java
rename to core/src/main/scala/org/apache/spark/ResourceInformation.scala
index c44b8af2552f0..6a5b725ac21d7 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java
+++ b/core/src/main/scala/org/apache/spark/ResourceInformation.scala
@@ -15,15 +15,23 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution.streaming;
+package org.apache.spark
+
+import org.apache.spark.annotation.Evolving
/**
- * The shared interface between V1 streaming sources and V2 streaming readers.
+ * Class to hold information about a type of Resource. A resource could be a GPU, FPGA, etc.
+ * The array of addresses are resource specific and its up to the user to interpret the address.
+ *
+ * One example is GPUs, where the addresses would be the indices of the GPUs
*
- * This is a temporary interface for compatibility during migration. It should not be implemented
- * directly, and will be removed in future versions.
+ * @param name the name of the resource
+ * @param addresses an array of strings describing the addresses of the resource
*/
-public interface BaseStreamingSource {
- /** Stop this source and free any resources it has allocated. */
- void stop();
+@Evolving
+class ResourceInformation(
+ val name: String,
+ val addresses: Array[String]) extends Serializable {
+
+ override def toString: String = s"[name: ${name}, addresses: ${addresses.mkString(",")}]"
}
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 26b18564be778..5a5c5a403f202 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -214,6 +214,15 @@ private[spark] class SecurityManager(
*/
def aclsEnabled(): Boolean = aclsOn
+ /**
+ * Checks whether the given user is an admin. This gives the user both view and
+ * modify permissions, and also allows the user to impersonate other users when
+ * making UI requests.
+ */
+ def checkAdminPermissions(user: String): Boolean = {
+ isUserInACL(user, adminAcls, adminAclsGroups)
+ }
+
/**
* Checks the given user against the view acl and groups list to see if they have
* authorization to view the UI. If the UI acls are disabled
@@ -227,13 +236,7 @@ private[spark] class SecurityManager(
def checkUIViewPermissions(user: String): Boolean = {
logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" +
viewAcls.mkString(",") + " viewAclsGroups=" + viewAclsGroups.mkString(","))
- if (!aclsEnabled || user == null || viewAcls.contains(user) ||
- viewAcls.contains(WILDCARD_ACL) || viewAclsGroups.contains(WILDCARD_ACL)) {
- return true
- }
- val currentUserGroups = Utils.getCurrentUserGroups(sparkConf, user)
- logDebug("userGroups=" + currentUserGroups.mkString(","))
- viewAclsGroups.exists(currentUserGroups.contains(_))
+ isUserInACL(user, viewAcls, viewAclsGroups)
}
/**
@@ -249,13 +252,7 @@ private[spark] class SecurityManager(
def checkModifyPermissions(user: String): Boolean = {
logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" +
modifyAcls.mkString(",") + " modifyAclsGroups=" + modifyAclsGroups.mkString(","))
- if (!aclsEnabled || user == null || modifyAcls.contains(user) ||
- modifyAcls.contains(WILDCARD_ACL) || modifyAclsGroups.contains(WILDCARD_ACL)) {
- return true
- }
- val currentUserGroups = Utils.getCurrentUserGroups(sparkConf, user)
- logDebug("userGroups=" + currentUserGroups)
- modifyAclsGroups.exists(currentUserGroups.contains(_))
+ isUserInACL(user, modifyAcls, modifyAclsGroups)
}
/**
@@ -371,6 +368,23 @@ private[spark] class SecurityManager(
}
}
+ private def isUserInACL(
+ user: String,
+ aclUsers: Set[String],
+ aclGroups: Set[String]): Boolean = {
+ if (user == null ||
+ !aclsEnabled ||
+ aclUsers.contains(WILDCARD_ACL) ||
+ aclUsers.contains(user) ||
+ aclGroups.contains(WILDCARD_ACL)) {
+ true
+ } else {
+ val userGroups = Utils.getCurrentUserGroups(sparkConf, user)
+ logDebug(s"user $user is in groups ${userGroups.mkString(",")}")
+ aclGroups.exists(userGroups.contains(_))
+ }
+ }
+
// Default SecurityManager only has a single secret key, so ignore appId.
override def getSaslUser(appId: String): String = getSaslUser()
override def getSecretKey(appId: String): String = getSecretKey()
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 913a1704ad5ce..aa93f42141fc1 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -168,6 +168,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria
}
/** Set multiple parameters together */
+ def setAll(settings: Iterable[(String, String)]): SparkConf = {
+ settings.foreach { case (k, v) => set(k, v) }
+ this
+ }
+
+ /**
+ * Set multiple parameters together
+ */
+ @deprecated("Use setAll(Iterable) instead", "3.0.0")
def setAll(settings: Traversable[(String, String)]): SparkConf = {
settings.foreach { case (k, v) => set(k, v) }
this
@@ -705,7 +714,9 @@ private[spark] object SparkConf extends Logging {
AlternateConfig("spark.yarn.kerberos.relogin.period", "3.0")),
KERBEROS_FILESYSTEMS_TO_ACCESS.key -> Seq(
AlternateConfig("spark.yarn.access.namenodes", "2.2"),
- AlternateConfig("spark.yarn.access.hadoopFileSystems", "3.0"))
+ AlternateConfig("spark.yarn.access.hadoopFileSystems", "3.0")),
+ "spark.kafka.consumer.cache.capacity" -> Seq(
+ AlternateConfig("spark.sql.kafkaConsumerCache.capacity", "3.0"))
)
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8b744356daaee..997941080fc6f 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -2556,7 +2556,7 @@ object SparkContext extends Logging {
private[spark] val DRIVER_IDENTIFIER = "driver"
- private implicit def arrayToArrayWritable[T <: Writable : ClassTag](arr: Traversable[T])
+ private implicit def arrayToArrayWritable[T <: Writable : ClassTag](arr: Iterable[T])
: ArrayWritable = {
def anyToWritable[U <: Writable](u: U): Writable = u
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index c2ebd388a2365..c97b10ee63b18 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -192,6 +192,20 @@ private[spark] object TestUtils {
assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did")
}
+ /**
+ * Asserts that exception message contains the message. Please note this checks all
+ * exceptions in the tree.
+ */
+ def assertExceptionMsg(exception: Throwable, msg: String): Unit = {
+ var e = exception
+ var contains = e.getMessage.contains(msg)
+ while (e.getCause != null && !contains) {
+ e = e.getCause
+ contains = e.getMessage.contains(msg)
+ }
+ assert(contains, s"Exception tree doesn't contain the expected message: $msg")
+ }
+
/**
* Test if a command is available.
*/
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index 2ab8add63efae..a4817b3cf770d 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -33,18 +33,17 @@ import org.apache.spark.util.{SerializableConfiguration, Utils}
* A trait for use with reading custom classes in PySpark. Implement this trait and add custom
* transformation code by overriding the convert method.
*/
-trait Converter[T, + U] extends Serializable {
+trait Converter[-T, +U] extends Serializable {
def convert(obj: T): U
}
private[python] object Converter extends Logging {
- def getInstance(converterClass: Option[String],
- defaultConverter: Converter[Any, Any]): Converter[Any, Any] = {
+ def getInstance[T, U](converterClass: Option[String],
+ defaultConverter: Converter[_ >: T, _ <: U]): Converter[T, U] = {
converterClass.map { cc =>
Try {
- val c = Utils.classForName(cc).getConstructor().
- newInstance().asInstanceOf[Converter[Any, Any]]
+ val c = Utils.classForName[Converter[T, U]](cc).getConstructor().newInstance()
logInfo(s"Loaded converter: $cc")
c
} match {
@@ -177,8 +176,8 @@ private[python] object PythonHadoopUtil {
* [[org.apache.hadoop.io.Writable]], into an RDD of base types, or vice versa.
*/
def convertRDD[K, V](rdd: RDD[(K, V)],
- keyConverter: Converter[Any, Any],
- valueConverter: Converter[Any, Any]): RDD[(Any, Any)] = {
+ keyConverter: Converter[K, Any],
+ valueConverter: Converter[V, Any]): RDD[(Any, Any)] = {
rdd.map { case (k, v) => (keyConverter.convert(k), valueConverter.convert(v)) }
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 5b492b1f3991e..fe25c3aac81b8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -24,10 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.concurrent.Promise
-import scala.concurrent.duration.Duration
-import scala.language.existentials
-import scala.util.Try
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
@@ -90,7 +86,7 @@ private[spark] case class PythonFunction(
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
/** Thrown for exceptions in user Python code. */
-private[spark] class PythonException(msg: String, cause: Exception)
+private[spark] class PythonException(msg: String, cause: Throwable)
extends RuntimeException(msg, cause)
/**
@@ -167,8 +163,63 @@ private[spark] object PythonRDD extends Logging {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}
+ /**
+ * A helper function to create a local RDD iterator and serve it via socket. Partitions are
+ * are collected as separate jobs, by order of index. Partition data is first requested by a
+ * non-zero integer to start a collection job. The response is prefaced by an integer with 1
+ * meaning partition data will be served, 0 meaning the local iterator has been consumed,
+ * and -1 meaining an error occurred during collection. This function is used by
+ * pyspark.rdd._local_iterator_from_socket().
+ *
+ * @return 2-tuple (as a Java array) with the port number of a local socket which serves the
+ * data collected from these jobs, and the secret for authentication.
+ */
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
- serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
+ val (port, secret) = SocketAuthServer.setupOneConnectionServer(
+ authHelper, "serve toLocalIterator") { s =>
+ val out = new DataOutputStream(s.getOutputStream)
+ val in = new DataInputStream(s.getInputStream)
+ Utils.tryWithSafeFinally {
+
+ // Collects a partition on each iteration
+ val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
+ rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
+ }
+
+ // Read request for data and send next partition if nonzero
+ var complete = false
+ while (!complete && in.readInt() != 0) {
+ if (collectPartitionIter.hasNext) {
+ try {
+ // Attempt to collect the next partition
+ val partitionArray = collectPartitionIter.next()
+
+ // Send response there is a partition to read
+ out.writeInt(1)
+
+ // Write the next object and signal end of data for this iteration
+ writeIteratorToStream(partitionArray.toIterator, out)
+ out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+ out.flush()
+ } catch {
+ case e: SparkException =>
+ // Send response that an error occurred followed by error message
+ out.writeInt(-1)
+ writeUTF(e.getMessage, out)
+ complete = true
+ }
+ } else {
+ // Send response there are no more partitions to read and close
+ out.writeInt(0)
+ complete = true
+ }
+ }
+ } {
+ out.close()
+ in.close()
+ }
+ }
+ Array(port, secret)
}
def readRDDFromFile(
@@ -228,8 +279,8 @@ private[spark] object PythonRDD extends Logging {
batchSize: Int): JavaRDD[Array[Byte]] = {
val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
- val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]]
- val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]]
+ val kc = Utils.classForName[K](keyClass)
+ val vc = Utils.classForName[V](valueClass)
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration()))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
@@ -296,9 +347,9 @@ private[spark] object PythonRDD extends Logging {
keyClass: String,
valueClass: String,
conf: Configuration): RDD[(K, V)] = {
- val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]]
- val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]]
- val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]]
+ val kc = Utils.classForName[K](keyClass)
+ val vc = Utils.classForName[V](valueClass)
+ val fc = Utils.classForName[F](inputFormatClass)
if (path.isDefined) {
sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf)
} else {
@@ -365,9 +416,9 @@ private[spark] object PythonRDD extends Logging {
keyClass: String,
valueClass: String,
conf: Configuration) = {
- val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]]
- val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]]
- val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]]
+ val kc = Utils.classForName[K](keyClass)
+ val vc = Utils.classForName[V](valueClass)
+ val fc = Utils.classForName[F](inputFormatClass)
if (path.isDefined) {
sc.sc.hadoopFile(path.get, fc, kc, vc)
} else {
@@ -425,29 +476,33 @@ private[spark] object PythonRDD extends Logging {
PythonHadoopUtil.mergeConfs(baseConf, conf)
}
- private def inferKeyValueTypes[K, V](rdd: RDD[(K, V)], keyConverterClass: String = null,
- valueConverterClass: String = null): (Class[_], Class[_]) = {
+ private def inferKeyValueTypes[K, V, KK, VV](rdd: RDD[(K, V)], keyConverterClass: String = null,
+ valueConverterClass: String = null): (Class[_ <: KK], Class[_ <: VV]) = {
// Peek at an element to figure out key/value types. Since Writables are not serializable,
// we cannot call first() on the converted RDD. Instead, we call first() on the original RDD
// and then convert locally.
val (key, value) = rdd.first()
- val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass,
- new JavaToWritableConverter)
+ val (kc, vc) = getKeyValueConverters[K, V, KK, VV](
+ keyConverterClass, valueConverterClass, new JavaToWritableConverter)
(kc.convert(key).getClass, vc.convert(value).getClass)
}
- private def getKeyValueTypes(keyClass: String, valueClass: String):
- Option[(Class[_], Class[_])] = {
+ private def getKeyValueTypes[K, V](keyClass: String, valueClass: String):
+ Option[(Class[K], Class[V])] = {
for {
k <- Option(keyClass)
v <- Option(valueClass)
} yield (Utils.classForName(k), Utils.classForName(v))
}
- private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String,
- defaultConverter: Converter[Any, Any]): (Converter[Any, Any], Converter[Any, Any]) = {
- val keyConverter = Converter.getInstance(Option(keyConverterClass), defaultConverter)
- val valueConverter = Converter.getInstance(Option(valueConverterClass), defaultConverter)
+ private def getKeyValueConverters[K, V, KK, VV](
+ keyConverterClass: String,
+ valueConverterClass: String,
+ defaultConverter: Converter[_, _]): (Converter[K, KK], Converter[V, VV]) = {
+ val keyConverter = Converter.getInstance(Option(keyConverterClass),
+ defaultConverter.asInstanceOf[Converter[K, KK]])
+ val valueConverter = Converter.getInstance(Option(valueConverterClass),
+ defaultConverter.asInstanceOf[Converter[V, VV]])
(keyConverter, valueConverter)
}
@@ -459,7 +514,7 @@ private[spark] object PythonRDD extends Logging {
keyConverterClass: String,
valueConverterClass: String,
defaultConverter: Converter[Any, Any]): RDD[(Any, Any)] = {
- val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass,
+ val (kc, vc) = getKeyValueConverters[K, V, Any, Any](keyConverterClass, valueConverterClass,
defaultConverter)
PythonHadoopUtil.convertRDD(rdd, kc, vc)
}
@@ -470,7 +525,7 @@ private[spark] object PythonRDD extends Logging {
* [[org.apache.hadoop.io.Writable]] types already, since Writables are not Java
* `Serializable` and we can't peek at them. The `path` can be on any Hadoop file system.
*/
- def saveAsSequenceFile[K, V, C <: CompressionCodec](
+ def saveAsSequenceFile[C <: CompressionCodec](
pyRDD: JavaRDD[Array[Byte]],
batchSerialized: Boolean,
path: String,
@@ -489,7 +544,7 @@ private[spark] object PythonRDD extends Logging {
* `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of
* this RDD.
*/
- def saveAsHadoopFile[K, V, F <: OutputFormat[_, _], C <: CompressionCodec](
+ def saveAsHadoopFile[F <: OutputFormat[_, _], C <: CompressionCodec](
pyRDD: JavaRDD[Array[Byte]],
batchSerialized: Boolean,
path: String,
@@ -507,7 +562,7 @@ private[spark] object PythonRDD extends Logging {
val codec = Option(compressionCodecClass).map(Utils.classForName(_).asInstanceOf[Class[C]])
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
new JavaToWritableConverter)
- val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]]
+ val fc = Utils.classForName[F](outputFormatClass)
converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec = codec)
}
@@ -520,7 +575,7 @@ private[spark] object PythonRDD extends Logging {
* `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of
* this RDD.
*/
- def saveAsNewAPIHadoopFile[K, V, F <: NewOutputFormat[_, _]](
+ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]](
pyRDD: JavaRDD[Array[Byte]],
batchSerialized: Boolean,
path: String,
@@ -548,7 +603,7 @@ private[spark] object PythonRDD extends Logging {
* (mapred vs. mapreduce). Keys/values are converted for output using either user specified
* converters or, by default, [[org.apache.spark.api.python.JavaToWritableConverter]].
*/
- def saveAsHadoopDataset[K, V](
+ def saveAsHadoopDataset(
pyRDD: JavaRDD[Array[Byte]],
batchSerialized: Boolean,
confAsMap: java.util.HashMap[String, String],
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index b7f14e062b437..dca87044513c2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
import org.apache.spark._
import org.apache.spark.internal.Logging
@@ -165,15 +166,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
context: TaskContext)
extends Thread(s"stdout writer for $pythonExec") {
- @volatile private var _exception: Exception = null
+ @volatile private var _exception: Throwable = null
private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
setDaemon(true)
- /** Contains the exception thrown while writing the parent iterator to the Python process. */
- def exception: Option[Exception] = Option(_exception)
+ /** Contains the throwable thrown while writing the parent iterator to the Python process. */
+ def exception: Option[Throwable] = Option(_exception)
/** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */
def shutdownOnTaskCompletion() {
@@ -347,18 +348,21 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
dataOut.writeInt(SpecialLengths.END_OF_STREAM)
dataOut.flush()
} catch {
- case e: Exception if context.isCompleted || context.isInterrupted =>
- logDebug("Exception thrown after task completion (likely due to cleanup)", e)
- if (!worker.isClosed) {
- Utils.tryLog(worker.shutdownOutput())
- }
-
- case e: Exception =>
- // We must avoid throwing exceptions here, because the thread uncaught exception handler
- // will kill the whole executor (see org.apache.spark.executor.Executor).
- _exception = e
- if (!worker.isClosed) {
- Utils.tryLog(worker.shutdownOutput())
+ case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) =>
+ if (context.isCompleted || context.isInterrupted) {
+ logDebug("Exception/NonFatal Error thrown after task completion (likely due to " +
+ "cleanup)", t)
+ if (!worker.isClosed) {
+ Utils.tryLog(worker.shutdownOutput())
+ }
+ } else {
+ // We must avoid throwing exceptions/NonFatals here, because the thread uncaught
+ // exception handler will kill the whole executor (see
+ // org.apache.spark.executor.Executor).
+ _exception = t
+ if (!worker.isClosed) {
+ Utils.tryLog(worker.shutdownOutput())
+ }
}
}
}
@@ -511,7 +515,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
if (!context.isCompleted) {
try {
// Mimic the task name used in `Executor` to help the user find out the task to blame.
- val taskName = s"${context.partitionId}.${context.taskAttemptId} " +
+ val taskName = s"${context.partitionId}.${context.attemptNumber} " +
s"in stage ${context.stageId} (TID ${context.taskAttemptId})"
logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index 01e64b6972ae2..9462dfd950bab 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -186,6 +186,9 @@ private[spark] object SerDeUtil extends Logging {
val unpickle = new Unpickler
iter.flatMap { row =>
val obj = unpickle.loads(row)
+ // `Opcodes.MEMOIZE` of Protocol 4 (Python 3.4+) will store objects in internal map
+ // of `Unpickler`. This map is cleared when calling `Unpickler.close()`.
+ unpickle.close()
if (batched) {
obj match {
case array: Array[Any] => array.toSeq
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
index 36b4132088b58..c755dcba6bcea 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala
@@ -30,7 +30,7 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
import io.netty.handler.timeout.ReadTimeoutHandler
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.R._
@@ -47,7 +47,7 @@ private[spark] class RBackend {
private[r] val jvmObjectTracker = new JVMObjectTracker
def init(): (Int, RAuthHelper) = {
- val conf = new SparkConf()
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT)
bossGroup = new NioEventLoopGroup(conf.get(R_NUM_BACKEND_THREADS))
val workerGroup = bossGroup
@@ -124,7 +124,7 @@ private[spark] object RBackend extends Logging {
val listenPort = serverSocket.getLocalPort()
// Connection timeout is set by socket client. To make it configurable we will pass the
// timeout value to client inside the temp file
- val conf = new SparkConf()
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT)
// tell the R process via temporary file
diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
index 7b74efa41044f..f2f81b11fc813 100644
--- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
+++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
@@ -20,13 +20,11 @@ package org.apache.spark.api.r
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util.concurrent.TimeUnit
-import scala.language.existentials
-
import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.channel.ChannelHandler.Sharable
import io.netty.handler.timeout.ReadTimeoutException
-import org.apache.spark.SparkConf
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.api.r.SerDe._
import org.apache.spark.internal.Logging
import org.apache.spark.internal.config.R._
@@ -98,7 +96,7 @@ private[r] class RBackendHandler(server: RBackend)
ctx.write(pingBaos.toByteArray)
}
}
- val conf = new SparkConf()
+ val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf())
val heartBeatInterval = conf.get(R_HEARTBEAT_INTERVAL)
val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT)
val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1)
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index a66243012041c..99f841234005e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -26,7 +26,6 @@ import scala.collection.mutable.ListBuffer
import scala.concurrent.{Future, Promise}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.duration._
-import scala.language.postfixOps
import scala.sys.process._
import org.json4s._
@@ -112,7 +111,7 @@ private object FaultToleranceTest extends App with Logging {
assertValidClusterState()
killLeader()
- delay(30 seconds)
+ delay(30.seconds)
assertValidClusterState()
createClient()
assertValidClusterState()
@@ -126,12 +125,12 @@ private object FaultToleranceTest extends App with Logging {
killLeader()
addMasters(1)
- delay(30 seconds)
+ delay(30.seconds)
assertValidClusterState()
killLeader()
addMasters(1)
- delay(30 seconds)
+ delay(30.seconds)
assertValidClusterState()
}
@@ -156,7 +155,7 @@ private object FaultToleranceTest extends App with Logging {
killLeader()
workers.foreach(_.kill())
workers.clear()
- delay(30 seconds)
+ delay(30.seconds)
addWorkers(2)
assertValidClusterState()
}
@@ -174,7 +173,7 @@ private object FaultToleranceTest extends App with Logging {
(1 to 3).foreach { _ =>
killLeader()
- delay(30 seconds)
+ delay(30.seconds)
assertValidClusterState()
assertTrue(getLeader == masters.head)
addMasters(1)
@@ -264,7 +263,7 @@ private object FaultToleranceTest extends App with Logging {
}
// Avoid waiting indefinitely (e.g., we could register but get no executors).
- assertTrue(ThreadUtils.awaitResult(f, 120 seconds))
+ assertTrue(ThreadUtils.awaitResult(f, 2.minutes))
}
/**
@@ -317,7 +316,7 @@ private object FaultToleranceTest extends App with Logging {
}
try {
- assertTrue(ThreadUtils.awaitResult(f, 120 seconds))
+ assertTrue(ThreadUtils.awaitResult(f, 2.minutes))
} catch {
case e: TimeoutException =>
logError("Master states: " + masters.map(_.state))
@@ -421,7 +420,7 @@ private object SparkDocker {
}
dockerCmd.run(ProcessLogger(findIpAndLog _))
- val ip = ThreadUtils.awaitResult(ipPromise.future, 30 seconds)
+ val ip = ThreadUtils.awaitResult(ipPromise.future, 30.seconds)
val dockerId = Docker.getLastProcessId
(ip, dockerId, outFile)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 9efaaa765b5d1..49d939539719d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -544,10 +544,14 @@ private[spark] class SparkSubmit extends Logging {
// Yarn only
OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.queue"),
- OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles"),
- OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"),
- OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"),
- OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives"),
+ OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles",
+ mergeFn = Some(mergeFileLists(_, _))),
+ OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars",
+ mergeFn = Some(mergeFileLists(_, _))),
+ OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files",
+ mergeFn = Some(mergeFileLists(_, _))),
+ OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives",
+ mergeFn = Some(mergeFileLists(_, _))),
// Other options
OptionAssigner(args.numExecutors, YARN | KUBERNETES, ALL_DEPLOY_MODES,
@@ -608,7 +612,13 @@ private[spark] class SparkSubmit extends Logging {
(deployMode & opt.deployMode) != 0 &&
(clusterManager & opt.clusterManager) != 0) {
if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) }
- if (opt.confKey != null) { sparkConf.set(opt.confKey, opt.value) }
+ if (opt.confKey != null) {
+ if (opt.mergeFn.isDefined && sparkConf.contains(opt.confKey)) {
+ sparkConf.set(opt.confKey, opt.mergeFn.get.apply(sparkConf.get(opt.confKey), opt.value))
+ } else {
+ sparkConf.set(opt.confKey, opt.value)
+ }
+ }
}
}
@@ -1381,7 +1391,8 @@ private case class OptionAssigner(
clusterManager: Int,
deployMode: Int,
clOption: String = null,
- confKey: String = null)
+ confKey: String = null,
+ mergeFn: Option[(String, String) => String] = None)
private[spark] trait SparkSubmitOperation {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index 7c9ce14c652c4..7df36c5aeba07 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -34,7 +34,6 @@ import org.apache.spark.internal.config.History
import org.apache.spark.internal.config.UI._
import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot}
import org.apache.spark.ui.{SparkUI, UIUtils, WebUI}
-import org.apache.spark.ui.JettyUtils._
import org.apache.spark.util.{ShutdownHookManager, SystemClock, Utils}
/**
@@ -274,10 +273,9 @@ object HistoryServer extends Logging {
val providerName = conf.get(History.PROVIDER)
.getOrElse(classOf[FsHistoryProvider].getName())
- val provider = Utils.classForName(providerName)
+ val provider = Utils.classForName[ApplicationHistoryProvider](providerName)
.getConstructor(classOf[SparkConf])
.newInstance(conf)
- .asInstanceOf[ApplicationHistoryProvider]
val port = conf.get(History.HISTORY_SERVER_UI_PORT)
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 9d2301cea9b33..6f1484cee586e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -78,8 +78,8 @@ private[deploy] class ExecutorRunner(
// Shutdown hook that kills actors on shutdown.
shutdownHook = ShutdownHookManager.addShutdownHook { () =>
// It's possible that we arrive here before calling `fetchAndRunExecutor`, then `state` will
- // be `ExecutorState.RUNNING`. In this case, we should set `state` to `FAILED`.
- if (state == ExecutorState.RUNNING) {
+ // be `ExecutorState.LAUNCHING`. In this case, we should set `state` to `FAILED`.
+ if (state == ExecutorState.LAUNCHING) {
state = ExecutorState.FAILED
}
killProcess(Some("Worker shutting down")) }
@@ -183,6 +183,8 @@ private[deploy] class ExecutorRunner(
Files.write(header, stderr, StandardCharsets.UTF_8)
stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
+ state = ExecutorState.RUNNING
+ worker.send(ExecutorStateChanged(appId, execId, state, None, None))
// Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown)
// or with nonzero exit code
val exitCode = process.waitFor()
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index eb2add3af8251..f8ec5b6b190c1 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -540,12 +540,12 @@ private[deploy] class Worker(
executorDir,
workerUri,
conf,
- appLocalDirs, ExecutorState.RUNNING)
+ appLocalDirs,
+ ExecutorState.LAUNCHING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
memoryUsed += memory_
- sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None))
} catch {
case e: Exception =>
logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e)
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 645f58716de63..af01e0b23dada 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -17,6 +17,7 @@
package org.apache.spark.executor
+import java.io.{BufferedInputStream, FileInputStream}
import java.net.URL
import java.nio.ByteBuffer
import java.util.Locale
@@ -26,11 +27,18 @@ import scala.collection.mutable
import scala.util.{Failure, Success}
import scala.util.control.NonFatal
+import com.fasterxml.jackson.databind.exc.MismatchedInputException
+import org.json4s.DefaultFormats
+import org.json4s.JsonAST.JArray
+import org.json4s.MappingException
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.deploy.worker.WorkerWatcher
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
import org.apache.spark.rpc._
import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
@@ -44,9 +52,12 @@ private[spark] class CoarseGrainedExecutorBackend(
hostname: String,
cores: Int,
userClassPath: Seq[URL],
- env: SparkEnv)
+ env: SparkEnv,
+ resourcesFile: Option[String])
extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging {
+ private implicit val formats = DefaultFormats
+
private[this] val stopping = new AtomicBoolean(false)
var executor: Executor = null
@volatile var driver: Option[RpcEndpointRef] = None
@@ -57,11 +68,12 @@ private[spark] class CoarseGrainedExecutorBackend(
override def onStart() {
logInfo("Connecting to driver: " + driverUrl)
+ val resources = parseOrFindResources(resourcesFile)
rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
// This is a very fast action so we can use "ThreadUtils.sameThread"
driver = Some(ref)
ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, extractLogUrls,
- extractAttributes))
+ extractAttributes, resources))
}(ThreadUtils.sameThread).onComplete {
// This is a very fast action so we can use "ThreadUtils.sameThread"
case Success(msg) =>
@@ -71,6 +83,97 @@ private[spark] class CoarseGrainedExecutorBackend(
}(ThreadUtils.sameThread)
}
+ // Check that the actual resources discovered will satisfy the user specified
+ // requirements and that they match the configs specified by the user to catch
+ // mismatches between what the user requested and what resource manager gave or
+ // what the discovery script found.
+ private def checkResourcesMeetRequirements(
+ resourceConfigPrefix: String,
+ reqResourcesAndCounts: Array[(String, String)],
+ actualResources: Map[String, ResourceInformation]): Unit = {
+
+ reqResourcesAndCounts.foreach { case (rName, reqCount) =>
+ if (actualResources.contains(rName)) {
+ val resourceInfo = actualResources(rName)
+
+ if (resourceInfo.addresses.size < reqCount.toLong) {
+ throw new SparkException(s"Resource: $rName with addresses: " +
+ s"${resourceInfo.addresses.mkString(",")} doesn't meet the " +
+ s"requirements of needing $reqCount of them")
+ }
+ // also make sure the resource count on start matches the
+ // resource configs specified by user
+ val userCountConfigName =
+ resourceConfigPrefix + rName + SPARK_RESOURCE_COUNT_POSTFIX
+ val userConfigCount = env.conf.getOption(userCountConfigName).
+ getOrElse(throw new SparkException(s"Resource: $rName not specified " +
+ s"via config: $userCountConfigName, but required, " +
+ "please fix your configuration"))
+
+ if (userConfigCount.toLong > resourceInfo.addresses.size) {
+ throw new SparkException(s"Resource: $rName, with addresses: " +
+ s"${resourceInfo.addresses.mkString(",")} " +
+ s"is less than what the user requested for count: $userConfigCount, " +
+ s"via $userCountConfigName")
+ }
+ } else {
+ throw new SparkException(s"Executor resource config missing required task resource: $rName")
+ }
+ }
+ }
+
+ // visible for testing
+ def parseOrFindResources(resourcesFile: Option[String]): Map[String, ResourceInformation] = {
+ // only parse the resources if a task requires them
+ val taskResourceConfigs = env.conf.getAllWithPrefix(SPARK_TASK_RESOURCE_PREFIX)
+ val resourceInfo = if (taskResourceConfigs.nonEmpty) {
+ val execResources = resourcesFile.map { resourceFileStr => {
+ val source = new BufferedInputStream(new FileInputStream(resourceFileStr))
+ val resourceMap = try {
+ val parsedJson = parse(source).asInstanceOf[JArray].arr
+ parsedJson.map { json =>
+ val name = (json \ "name").extract[String]
+ val addresses = (json \ "addresses").extract[Array[String]]
+ new ResourceInformation(name, addresses)
+ }.map(x => (x.name -> x)).toMap
+ } catch {
+ case e @ (_: MappingException | _: MismatchedInputException) =>
+ throw new SparkException(
+ s"Exception parsing the resources in $resourceFileStr", e)
+ } finally {
+ source.close()
+ }
+ resourceMap
+ }}.getOrElse(ResourceDiscoverer.findResources(env.conf, isDriver = false))
+
+ if (execResources.isEmpty) {
+ throw new SparkException("User specified resources per task via: " +
+ s"$SPARK_TASK_RESOURCE_PREFIX, but can't find any resources available on the executor.")
+ }
+ // get just the map of resource name to count
+ val resourcesAndCounts = taskResourceConfigs.
+ withFilter { case (k, v) => k.endsWith(SPARK_RESOURCE_COUNT_POSTFIX)}.
+ map { case (k, v) => (k.dropRight(SPARK_RESOURCE_COUNT_POSTFIX.size), v)}
+
+ checkResourcesMeetRequirements(SPARK_EXECUTOR_RESOURCE_PREFIX, resourcesAndCounts,
+ execResources)
+
+ logInfo("===============================================================================")
+ logInfo(s"Executor $executorId Resources:")
+ execResources.foreach { case (k, v) => logInfo(s"$k -> $v") }
+ logInfo("===============================================================================")
+
+ execResources
+ } else {
+ if (resourcesFile.nonEmpty) {
+ logWarning(s"A resources file was specified but the application is not configured " +
+ s"to use any resources, see the configs with prefix: ${SPARK_TASK_RESOURCE_PREFIX}")
+ }
+ Map.empty[String, ResourceInformation]
+ }
+ resourceInfo
+ }
+
def extractLogUrls: Map[String, String] = {
val prefix = "SPARK_LOG_URL_"
sys.env.filterKeys(_.startsWith(prefix))
@@ -188,13 +291,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
cores: Int,
appId: String,
workerUrl: Option[String],
- userClassPath: mutable.ListBuffer[URL])
+ userClassPath: mutable.ListBuffer[URL],
+ resourcesFile: Option[String])
def main(args: Array[String]): Unit = {
val createFn: (RpcEnv, Arguments, SparkEnv) =>
CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env) =>
new CoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId,
- arguments.hostname, arguments.cores, arguments.userClassPath, env)
+ arguments.hostname, arguments.cores, arguments.userClassPath, env, arguments.resourcesFile)
}
run(parseArguments(args, this.getClass.getCanonicalName.stripSuffix("$")), createFn)
System.exit(0)
@@ -239,6 +343,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf)
}
+ driverConf.set(EXECUTOR_ID, arguments.executorId)
val env = SparkEnv.createExecutorEnv(driverConf, arguments.executorId, arguments.hostname,
arguments.cores, cfg.ioEncryptionKey, isLocal = false)
@@ -255,6 +360,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
var executorId: String = null
var hostname: String = null
var cores: Int = 0
+ var resourcesFile: Option[String] = None
var appId: String = null
var workerUrl: Option[String] = None
val userClassPath = new mutable.ListBuffer[URL]()
@@ -274,6 +380,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
case ("--cores") :: value :: tail =>
cores = value.toInt
argv = tail
+ case ("--resourcesFile") :: value :: tail =>
+ resourcesFile = Some(value)
+ argv = tail
case ("--app-id") :: value :: tail =>
appId = value
argv = tail
@@ -299,7 +408,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}
Arguments(driverUrl, executorId, hostname, cores, appId, workerUrl,
- userClassPath)
+ userClassPath, resourcesFile)
}
private def printUsageAndExit(classNameForEntry: String): Unit = {
@@ -313,6 +422,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
| --executor-id
| --hostname
| --cores
+ | --resourcesFile
| --app-id
| --worker-url
| --user-class-path
diff --git a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala
deleted file mode 100644
index e91ddd3e9741a..0000000000000
--- a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala
+++ /dev/null
@@ -1,91 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.internal.config
-
-private[spark] object Kafka {
-
- val BOOTSTRAP_SERVERS =
- ConfigBuilder("spark.kafka.bootstrap.servers")
- .doc("A list of coma separated host/port pairs to use for establishing the initial " +
- "connection to the Kafka cluster. For further details please see kafka documentation. " +
- "Only used to obtain delegation token.")
- .stringConf
- .createOptional
-
- val SECURITY_PROTOCOL =
- ConfigBuilder("spark.kafka.security.protocol")
- .doc("Protocol used to communicate with brokers. For further details please see kafka " +
- "documentation. Only used to obtain delegation token.")
- .stringConf
- .createWithDefault("SASL_SSL")
-
- val KERBEROS_SERVICE_NAME =
- ConfigBuilder("spark.kafka.sasl.kerberos.service.name")
- .doc("The Kerberos principal name that Kafka runs as. This can be defined either in " +
- "Kafka's JAAS config or in Kafka's config. For further details please see kafka " +
- "documentation. Only used to obtain delegation token.")
- .stringConf
- .createWithDefault("kafka")
-
- val TRUSTSTORE_LOCATION =
- ConfigBuilder("spark.kafka.ssl.truststore.location")
- .doc("The location of the trust store file. For further details please see kafka " +
- "documentation. Only used to obtain delegation token.")
- .stringConf
- .createOptional
-
- val TRUSTSTORE_PASSWORD =
- ConfigBuilder("spark.kafka.ssl.truststore.password")
- .doc("The store password for the trust store file. This is optional for client and only " +
- "needed if ssl.truststore.location is configured. For further details please see kafka " +
- "documentation. Only used to obtain delegation token.")
- .stringConf
- .createOptional
-
- val KEYSTORE_LOCATION =
- ConfigBuilder("spark.kafka.ssl.keystore.location")
- .doc("The location of the key store file. This is optional for client and can be used for " +
- "two-way authentication for client. For further details please see kafka documentation. " +
- "Only used to obtain delegation token.")
- .stringConf
- .createOptional
-
- val KEYSTORE_PASSWORD =
- ConfigBuilder("spark.kafka.ssl.keystore.password")
- .doc("The store password for the key store file. This is optional for client and only " +
- "needed if ssl.keystore.location is configured. For further details please see kafka " +
- "documentation. Only used to obtain delegation token.")
- .stringConf
- .createOptional
-
- val KEY_PASSWORD =
- ConfigBuilder("spark.kafka.ssl.key.password")
- .doc("The password of the private key in the key store file. This is optional for client. " +
- "For further details please see kafka documentation. Only used to obtain delegation token.")
- .stringConf
- .createOptional
-
- val TOKEN_SASL_MECHANISM =
- ConfigBuilder("spark.kafka.sasl.token.mechanism")
- .doc("SASL mechanism used for client connections with delegation token. Because SCRAM " +
- "login module used for authentication a compatible mechanism has to be set here. " +
- "For further details please see kafka documentation (sasl.mechanism). Only used to " +
- "authenticate against Kafka broker with delegation token.")
- .stringConf
- .createWithDefault("SCRAM-SHA-512")
-}
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 0bd46bef35d27..0aed1af023f83 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -30,6 +30,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.MAX_
package object config {
+ private[spark] val SPARK_DRIVER_RESOURCE_PREFIX = "spark.driver.resource."
+ private[spark] val SPARK_EXECUTOR_RESOURCE_PREFIX = "spark.executor.resource."
+ private[spark] val SPARK_TASK_RESOURCE_PREFIX = "spark.task.resource."
+
+ private[spark] val SPARK_RESOURCE_COUNT_POSTFIX = ".count"
+ private[spark] val SPARK_RESOURCE_DISCOVERY_SCRIPT_POSTFIX = ".discoveryScript"
+
private[spark] val DRIVER_CLASS_PATH =
ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.createOptional
@@ -1303,4 +1310,10 @@ package object config {
.doc("Staging directory used while submitting applications.")
.stringConf
.createOptional
+
+ private[spark] val BUFFER_PAGESIZE = ConfigBuilder("spark.buffer.pageSize")
+ .doc("The amount of memory used per page in bytes")
+ .bytesConf(ByteUnit.BYTE)
+ .createOptional
+
}
diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
index e6e9c9e328853..854093851f5d0 100644
--- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala
@@ -149,7 +149,7 @@ object FileCommitProtocol extends Logging {
logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" +
s" dynamic=$dynamicPartitionOverwrite")
- val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]]
+ val clazz = Utils.classForName[FileCommitProtocol](className)
// First try the constructor with arguments (jobId: String, outputPath: String,
// dynamicPartitionOverwrite: Boolean).
// If that doesn't exist, try the one with (jobId: string, outputPath: String).
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 288c0d18191c3..065f05e87cf34 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -77,8 +77,9 @@ private[spark] object CompressionCodec {
val codecClass =
shortCompressionCodecNames.getOrElse(codecName.toLowerCase(Locale.ROOT), codecName)
val codec = try {
- val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf])
- Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec])
+ val ctor =
+ Utils.classForName[CompressionCodec](codecClass).getConstructor(classOf[SparkConf])
+ Some(ctor.newInstance(conf))
} catch {
case _: ClassNotFoundException | _: IllegalArgumentException => None
}
diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
index ff6d84b57ebcb..c08b47f99dda3 100644
--- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala
@@ -255,7 +255,7 @@ private[spark] abstract class MemoryManager(
}
val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor)
val default = math.min(maxPageSize, math.max(minPageSize, size))
- conf.getSizeAsBytes("spark.buffer.pageSize", default)
+ conf.get(BUFFER_PAGESIZE).getOrElse(default)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 8dad42b6096a9..c96640a6fab3f 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -179,8 +179,8 @@ private[spark] class MetricsSystem private (
sourceConfigs.foreach { kv =>
val classPath = kv._2.getProperty("class")
try {
- val source = Utils.classForName(classPath).getConstructor().newInstance()
- registerSource(source.asInstanceOf[Source])
+ val source = Utils.classForName[Source](classPath).getConstructor().newInstance()
+ registerSource(source)
} catch {
case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e)
}
@@ -195,13 +195,18 @@ private[spark] class MetricsSystem private (
val classPath = kv._2.getProperty("class")
if (null != classPath) {
try {
- val sink = Utils.classForName(classPath)
- .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
- .newInstance(kv._2, registry, securityMgr)
if (kv._1 == "servlet") {
- metricsServlet = Some(sink.asInstanceOf[MetricsServlet])
+ val servlet = Utils.classForName[MetricsServlet](classPath)
+ .getConstructor(
+ classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
+ .newInstance(kv._2, registry, securityMgr)
+ metricsServlet = Some(servlet)
} else {
- sinks += sink.asInstanceOf[Sink]
+ val sink = Utils.classForName[Sink](classPath)
+ .getConstructor(
+ classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager])
+ .newInstance(kv._2, registry, securityMgr)
+ sinks += sink
}
} catch {
case e: Exception =>
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 27f4f94ea55f8..55e7109f935c1 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -20,7 +20,6 @@ package org.apache.spark.network.netty
import java.nio.ByteBuffer
import scala.collection.JavaConverters._
-import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark.internal.Logging
@@ -66,12 +65,7 @@ class NettyBlockRpcServer(
case uploadBlock: UploadBlock =>
// StorageLevel and ClassTag are serialized as bytes using our JavaSerializer.
- val (level: StorageLevel, classTag: ClassTag[_]) = {
- serializer
- .newInstance()
- .deserialize(ByteBuffer.wrap(uploadBlock.metadata))
- .asInstanceOf[(StorageLevel, ClassTag[_])]
- }
+ val (level, classTag) = deserializeMetadata(uploadBlock.metadata)
val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
val blockId = BlockId(uploadBlock.blockId)
logDebug(s"Receiving replicated block $blockId with level ${level} " +
@@ -87,12 +81,7 @@ class NettyBlockRpcServer(
responseContext: RpcResponseCallback): StreamCallbackWithID = {
val message =
BlockTransferMessage.Decoder.fromByteBuffer(messageHeader).asInstanceOf[UploadBlockStream]
- val (level: StorageLevel, classTag: ClassTag[_]) = {
- serializer
- .newInstance()
- .deserialize(ByteBuffer.wrap(message.metadata))
- .asInstanceOf[(StorageLevel, ClassTag[_])]
- }
+ val (level, classTag) = deserializeMetadata(message.metadata)
val blockId = BlockId(message.blockId)
logDebug(s"Receiving replicated block $blockId with level ${level} as stream " +
s"from ${client.getSocketAddress}")
@@ -101,5 +90,12 @@ class NettyBlockRpcServer(
blockManager.putBlockDataAsStream(blockId, level, classTag)
}
+ private def deserializeMetadata[T](metadata: Array[Byte]): (StorageLevel, ClassTag[T]) = {
+ serializer
+ .newInstance()
+ .deserialize(ByteBuffer.wrap(metadata))
+ .asInstanceOf[(StorageLevel, ClassTag[T])]
+ }
+
override def getStreamManager(): StreamManager = streamManager
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
index 3ba0a0a750f97..c9103045260f2 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
@@ -36,16 +36,26 @@ object SparkTransportConf {
* @param numUsableCores if nonzero, this will restrict the server and client threads to only
* use the given number of cores, rather than all of the machine's cores.
* This restriction will only occur if these properties are not already set.
+ * @param role optional role, could be driver, executor, worker and master. Default is
+ * [[None]], means no role specific configurations.
*/
- def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = {
+ def fromSparkConf(
+ _conf: SparkConf,
+ module: String,
+ numUsableCores: Int = 0,
+ role: Option[String] = None): TransportConf = {
val conf = _conf.clone
-
- // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily
- // assuming we have all the machine's cores).
- // NB: Only set if serverThreads/clientThreads not already set.
+ // specify default thread configuration based on our JVM's allocation of cores (rather than
+ // necessarily assuming we have all the machine's cores).
val numThreads = NettyUtils.defaultNumThreads(numUsableCores)
- conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString)
- conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString)
+ // override threads configurations with role specific values if specified
+ // config order is role > module > default
+ Seq("serverThreads", "clientThreads").foreach { suffix =>
+ val value = role.flatMap { r => conf.getOption(s"spark.$r.$module.io.$suffix") }
+ .getOrElse(
+ conf.get(s"spark.$module.io.$suffix", numThreads.toString))
+ conf.set(s"spark.$module.io.$suffix", value)
+ }
new TransportConf(module, new ConfigProvider {
override def get(name: String): String = conf.get(name)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index 7e76731f5e454..909f58512153b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -20,7 +20,6 @@ package org.apache.spark.rdd
import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
-import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark._
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 44bc40b5babbd..836d3e231269d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -21,7 +21,6 @@ import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.language.existentials
import scala.reflect.ClassTag
import org.apache.spark._
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
index 117f51c5b8f2a..f6b20593462cd 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala
@@ -24,7 +24,7 @@ package org.apache.spark.rpc
private[spark] trait RpcCallContext {
/**
- * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]]
+ * Reply a message to the sender. If the sender is [[RpcEndpoint]], its `RpcEndpoint.receive`
* will be called.
*/
def reply(response: Any): Unit
diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
index 3dc41f7f12798..770ae2f1dd22f 100644
--- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala
@@ -52,7 +52,7 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S
*
* @note This can be used in the recover callback of a Future to add to a TimeoutException
* Example:
- * val timeout = new RpcTimeout(5 millis, "short timeout")
+ * val timeout = new RpcTimeout(5.milliseconds, "short timeout")
* Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout)
*/
def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = {
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
index ce238a256cfb4..2f923d7902b05 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala
@@ -24,8 +24,9 @@ import scala.collection.JavaConverters._
import scala.concurrent.Promise
import scala.util.control.NonFatal
-import org.apache.spark.SparkException
+import org.apache.spark.{SparkConf, SparkContext, SparkException}
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.EXECUTOR_ID
import org.apache.spark.internal.config.Network.RPC_NETTY_DISPATCHER_NUM_THREADS
import org.apache.spark.network.client.RpcResponseCallback
import org.apache.spark.rpc._
@@ -194,12 +195,22 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
endpoints.containsKey(name)
}
- /** Thread pool used for dispatching messages. */
- private val threadpool: ThreadPoolExecutor = {
+ private def getNumOfThreads(conf: SparkConf): Int = {
val availableCores =
if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
- val numThreads = nettyEnv.conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
+
+ val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS)
.getOrElse(math.max(2, availableCores))
+
+ conf.get(EXECUTOR_ID).map { id =>
+ val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor"
+ conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads)
+ }.getOrElse(modNumThreads)
+ }
+
+ /** Thread pool used for dispatching messages. */
+ private val threadpool: ThreadPoolExecutor = {
+ val numThreads = getNumOfThreads(nettyEnv.conf)
val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop")
for (i <- 0 until numThreads) {
pool.execute(new MessageLoop)
@@ -225,7 +236,15 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte
}
}
} catch {
- case ie: InterruptedException => // exit
+ case _: InterruptedException => // exit
+ case t: Throwable =>
+ try {
+ // Re-submit a MessageLoop so that Dispatcher will still work if
+ // UncaughtExceptionHandler decides to not kill JVM.
+ threadpool.execute(new MessageLoop)
+ } finally {
+ throw t
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
index d32eba64e13e9..44d2622a42f58 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala
@@ -106,7 +106,7 @@ private[netty] class Inbox(
throw new SparkException(s"Unsupported message $message from ${_sender}")
})
} catch {
- case NonFatal(e) =>
+ case e: Throwable =>
context.sendFailure(e)
// Throw the exception -- this exception will be caught by the safelyCall function.
// The endpoint's onError function will be called.
diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
index 472db45490e95..5dce43b7523d9 100644
--- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
+++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
@@ -29,8 +29,9 @@ import scala.reflect.ClassTag
import scala.util.{DynamicVariable, Failure, Success, Try}
import scala.util.control.NonFatal
-import org.apache.spark.{SecurityManager, SparkConf}
+import org.apache.spark.{SecurityManager, SparkConf, SparkContext}
import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.EXECUTOR_ID
import org.apache.spark.internal.config.Network._
import org.apache.spark.network.TransportContext
import org.apache.spark.network.client._
@@ -47,11 +48,15 @@ private[netty] class NettyRpcEnv(
host: String,
securityManager: SecurityManager,
numUsableCores: Int) extends RpcEnv(conf) with Logging {
+ val role = conf.get(EXECUTOR_ID).map { id =>
+ if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor"
+ }
private[netty] val transportConf = SparkTransportConf.fromSparkConf(
conf.clone.set(RPC_IO_NUM_CONNECTIONS_PER_PEER, 1),
"rpc",
- conf.get(RPC_IO_THREADS).getOrElse(numUsableCores))
+ conf.get(RPC_IO_THREADS).getOrElse(numUsableCores),
+ role)
private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index d967d38c52631..de57807639f3b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -24,10 +24,8 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.annotation.tailrec
import scala.collection.Map
-import scala.collection.mutable.{ArrayStack, HashMap, HashSet}
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import scala.concurrent.duration._
-import scala.language.existentials
-import scala.language.postfixOps
import scala.util.control.NonFatal
import org.apache.commons.lang3.SerializationUtils
@@ -270,7 +268,7 @@ private[spark] class DAGScheduler(
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates,
Some(executorUpdates)))
blockManagerMaster.driverEndpoint.askSync[Boolean](
- BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat"))
+ BlockManagerHeartbeat(blockManagerId), new RpcTimeout(10.minutes, "BlockManagerHeartbeat"))
}
/**
@@ -383,7 +381,8 @@ private[spark] class DAGScheduler(
* locations that are still available from the previous shuffle to avoid unnecessarily
* regenerating data.
*/
- def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = {
+ def createShuffleMapStage[K, V, C](
+ shuffleDep: ShuffleDependency[K, V, C], jobId: Int): ShuffleMapStage = {
val rdd = shuffleDep.rdd
checkBarrierStageWithDynamicAllocation(rdd)
checkBarrierStageWithNumSlots(rdd)
@@ -469,21 +468,21 @@ private[spark] class DAGScheduler(
/** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */
private def getMissingAncestorShuffleDependencies(
- rdd: RDD[_]): ArrayStack[ShuffleDependency[_, _, _]] = {
- val ancestors = new ArrayStack[ShuffleDependency[_, _, _]]
+ rdd: RDD[_]): ListBuffer[ShuffleDependency[_, _, _]] = {
+ val ancestors = new ListBuffer[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
- val waitingForVisit = new ArrayStack[RDD[_]]
- waitingForVisit.push(rdd)
+ val waitingForVisit = new ListBuffer[RDD[_]]
+ waitingForVisit += rdd
while (waitingForVisit.nonEmpty) {
- val toVisit = waitingForVisit.pop()
+ val toVisit = waitingForVisit.remove(0)
if (!visited(toVisit)) {
visited += toVisit
getShuffleDependencies(toVisit).foreach { shuffleDep =>
if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) {
- ancestors.push(shuffleDep)
- waitingForVisit.push(shuffleDep.rdd)
+ ancestors.prepend(shuffleDep)
+ waitingForVisit.prepend(shuffleDep.rdd)
} // Otherwise, the dependency and its ancestors have already been registered.
}
}
@@ -507,17 +506,17 @@ private[spark] class DAGScheduler(
rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = {
val parents = new HashSet[ShuffleDependency[_, _, _]]
val visited = new HashSet[RDD[_]]
- val waitingForVisit = new ArrayStack[RDD[_]]
- waitingForVisit.push(rdd)
+ val waitingForVisit = new ListBuffer[RDD[_]]
+ waitingForVisit += rdd
while (waitingForVisit.nonEmpty) {
- val toVisit = waitingForVisit.pop()
+ val toVisit = waitingForVisit.remove(0)
if (!visited(toVisit)) {
visited += toVisit
toVisit.dependencies.foreach {
case shuffleDep: ShuffleDependency[_, _, _] =>
parents += shuffleDep
case dependency =>
- waitingForVisit.push(dependency.rdd)
+ waitingForVisit.prepend(dependency.rdd)
}
}
}
@@ -530,10 +529,10 @@ private[spark] class DAGScheduler(
*/
private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = {
val visited = new HashSet[RDD[_]]
- val waitingForVisit = new ArrayStack[RDD[_]]
- waitingForVisit.push(rdd)
+ val waitingForVisit = new ListBuffer[RDD[_]]
+ waitingForVisit += rdd
while (waitingForVisit.nonEmpty) {
- val toVisit = waitingForVisit.pop()
+ val toVisit = waitingForVisit.remove(0)
if (!visited(toVisit)) {
if (!predicate(toVisit)) {
return false
@@ -543,7 +542,7 @@ private[spark] class DAGScheduler(
case _: ShuffleDependency[_, _, _] =>
// Not within the same stage with current rdd, do nothing.
case dependency =>
- waitingForVisit.push(dependency.rdd)
+ waitingForVisit.prepend(dependency.rdd)
}
}
}
@@ -555,7 +554,8 @@ private[spark] class DAGScheduler(
val visited = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
- val waitingForVisit = new ArrayStack[RDD[_]]
+ val waitingForVisit = new ListBuffer[RDD[_]]
+ waitingForVisit += stage.rdd
def visit(rdd: RDD[_]) {
if (!visited(rdd)) {
visited += rdd
@@ -569,15 +569,14 @@ private[spark] class DAGScheduler(
missing += mapStage
}
case narrowDep: NarrowDependency[_] =>
- waitingForVisit.push(narrowDep.rdd)
+ waitingForVisit.prepend(narrowDep.rdd)
}
}
}
}
}
- waitingForVisit.push(stage.rdd)
while (waitingForVisit.nonEmpty) {
- visit(waitingForVisit.pop())
+ visit(waitingForVisit.remove(0))
}
missing.toList
}
@@ -1390,6 +1389,14 @@ private[spark] class DAGScheduler(
event.reason match {
case Success =>
+ // An earlier attempt of a stage (which is zombie) may still have running tasks. If these
+ // tasks complete, they still count and we can mark the corresponding partitions as
+ // finished. Here we notify the task scheduler to skip running tasks for the same partition,
+ // to save resource.
+ if (task.stageAttemptId < stage.latestInfo.attemptNumber()) {
+ taskScheduler.notifyPartitionCompletion(stageId, task.partitionId)
+ }
+
task match {
case rt: ResultTask[_, _] =>
// Cast to ResultStage here because it's part of the ResultTask
@@ -1993,7 +2000,8 @@ private[spark] class DAGScheduler(
val visitedRdds = new HashSet[RDD[_]]
// We are manually maintaining a stack here to prevent StackOverflowError
// caused by recursively visiting
- val waitingForVisit = new ArrayStack[RDD[_]]
+ val waitingForVisit = new ListBuffer[RDD[_]]
+ waitingForVisit += stage.rdd
def visit(rdd: RDD[_]) {
if (!visitedRdds(rdd)) {
visitedRdds += rdd
@@ -2002,17 +2010,16 @@ private[spark] class DAGScheduler(
case shufDep: ShuffleDependency[_, _, _] =>
val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId)
if (!mapStage.isAvailable) {
- waitingForVisit.push(mapStage.rdd)
+ waitingForVisit.prepend(mapStage.rdd)
} // Otherwise there's no need to follow the dependency back
case narrowDep: NarrowDependency[_] =>
- waitingForVisit.push(narrowDep.rdd)
+ waitingForVisit.prepend(narrowDep.rdd)
}
}
}
}
- waitingForVisit.push(stage.rdd)
while (waitingForVisit.nonEmpty) {
- visit(waitingForVisit.pop())
+ visit(waitingForVisit.remove(0))
}
visitedRdds.contains(target.rdd)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 54ab8f8b3e1d8..b514c2e7056f4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -19,8 +19,6 @@ package org.apache.spark.scheduler
import java.util.Properties
-import scala.language.existentials
-
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.{AccumulatorV2, CallSite}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 189e35ee83119..710f5eb211dde 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -21,13 +21,10 @@ import java.lang.management.ManagementFactory
import java.nio.ByteBuffer
import java.util.Properties
-import scala.language.existentials
-
import org.apache.spark._
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
-import org.apache.spark.shuffle.ShuffleWriter
/**
* A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
@@ -85,13 +82,15 @@ private[spark] class ShuffleMapTask(
threadMXBean.getCurrentThreadCpuTime
} else 0L
val ser = SparkEnv.get.closureSerializer.newInstance()
- val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
+ val rddAndDep = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])](
ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader)
_executorDeserializeTimeNs = System.nanoTime() - deserializeStartTimeNs
_executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) {
threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime
} else 0L
+ val rdd = rddAndDep._1
+ val dep = rddAndDep._2
dep.shuffleWriterProcessor.write(rdd, dep, partitionId, context, partition)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index c6dedaaa9554a..9b7f901c55e00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -155,6 +155,15 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
}
}
+ // This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want
+ // DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's
+ // synchronized and may hurt the throughput of the scheduler.
+ def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = {
+ getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions {
+ scheduler.handlePartitionCompleted(stageId, partitionId)
+ })
+ }
+
def stop() {
getTaskResultExecutor.shutdownNow()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index 94221eb0d5515..bfdbf0217210a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -68,6 +68,10 @@ private[spark] trait TaskScheduler {
// Throw UnsupportedOperationException if the backend doesn't support kill tasks.
def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit
+ // Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed
+ // and they can skip running tasks for it.
+ def notifyPartitionCompletion(stageId: Int, partitionId: Int)
+
// Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
def setDAGScheduler(dagScheduler: DAGScheduler): Unit
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index e401c395a0486..532eb322769aa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -22,7 +22,7 @@ import java.util.{Locale, Timer, TimerTask}
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
import java.util.concurrent.atomic.AtomicLong
-import scala.collection.mutable.{ArrayBuffer, BitSet, HashMap, HashSet}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.util.Random
import org.apache.spark._
@@ -101,9 +101,6 @@ private[spark] class TaskSchedulerImpl(
// Protected by `this`
val taskIdToExecutorId = new HashMap[Long, String]
- // Protected by `this`
- private[scheduler] val stageIdToFinishedPartitions = new HashMap[Int, BitSet]
-
@volatile private var hasReceivedTask = false
@volatile private var hasLaunchedTask = false
private val starvationTimer = new Timer(true)
@@ -252,20 +249,7 @@ private[spark] class TaskSchedulerImpl(
private[scheduler] def createTaskSetManager(
taskSet: TaskSet,
maxTaskFailures: Int): TaskSetManager = {
- // only create a BitSet once for a certain stage since we only remove
- // that stage when an active TaskSetManager succeed.
- stageIdToFinishedPartitions.getOrElseUpdate(taskSet.stageId, new BitSet)
- val tsm = new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt)
- // TaskSet got submitted by DAGScheduler may have some already completed
- // tasks since DAGScheduler does not always know all the tasks that have
- // been completed by other tasksets when completing a stage, so we mark
- // those tasks as finished here to avoid launching duplicate tasks, while
- // holding the TaskSchedulerImpl lock.
- // See SPARK-25250 and `markPartitionCompletedInAllTaskSets()`
- stageIdToFinishedPartitions.get(taskSet.stageId).foreach {
- finishedPartitions => finishedPartitions.foreach(tsm.markPartitionCompleted(_, None))
- }
- tsm
+ new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt)
}
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
@@ -317,6 +301,10 @@ private[spark] class TaskSchedulerImpl(
}
}
+ override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = {
+ taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId)
+ }
+
/**
* Called to indicate that all task attempts (including speculated tasks) associated with the
* given TaskSetManager have completed, so state associated with the TaskSetManager should be
@@ -653,6 +641,21 @@ private[spark] class TaskSchedulerImpl(
}
}
+ /**
+ * Marks the task has completed in the active TaskSetManager for the given stage.
+ *
+ * After stage failure and retry, there may be multiple TaskSetManagers for the stage.
+ * If an earlier zombie attempt of a stage completes a task, we can ask the later active attempt
+ * to skip submitting and running the task for the same partition, to save resource. That also
+ * means that a task completion from an earlier zombie attempt can lead to the entire stage
+ * getting marked as successful.
+ */
+ private[scheduler] def handlePartitionCompleted(stageId: Int, partitionId: Int) = synchronized {
+ taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm =>
+ tsm.markPartitionCompleted(partitionId)
+ })
+ }
+
def error(message: String) {
synchronized {
if (taskSetsByStageIdAndAttempt.nonEmpty) {
@@ -884,36 +887,6 @@ private[spark] class TaskSchedulerImpl(
manager
}
}
-
- /**
- * Marks the task has completed in all TaskSetManagers(active / zombie) for the given stage.
- *
- * After stage failure and retry, there may be multiple TaskSetManagers for the stage.
- * If an earlier attempt of a stage completes a task, we should ensure that the later attempts
- * do not also submit those same tasks. That also means that a task completion from an earlier
- * attempt can lead to the entire stage getting marked as successful.
- * And there is also the possibility that the DAGScheduler submits another taskset at the same
- * time as we're marking a task completed here -- that taskset would have a task for a partition
- * that was already completed. We maintain the set of finished partitions in
- * stageIdToFinishedPartitions, protected by this, so we can detect those tasks when the taskset
- * is submitted. See SPARK-25250 for more details.
- *
- * note: this method must be called with a lock on this.
- */
- private[scheduler] def markPartitionCompletedInAllTaskSets(
- stageId: Int,
- partitionId: Int,
- taskInfo: TaskInfo) = {
- // if we do not find a BitSet for this stage, which means an active TaskSetManager
- // has already succeeded and removed the stage.
- stageIdToFinishedPartitions.get(stageId).foreach{
- finishedPartitions => finishedPartitions += partitionId
- }
- taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm =>
- tsm.markPartitionCompleted(partitionId, Some(taskInfo))
- }
- }
-
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 144422022c22f..52323b3331d7e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -21,7 +21,7 @@ import java.io.NotSerializableException
import java.nio.ByteBuffer
import java.util.concurrent.ConcurrentLinkedQueue
-import scala.collection.mutable.{ArrayBuffer, BitSet, HashMap, HashSet}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
import scala.math.max
import scala.util.control.NonFatal
@@ -62,14 +62,8 @@ private[spark] class TaskSetManager(
private val addedJars = HashMap[String, Long](sched.sc.addedJars.toSeq: _*)
private val addedFiles = HashMap[String, Long](sched.sc.addedFiles.toSeq: _*)
- // Quantile of tasks at which to start speculation
- val speculationQuantile = conf.get(SPECULATION_QUANTILE)
- val speculationMultiplier = conf.get(SPECULATION_MULTIPLIER)
-
val maxResultSize = conf.get(config.MAX_RESULT_SIZE)
- val speculationEnabled = conf.get(SPECULATION_ENABLED)
-
// Serializer for closures and tasks.
val env = SparkEnv.get
val ser = env.closureSerializer.newInstance()
@@ -80,6 +74,12 @@ private[spark] class TaskSetManager(
val numTasks = tasks.length
val copiesRunning = new Array[Int](numTasks)
+ val speculationEnabled = conf.get(SPECULATION_ENABLED)
+ // Quantile of tasks at which to start speculation
+ val speculationQuantile = conf.get(SPECULATION_QUANTILE)
+ val speculationMultiplier = conf.get(SPECULATION_MULTIPLIER)
+ val minFinishedForSpeculation = math.max((speculationQuantile * numTasks).floor.toInt, 1)
+
// For each task, tracks whether a copy of the task has succeeded. A task will also be
// marked as "succeeded" if it failed with a fetch failure, in which case it should not
// be re-run because the missing map data needs to be regenerated first.
@@ -800,19 +800,12 @@ private[spark] class TaskSetManager(
// Mark successful and stop if all the tasks have succeeded.
successful(index) = true
if (tasksSuccessful == numTasks) {
- // clean up finished partitions for the stage when the active TaskSetManager succeed
- if (!isZombie) {
- sched.stageIdToFinishedPartitions -= stageId
- isZombie = true
- }
+ isZombie = true
}
} else {
logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id +
" because task " + index + " has already completed successfully")
}
- // There may be multiple tasksets for this stage -- we let all of them know that the partition
- // was completed. This may result in some of the tasksets getting completed.
- sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info)
// This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the
// "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not
// "deserialize" the value when holding a lock to avoid blocking other threads. So we call
@@ -823,21 +816,13 @@ private[spark] class TaskSetManager(
maybeFinishTaskSet()
}
- private[scheduler] def markPartitionCompleted(
- partitionId: Int,
- taskInfo: Option[TaskInfo]): Unit = {
+ private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = {
partitionToIndex.get(partitionId).foreach { index =>
if (!successful(index)) {
- if (speculationEnabled && !isZombie) {
- taskInfo.foreach { info => successfulTaskDurations.insert(info.duration) }
- }
tasksSuccessful += 1
successful(index) = true
if (tasksSuccessful == numTasks) {
- if (!isZombie) {
- sched.stageIdToFinishedPartitions -= stageId
- isZombie = true
- }
+ isZombie = true
}
maybeFinishTaskSet()
}
@@ -1047,10 +1032,13 @@ private[spark] class TaskSetManager(
return false
}
var foundTasks = false
- val minFinishedForSpeculation = (speculationQuantile * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
- if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
+ // It's possible that a task is marked as completed by the scheduler, then the size of
+ // `successfulTaskDurations` may not equal to `tasksSuccessful`. Here we should only count the
+ // tasks that are submitted by this `TaskSetManager` and are completed successfully.
+ val numSuccessfulTasks = successfulTaskDurations.size()
+ if (numSuccessfulTasks >= minFinishedForSpeculation) {
val time = clock.getTimeMillis()
val medianDuration = successfulTaskDurations.median
val threshold = max(speculationMultiplier * medianDuration, minTimeToSpeculation)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index afb48a31754f9..89425e702677a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster
import java.nio.ByteBuffer
+import org.apache.spark.ResourceInformation
import org.apache.spark.TaskState.TaskState
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.scheduler.ExecutorLossReason
@@ -64,7 +65,8 @@ private[spark] object CoarseGrainedClusterMessages {
hostname: String,
cores: Int,
logUrls: Map[String, String],
- attributes: Map[String, String])
+ attributes: Map[String, String],
+ resources: Map[String, ResourceInformation])
extends CoarseGrainedClusterMessage
case class StatusUpdate(executorId: String, taskId: Long, state: TaskState,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 4830d0e6f8008..f7cf212d0bfe1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -185,7 +185,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
- case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls, attributes) =>
+ case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls,
+ attributes, resources) =>
if (executorDataMap.contains(executorId)) {
executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId))
context.reply(true)
diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
index c65c8fd6e3e1c..e616d239ce8d5 100644
--- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
+++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala
@@ -21,7 +21,6 @@ import java.net.{InetAddress, ServerSocket, Socket}
import scala.concurrent.Promise
import scala.concurrent.duration.Duration
-import scala.language.existentials
import scala.util.Try
import org.apache.spark.SparkEnv
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index eef19973e8d77..39691069bf5f6 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -164,8 +164,8 @@ class KryoSerializer(conf: SparkConf)
}
// Allow the user to register their own classes by setting spark.kryo.registrator.
userRegistrators
- .map(Utils.classForName(_, noSparkClassLoader = true).getConstructor().
- newInstance().asInstanceOf[KryoRegistrator])
+ .map(Utils.classForName[KryoRegistrator](_, noSparkClassLoader = true).
+ getConstructor().newInstance())
.foreach { reg => reg.registerClasses(kryo) }
} catch {
case e: Exception =>
@@ -213,6 +213,10 @@ class KryoSerializer(conf: SparkConf)
// We can't load those class directly in order to avoid unnecessary jar dependencies.
// We load them safely, ignore it if the class not found.
Seq(
+ "org.apache.spark.sql.catalyst.expressions.UnsafeRow",
+ "org.apache.spark.sql.catalyst.expressions.UnsafeArrayData",
+ "org.apache.spark.sql.catalyst.expressions.UnsafeMapData",
+
"org.apache.spark.ml.attribute.Attribute",
"org.apache.spark.ml.attribute.AttributeGroup",
"org.apache.spark.ml.attribute.BinaryAttribute",
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
index 84c2ad48f1f27..83f76db7e89da 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala
@@ -77,7 +77,7 @@ private[spark] trait UIRoot {
/**
* Runs some code with the current SparkUI instance for the app / attempt.
*
- * @throws NoSuchElementException If the app / attempt pair does not exist.
+ * @throws java.util.NoSuchElementException If the app / attempt pair does not exist.
*/
def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T
@@ -85,8 +85,8 @@ private[spark] trait UIRoot {
def getApplicationInfo(appId: String): Option[ApplicationInfo]
/**
- * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is
- * [[None]], event logs for all attempts of this application will be written out.
+ * Write the event logs for the given app to the `ZipOutputStream` instance. If attemptId is
+ * `None`, event logs for all attempts of this application will be written out.
*/
def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit = {
Response.serverError()
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala
index 9d1d66a0e15a4..db53a400ed62f 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala
@@ -24,8 +24,9 @@ import org.apache.spark.SparkException
import org.apache.spark.scheduler.StageInfo
import org.apache.spark.status.api.v1.StageStatus._
import org.apache.spark.status.api.v1.TaskSorting._
-import org.apache.spark.ui.SparkUI
+import org.apache.spark.ui.{SparkUI, UIUtils}
import org.apache.spark.ui.jobs.ApiHelper._
+import org.apache.spark.util.Utils
@Produces(Array(MediaType.APPLICATION_JSON))
private[v1] class StagesResource extends BaseAppResource {
@@ -189,32 +190,42 @@ private[v1] class StagesResource extends BaseAppResource {
val taskMetricsContainsValue = (task: TaskData) => task.taskMetrics match {
case None => false
case Some(metrics) =>
- (containsValue(task.taskMetrics.get.executorDeserializeTime)
- || containsValue(task.taskMetrics.get.executorRunTime)
- || containsValue(task.taskMetrics.get.jvmGcTime)
- || containsValue(task.taskMetrics.get.resultSerializationTime)
- || containsValue(task.taskMetrics.get.memoryBytesSpilled)
- || containsValue(task.taskMetrics.get.diskBytesSpilled)
- || containsValue(task.taskMetrics.get.peakExecutionMemory)
- || containsValue(task.taskMetrics.get.inputMetrics.bytesRead)
+ (containsValue(UIUtils.formatDuration(task.taskMetrics.get.executorDeserializeTime))
+ || containsValue(UIUtils.formatDuration(task.taskMetrics.get.executorRunTime))
+ || containsValue(UIUtils.formatDuration(task.taskMetrics.get.jvmGcTime))
+ || containsValue(UIUtils.formatDuration(task.taskMetrics.get.resultSerializationTime))
+ || containsValue(Utils.bytesToString(task.taskMetrics.get.memoryBytesSpilled))
+ || containsValue(Utils.bytesToString(task.taskMetrics.get.diskBytesSpilled))
+ || containsValue(Utils.bytesToString(task.taskMetrics.get.peakExecutionMemory))
+ || containsValue(Utils.bytesToString(task.taskMetrics.get.inputMetrics.bytesRead))
|| containsValue(task.taskMetrics.get.inputMetrics.recordsRead)
- || containsValue(task.taskMetrics.get.outputMetrics.bytesWritten)
+ || containsValue(Utils.bytesToString(
+ task.taskMetrics.get.outputMetrics.bytesWritten))
|| containsValue(task.taskMetrics.get.outputMetrics.recordsWritten)
- || containsValue(task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime)
+ || containsValue(UIUtils.formatDuration(
+ task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime))
+ || containsValue(Utils.bytesToString(
+ task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead))
+ || containsValue(Utils.bytesToString(
+ task.taskMetrics.get.shuffleReadMetrics.localBytesRead +
+ task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead))
|| containsValue(task.taskMetrics.get.shuffleReadMetrics.recordsRead)
- || containsValue(task.taskMetrics.get.shuffleWriteMetrics.bytesWritten)
+ || containsValue(Utils.bytesToString(
+ task.taskMetrics.get.shuffleWriteMetrics.bytesWritten))
|| containsValue(task.taskMetrics.get.shuffleWriteMetrics.recordsWritten)
- || containsValue(task.taskMetrics.get.shuffleWriteMetrics.writeTime))
+ || containsValue(UIUtils.formatDuration(
+ task.taskMetrics.get.shuffleWriteMetrics.writeTime / 1000000)))
}
val filteredTaskDataSequence: Seq[TaskData] = taskDataList.filter(f =>
(containsValue(f.taskId) || containsValue(f.index) || containsValue(f.attempt)
- || containsValue(f.launchTime)
+ || containsValue(UIUtils.formatDate(f.launchTime))
|| containsValue(f.resultFetchStart.getOrElse(defaultOptionString))
|| containsValue(f.executorId) || containsValue(f.host) || containsValue(f.status)
|| containsValue(f.taskLocality) || containsValue(f.speculative)
|| containsValue(f.errorMessage.getOrElse(defaultOptionString))
|| taskMetricsContainsValue(f)
- || containsValue(f.schedulerDelay) || containsValue(f.gettingResultTime)))
+ || containsValue(UIUtils.formatDuration(f.schedulerDelay))
+ || containsValue(UIUtils.formatDuration(f.gettingResultTime))))
filteredTaskDataSequence
}
diff --git a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala
index fc9b50f14a083..1c0dd7dee2228 100644
--- a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala
+++ b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala
@@ -53,9 +53,24 @@ private class HttpSecurityFilter(
val hres = res.asInstanceOf[HttpServletResponse]
hres.setHeader("Cache-Control", "no-cache, no-store, must-revalidate")
- if (!securityMgr.checkUIViewPermissions(hreq.getRemoteUser())) {
+ val requestUser = hreq.getRemoteUser()
+
+ // The doAs parameter allows proxy servers (e.g. Knox) to impersonate other users. For
+ // that to be allowed, the authenticated user needs to be an admin.
+ val effectiveUser = Option(hreq.getParameter("doAs"))
+ .map { proxy =>
+ if (requestUser != proxy && !securityMgr.checkAdminPermissions(requestUser)) {
+ hres.sendError(HttpServletResponse.SC_FORBIDDEN,
+ s"User $requestUser is not allowed to impersonate others.")
+ return
+ }
+ proxy
+ }
+ .getOrElse(requestUser)
+
+ if (!securityMgr.checkUIViewPermissions(effectiveUser)) {
hres.sendError(HttpServletResponse.SC_FORBIDDEN,
- "User is not authorized to access this page.")
+ s"User $effectiveUser is not authorized to access this page.")
return
}
@@ -77,12 +92,13 @@ private class HttpSecurityFilter(
hres.setHeader("Strict-Transport-Security", _))
}
- chain.doFilter(new XssSafeRequest(hreq), res)
+ chain.doFilter(new XssSafeRequest(hreq, effectiveUser), res)
}
}
-private class XssSafeRequest(req: HttpServletRequest) extends HttpServletRequestWrapper(req) {
+private class XssSafeRequest(req: HttpServletRequest, effectiveUser: String)
+ extends HttpServletRequestWrapper(req) {
private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r
@@ -92,6 +108,8 @@ private class XssSafeRequest(req: HttpServletRequest) extends HttpServletRequest
}.toMap
}
+ override def getRemoteUser(): String = effectiveUser
+
override def getParameterMap(): JMap[String, Array[String]] = parameterMap.asJava
override def getParameterNames(): Enumeration[String] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index d7bda8b80c24f..11647c0c7c623 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -109,12 +109,12 @@ private[spark] object UIUtils extends Logging {
}
}
// if time is more than a year
- return s"$yearString $weekString $dayString"
+ s"$yearString $weekString $dayString"
} catch {
case e: Exception =>
logError("Error converting time to string", e)
// if there is some error, return blank string
- return ""
+ ""
}
}
@@ -336,7 +336,7 @@ private[spark] object UIUtils extends Logging {
def getHeaderContent(header: String): Seq[Node] = {
if (newlinesInHeader) {
- { header.split("\n").map { case t => - {t}
} }
+ { header.split("\n").map(t => - {t}
) }
} else {
Text(header)
@@ -446,7 +446,7 @@ private[spark] object UIUtils extends Logging {
* the whole string will rendered as a simple escaped text.
*
* Note: In terms of security, only anchor tags with root relative links are supported. So any
- * attempts to embed links outside Spark UI, or other tags like {@code