diff --git a/LICENSE b/LICENSE index b771bd552b762..150ccc54ec6c2 100644 --- a/LICENSE +++ b/LICENSE @@ -222,7 +222,7 @@ Python Software Foundation License ---------------------------------- pyspark/heapq3.py - +python/docs/_static/copybutton.js BSD 3-Clause ------------ @@ -258,4 +258,4 @@ data/mllib/images/kittens/29.5.a_b_EGDP022204.jpg data/mllib/images/kittens/54893.jpg data/mllib/images/kittens/DP153539.jpg data/mllib/images/kittens/DP802813.jpg -data/mllib/images/multi-channel/chr30.4.184.jpg \ No newline at end of file +data/mllib/images/multi-channel/chr30.4.184.jpg diff --git a/LICENSE-binary b/LICENSE-binary index 5f57133bef43d..2ff881fac5fb0 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -302,7 +302,6 @@ com.google.code.gson:gson com.google.inject:guice com.google.inject.extensions:guice-servlet com.twitter:parquet-hadoop-bundle -commons-beanutils:commons-beanutils-core commons-cli:commons-cli commons-dbcp:commons-dbcp commons-io:commons-io @@ -490,7 +489,6 @@ Eclipse Distribution License (EDL) 1.0 org.glassfish.jaxb:jaxb-runtime jakarta.xml.bind:jakarta.xml.bind-api com.sun.istack:istack-commons-runtime -jakarta.activation:jakarta.activation-api Mozilla Public License (MPL) 1.1 diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 0566a47cc8755..3bd1f544d77a5 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -3589,6 +3589,8 @@ setMethod("element_at", #' @details #' \code{explode}: Creates a new row for each element in the given array or map column. +#' Uses the default column name \code{col} for elements in the array and +#' \code{key} and \code{value} for elements in the map unless specified otherwise. #' #' @rdname column_collection_functions #' @aliases explode explode,Column-method @@ -3649,7 +3651,9 @@ setMethod("sort_array", #' @details #' \code{posexplode}: Creates a new row for each element with position in the given array -#' or map column. +#' or map column. Uses the default column name \code{pos} for position, and \code{col} +#' for elements in the array and \code{key} and \code{value} for elements in the map +#' unless specified otherwise. #' #' @rdname column_collection_functions #' @aliases posexplode posexplode,Column-method @@ -3790,7 +3794,8 @@ setMethod("repeat_string", #' \code{explode}: Creates a new row for each element in the given array or map column. #' Unlike \code{explode}, if the array/map is \code{null} or empty #' then \code{null} is produced. -#' +#' Uses the default column name \code{col} for elements in the array and +#' \code{key} and \code{value} for elements in the map unless specified otherwise. #' #' @rdname column_collection_functions #' @aliases explode_outer explode_outer,Column-method @@ -3815,6 +3820,9 @@ setMethod("explode_outer", #' \code{posexplode_outer}: Creates a new row for each element with position in the given #' array or map column. Unlike \code{posexplode}, if the array/map is \code{null} or empty #' then the row (\code{null}, \code{null}) is produced. +#' Uses the default column name \code{pos} for position, and \code{col} +#' for elements in the array and \code{key} and \code{value} for elements in the map +#' unless specified otherwise. #' #' @rdname column_collection_functions #' @aliases posexplode_outer posexplode_outer,Column-method diff --git a/README.md b/README.md index 271f2f5f5b1c3..482c00764380a 100644 --- a/README.md +++ b/README.md @@ -1,18 +1,18 @@ # Apache Spark -[![Jenkins Build](https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/badge/icon)](https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7) -[![AppVeyor Build](https://img.shields.io/appveyor/ci/ApacheSoftwareFoundation/spark/master.svg?style=plastic&logo=appveyor)](https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark) -[![PySpark Coverage](https://img.shields.io/badge/dynamic/xml.svg?label=pyspark%20coverage&url=https%3A%2F%2Fspark-test.github.io%2Fpyspark-coverage-site&query=%2Fhtml%2Fbody%2Fdiv%5B1%5D%2Fdiv%2Fh1%2Fspan&colorB=brightgreen&style=plastic)](https://spark-test.github.io/pyspark-coverage-site) - -Spark is a fast and general cluster computing system for Big Data. It provides +Spark is a unified analytics engine for large-scale data processing. It provides high-level APIs in Scala, Java, Python, and R, and an optimized engine that supports general computation graphs for data analysis. It also supports a rich set of higher-level tools including Spark SQL for SQL and DataFrames, MLlib for machine learning, GraphX for graph processing, -and Spark Streaming for stream processing. +and Structured Streaming for stream processing. +[![Jenkins Build](https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7/badge/icon)](https://amplab.cs.berkeley.edu/jenkins/job/spark-master-test-sbt-hadoop-2.7) +[![AppVeyor Build](https://img.shields.io/appveyor/ci/ApacheSoftwareFoundation/spark/master.svg?style=plastic&logo=appveyor)](https://ci.appveyor.com/project/ApacheSoftwareFoundation/spark) +[![PySpark Coverage](https://img.shields.io/badge/dynamic/xml.svg?label=pyspark%20coverage&url=https%3A%2F%2Fspark-test.github.io%2Fpyspark-coverage-site&query=%2Fhtml%2Fbody%2Fdiv%5B1%5D%2Fdiv%2Fh1%2Fspan&colorB=brightgreen&style=plastic)](https://spark-test.github.io/pyspark-coverage-site) + ## Online Documentation @@ -41,9 +41,9 @@ The easiest way to start using Spark is through the Scala shell: ./bin/spark-shell -Try the following command, which should return 1000: +Try the following command, which should return 1,000,000,000: - scala> sc.parallelize(1 to 1000).count() + scala> spark.range(1000 * 1000 * 1000).count() ## Interactive Python Shell @@ -51,9 +51,9 @@ Alternatively, if you prefer Python, you can use the Python shell: ./bin/pyspark -And run the following command, which should also return 1000: +And run the following command, which should also return 1,000,000,000: - >>> sc.parallelize(range(1000)).count() + >>> spark.range(1000 * 1000 * 1000).count() ## Example Programs diff --git a/bin/docker-image-tool.sh b/bin/docker-image-tool.sh index 0388e23979dda..68fafbb848001 100755 --- a/bin/docker-image-tool.sh +++ b/bin/docker-image-tool.sh @@ -282,7 +282,7 @@ do if ! minikube status 1>/dev/null; then error "Cannot contact minikube. Make sure it's running." fi - eval $(minikube docker-env) + eval $(minikube docker-env --shell bash) ;; u) SPARK_UID=${OPTARG};; esac diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java index a530e16734db4..6f90df5f611a9 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RetryingBlockFetcherSuite.java @@ -46,9 +46,9 @@ */ public class RetryingBlockFetcherSuite { - ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); - ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); - ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); + private final ManagedBuffer block0 = new NioManagedBuffer(ByteBuffer.wrap(new byte[13])); + private final ManagedBuffer block1 = new NioManagedBuffer(ByteBuffer.wrap(new byte[7])); + private final ManagedBuffer block2 = new NioManagedBuffer(ByteBuffer.wrap(new byte[19])); @Test public void testNoFailures() throws IOException, InterruptedException { @@ -291,7 +291,7 @@ private static void performInteractions(List> inte } assertNotNull(stub); - stub.when(fetchStarter).createAndStart(any(), anyObject()); + stub.when(fetchStarter).createAndStart(any(), any()); String[] blockIdArray = blockIds.toArray(new String[blockIds.size()]); new RetryingBlockFetcher(conf, fetchStarter, blockIdArray, listener).start(); } diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 55cdc3140aa08..c642f3b4a1600 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -35,7 +35,7 @@ provided ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar - org/spark_project/ + org/sparkproject/ @@ -128,6 +128,50 @@ + + + org.codehaus.mojo + build-helper-maven-plugin + + + regex-property + + regex-property + + + spark.shade.native.packageName + ${spark.shade.packageName} + \. + _ + true + + + + + + org.apache.maven.plugins + maven-antrun-plugin + + + unpack + package + + + + + + + + + + + run + + + + diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java index 621f2c6bf3777..e36efa3b0f22b 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/CalendarInterval.java @@ -18,6 +18,7 @@ package org.apache.spark.unsafe.types; import java.io.Serializable; +import java.util.Locale; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -66,6 +67,10 @@ private static long toLong(String s) { } } + /** + * Convert a string to CalendarInterval. Return null if the input string is not a valid interval. + * This method is case-sensitive and all characters in the input string should be in lower case. + */ public static CalendarInterval fromString(String s) { if (s == null) { return null; @@ -87,6 +92,26 @@ public static CalendarInterval fromString(String s) { } } + /** + * Convert a string to CalendarInterval. Unlike fromString, this method is case-insensitive and + * will throw IllegalArgumentException when the input string is not a valid interval. + * + * @throws IllegalArgumentException if the string is not a valid internal. + */ + public static CalendarInterval fromCaseInsensitiveString(String s) { + if (s == null || s.trim().isEmpty()) { + throw new IllegalArgumentException("Interval cannot be null or blank."); + } + String sInLowerCase = s.trim().toLowerCase(Locale.ROOT); + String interval = + sInLowerCase.startsWith("interval ") ? sInLowerCase : "interval " + sInLowerCase; + CalendarInterval cal = fromString(interval); + if (cal == null) { + throw new IllegalArgumentException("Invalid interval: " + s); + } + return cal; + } + public static long toLongWithRange(String fieldName, String s, long minValue, long maxValue) throws IllegalArgumentException { long result = 0; @@ -319,6 +344,8 @@ public String toString() { appendUnit(sb, rest / MICROS_PER_MILLI, "millisecond"); rest %= MICROS_PER_MILLI; appendUnit(sb, rest, "microsecond"); + } else if (months == 0) { + sb.append(" 0 microseconds"); } return sb.toString(); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java index 9e69e264ff287..994af8f082447 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java @@ -41,6 +41,9 @@ public void equalsTest() { public void toStringTest() { CalendarInterval i; + i = new CalendarInterval(0, 0); + assertEquals("interval 0 microseconds", i.toString()); + i = new CalendarInterval(34, 0); assertEquals("interval 2 years 10 months", i.toString()); @@ -101,6 +104,31 @@ public void fromStringTest() { assertNull(fromString(input)); } + @Test + public void fromCaseInsensitiveStringTest() { + for (String input : new String[]{"5 MINUTES", "5 minutes", "5 Minutes"}) { + assertEquals(fromCaseInsensitiveString(input), new CalendarInterval(0, 5L * 60 * 1_000_000)); + } + + for (String input : new String[]{null, "", " "}) { + try { + fromCaseInsensitiveString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("cannot be null or blank")); + } + } + + for (String input : new String[]{"interval", "interval1 day", "foo", "foo 1 day"}) { + try { + fromCaseInsensitiveString(input); + fail("Expected to throw an exception for the invalid input"); + } catch (IllegalArgumentException e) { + assertTrue(e.getMessage().contains("Invalid interval")); + } + } + } + @Test public void fromYearMonthStringTest() { String input; diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index ec1aa187dfb32..e91595dd324b0 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: log4j.logger.org.apache.spark.repl.Main=WARN # Settings to quiet third party logs that are too verbose -log4j.logger.org.spark_project.jetty=WARN -log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.sparkproject.jetty=WARN +log4j.logger.org.sparkproject.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO log4j.logger.org.apache.parquet=ERROR diff --git a/core/pom.xml b/core/pom.xml index 45bda44916f52..9d57028019880 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -347,7 +347,7 @@ net.razorvine pyrolite - 4.13 + 4.23 net.razorvine diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 277010015072a..eb12848900b58 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: log4j.logger.org.apache.spark.repl.Main=WARN # Settings to quiet third party logs that are too verbose -log4j.logger.org.spark_project.jetty=WARN -log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.sparkproject.jetty=WARN +log4j.logger.org.sparkproject.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala index 2d842b98ead43..a354f44a1be19 100644 --- a/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala +++ b/core/src/main/scala/org/apache/spark/BarrierTaskContext.scala @@ -20,7 +20,6 @@ package org.apache.spark import java.util.{Properties, Timer, TimerTask} import scala.concurrent.duration._ -import scala.language.postfixOps import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.executor.TaskMetrics @@ -122,7 +121,7 @@ class BarrierTaskContext private[spark] ( barrierEpoch), // Set a fixed timeout for RPC here, so users shall get a SparkException thrown by // BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework. - timeout = new RpcTimeout(31536000 /* = 3600 * 24 * 365 */ seconds, "barrierTimeout")) + timeout = new RpcTimeout(365.days, "barrierTimeout")) barrierEpoch += 1 logInfo(s"Task $taskAttemptId from Stage $stageId(Attempt $stageAttemptNumber) finished " + "global sync successfully, waited for " + diff --git a/core/src/main/scala/org/apache/spark/ResourceDiscoverer.scala b/core/src/main/scala/org/apache/spark/ResourceDiscoverer.scala new file mode 100644 index 0000000000000..19639420b8b9f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/ResourceDiscoverer.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import java.io.File + +import com.fasterxml.jackson.core.JsonParseException +import org.json4s.{DefaultFormats, MappingException} +import org.json4s.JsonAST.JValue +import org.json4s.jackson.JsonMethods._ + +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ +import org.apache.spark.util.Utils.executeAndGetOutput + +/** + * Discovers resources (GPUs/FPGAs/etc). It currently only supports resources that have + * addresses. + * This class finds resources by running and parsing the output of the user specified script + * from the config spark.{driver/executor}.resource.{resourceType}.discoveryScript. + * The output of the script it runs is expected to be JSON in the format of the + * ResourceInformation class. + * + * For example: {"name": "gpu", "addresses": ["0","1"]} + */ +private[spark] object ResourceDiscoverer extends Logging { + + private implicit val formats = DefaultFormats + + def findResources(sparkConf: SparkConf, isDriver: Boolean): Map[String, ResourceInformation] = { + val prefix = if (isDriver) { + SPARK_DRIVER_RESOURCE_PREFIX + } else { + SPARK_EXECUTOR_RESOURCE_PREFIX + } + // get unique resource types by grabbing first part config with multiple periods, + // ie resourceType.count, grab resourceType part + val resourceNames = sparkConf.getAllWithPrefix(prefix).map { case (k, _) => + k.split('.').head + }.toSet + resourceNames.map { rName => { + val rInfo = getResourceInfoForType(sparkConf, prefix, rName) + (rName -> rInfo) + }}.toMap + } + + private def getResourceInfoForType( + sparkConf: SparkConf, + prefix: String, + resourceType: String): ResourceInformation = { + val discoveryConf = prefix + resourceType + SPARK_RESOURCE_DISCOVERY_SCRIPT_POSTFIX + val script = sparkConf.getOption(discoveryConf) + val result = if (script.nonEmpty) { + val scriptFile = new File(script.get) + // check that script exists and try to execute + if (scriptFile.exists()) { + try { + val output = executeAndGetOutput(Seq(script.get), new File(".")) + val parsedJson = parse(output) + val name = (parsedJson \ "name").extract[String] + val addresses = (parsedJson \ "addresses").extract[Array[String]].toArray + new ResourceInformation(name, addresses) + } catch { + case e @ (_: SparkException | _: MappingException | _: JsonParseException) => + throw new SparkException(s"Error running the resource discovery script: $scriptFile" + + s" for $resourceType", e) + } + } else { + throw new SparkException(s"Resource script: $scriptFile to discover $resourceType" + + s" doesn't exist!") + } + } else { + throw new SparkException(s"User is expecting to use $resourceType resources but " + + s"didn't specify a script via conf: $discoveryConf, to find them!") + } + result + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/core/src/main/scala/org/apache/spark/ResourceInformation.scala similarity index 54% rename from sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java rename to core/src/main/scala/org/apache/spark/ResourceInformation.scala index c44b8af2552f0..6a5b725ac21d7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java +++ b/core/src/main/scala/org/apache/spark/ResourceInformation.scala @@ -15,15 +15,23 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.streaming; +package org.apache.spark + +import org.apache.spark.annotation.Evolving /** - * The shared interface between V1 streaming sources and V2 streaming readers. + * Class to hold information about a type of Resource. A resource could be a GPU, FPGA, etc. + * The array of addresses are resource specific and its up to the user to interpret the address. + * + * One example is GPUs, where the addresses would be the indices of the GPUs * - * This is a temporary interface for compatibility during migration. It should not be implemented - * directly, and will be removed in future versions. + * @param name the name of the resource + * @param addresses an array of strings describing the addresses of the resource */ -public interface BaseStreamingSource { - /** Stop this source and free any resources it has allocated. */ - void stop(); +@Evolving +class ResourceInformation( + val name: String, + val addresses: Array[String]) extends Serializable { + + override def toString: String = s"[name: ${name}, addresses: ${addresses.mkString(",")}]" } diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala index 26b18564be778..5a5c5a403f202 100644 --- a/core/src/main/scala/org/apache/spark/SecurityManager.scala +++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala @@ -214,6 +214,15 @@ private[spark] class SecurityManager( */ def aclsEnabled(): Boolean = aclsOn + /** + * Checks whether the given user is an admin. This gives the user both view and + * modify permissions, and also allows the user to impersonate other users when + * making UI requests. + */ + def checkAdminPermissions(user: String): Boolean = { + isUserInACL(user, adminAcls, adminAclsGroups) + } + /** * Checks the given user against the view acl and groups list to see if they have * authorization to view the UI. If the UI acls are disabled @@ -227,13 +236,7 @@ private[spark] class SecurityManager( def checkUIViewPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " viewAcls=" + viewAcls.mkString(",") + " viewAclsGroups=" + viewAclsGroups.mkString(",")) - if (!aclsEnabled || user == null || viewAcls.contains(user) || - viewAcls.contains(WILDCARD_ACL) || viewAclsGroups.contains(WILDCARD_ACL)) { - return true - } - val currentUserGroups = Utils.getCurrentUserGroups(sparkConf, user) - logDebug("userGroups=" + currentUserGroups.mkString(",")) - viewAclsGroups.exists(currentUserGroups.contains(_)) + isUserInACL(user, viewAcls, viewAclsGroups) } /** @@ -249,13 +252,7 @@ private[spark] class SecurityManager( def checkModifyPermissions(user: String): Boolean = { logDebug("user=" + user + " aclsEnabled=" + aclsEnabled() + " modifyAcls=" + modifyAcls.mkString(",") + " modifyAclsGroups=" + modifyAclsGroups.mkString(",")) - if (!aclsEnabled || user == null || modifyAcls.contains(user) || - modifyAcls.contains(WILDCARD_ACL) || modifyAclsGroups.contains(WILDCARD_ACL)) { - return true - } - val currentUserGroups = Utils.getCurrentUserGroups(sparkConf, user) - logDebug("userGroups=" + currentUserGroups) - modifyAclsGroups.exists(currentUserGroups.contains(_)) + isUserInACL(user, modifyAcls, modifyAclsGroups) } /** @@ -371,6 +368,23 @@ private[spark] class SecurityManager( } } + private def isUserInACL( + user: String, + aclUsers: Set[String], + aclGroups: Set[String]): Boolean = { + if (user == null || + !aclsEnabled || + aclUsers.contains(WILDCARD_ACL) || + aclUsers.contains(user) || + aclGroups.contains(WILDCARD_ACL)) { + true + } else { + val userGroups = Utils.getCurrentUserGroups(sparkConf, user) + logDebug(s"user $user is in groups ${userGroups.mkString(",")}") + aclGroups.exists(userGroups.contains(_)) + } + } + // Default SecurityManager only has a single secret key, so ignore appId. override def getSaslUser(appId: String): String = getSaslUser() override def getSecretKey(appId: String): String = getSecretKey() diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 913a1704ad5ce..aa93f42141fc1 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -168,6 +168,15 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging with Seria } /** Set multiple parameters together */ + def setAll(settings: Iterable[(String, String)]): SparkConf = { + settings.foreach { case (k, v) => set(k, v) } + this + } + + /** + * Set multiple parameters together + */ + @deprecated("Use setAll(Iterable) instead", "3.0.0") def setAll(settings: Traversable[(String, String)]): SparkConf = { settings.foreach { case (k, v) => set(k, v) } this @@ -705,7 +714,9 @@ private[spark] object SparkConf extends Logging { AlternateConfig("spark.yarn.kerberos.relogin.period", "3.0")), KERBEROS_FILESYSTEMS_TO_ACCESS.key -> Seq( AlternateConfig("spark.yarn.access.namenodes", "2.2"), - AlternateConfig("spark.yarn.access.hadoopFileSystems", "3.0")) + AlternateConfig("spark.yarn.access.hadoopFileSystems", "3.0")), + "spark.kafka.consumer.cache.capacity" -> Seq( + AlternateConfig("spark.sql.kafkaConsumerCache.capacity", "3.0")) ) /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 8b744356daaee..997941080fc6f 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2556,7 +2556,7 @@ object SparkContext extends Logging { private[spark] val DRIVER_IDENTIFIER = "driver" - private implicit def arrayToArrayWritable[T <: Writable : ClassTag](arr: Traversable[T]) + private implicit def arrayToArrayWritable[T <: Writable : ClassTag](arr: Iterable[T]) : ArrayWritable = { def anyToWritable[U <: Writable](u: U): Writable = u diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index c2ebd388a2365..c97b10ee63b18 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -192,6 +192,20 @@ private[spark] object TestUtils { assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") } + /** + * Asserts that exception message contains the message. Please note this checks all + * exceptions in the tree. + */ + def assertExceptionMsg(exception: Throwable, msg: String): Unit = { + var e = exception + var contains = e.getMessage.contains(msg) + while (e.getCause != null && !contains) { + e = e.getCause + contains = e.getMessage.contains(msg) + } + assert(contains, s"Exception tree doesn't contain the expected message: $msg") + } + /** * Test if a command is available. */ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 2ab8add63efae..a4817b3cf770d 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -33,18 +33,17 @@ import org.apache.spark.util.{SerializableConfiguration, Utils} * A trait for use with reading custom classes in PySpark. Implement this trait and add custom * transformation code by overriding the convert method. */ -trait Converter[T, + U] extends Serializable { +trait Converter[-T, +U] extends Serializable { def convert(obj: T): U } private[python] object Converter extends Logging { - def getInstance(converterClass: Option[String], - defaultConverter: Converter[Any, Any]): Converter[Any, Any] = { + def getInstance[T, U](converterClass: Option[String], + defaultConverter: Converter[_ >: T, _ <: U]): Converter[T, U] = { converterClass.map { cc => Try { - val c = Utils.classForName(cc).getConstructor(). - newInstance().asInstanceOf[Converter[Any, Any]] + val c = Utils.classForName[Converter[T, U]](cc).getConstructor().newInstance() logInfo(s"Loaded converter: $cc") c } match { @@ -177,8 +176,8 @@ private[python] object PythonHadoopUtil { * [[org.apache.hadoop.io.Writable]], into an RDD of base types, or vice versa. */ def convertRDD[K, V](rdd: RDD[(K, V)], - keyConverter: Converter[Any, Any], - valueConverter: Converter[Any, Any]): RDD[(Any, Any)] = { + keyConverter: Converter[K, Any], + valueConverter: Converter[V, Any]): RDD[(Any, Any)] = { rdd.map { case (k, v) => (keyConverter.convert(k), valueConverter.convert(v)) } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 5b492b1f3991e..fe25c3aac81b8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -24,10 +24,6 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable -import scala.concurrent.Promise -import scala.concurrent.duration.Duration -import scala.language.existentials -import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec @@ -90,7 +86,7 @@ private[spark] case class PythonFunction( private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) /** Thrown for exceptions in user Python code. */ -private[spark] class PythonException(msg: String, cause: Exception) +private[spark] class PythonException(msg: String, cause: Throwable) extends RuntimeException(msg, cause) /** @@ -167,8 +163,63 @@ private[spark] object PythonRDD extends Logging { serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") } + /** + * A helper function to create a local RDD iterator and serve it via socket. Partitions are + * are collected as separate jobs, by order of index. Partition data is first requested by a + * non-zero integer to start a collection job. The response is prefaced by an integer with 1 + * meaning partition data will be served, 0 meaning the local iterator has been consumed, + * and -1 meaining an error occurred during collection. This function is used by + * pyspark.rdd._local_iterator_from_socket(). + * + * @return 2-tuple (as a Java array) with the port number of a local socket which serves the + * data collected from these jobs, and the secret for authentication. + */ def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { - serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") + val (port, secret) = SocketAuthServer.setupOneConnectionServer( + authHelper, "serve toLocalIterator") { s => + val out = new DataOutputStream(s.getOutputStream) + val in = new DataInputStream(s.getInputStream) + Utils.tryWithSafeFinally { + + // Collects a partition on each iteration + val collectPartitionIter = rdd.partitions.indices.iterator.map { i => + rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head + } + + // Read request for data and send next partition if nonzero + var complete = false + while (!complete && in.readInt() != 0) { + if (collectPartitionIter.hasNext) { + try { + // Attempt to collect the next partition + val partitionArray = collectPartitionIter.next() + + // Send response there is a partition to read + out.writeInt(1) + + // Write the next object and signal end of data for this iteration + writeIteratorToStream(partitionArray.toIterator, out) + out.writeInt(SpecialLengths.END_OF_DATA_SECTION) + out.flush() + } catch { + case e: SparkException => + // Send response that an error occurred followed by error message + out.writeInt(-1) + writeUTF(e.getMessage, out) + complete = true + } + } else { + // Send response there are no more partitions to read and close + out.writeInt(0) + complete = true + } + } + } { + out.close() + in.close() + } + } + Array(port, secret) } def readRDDFromFile( @@ -228,8 +279,8 @@ private[spark] object PythonRDD extends Logging { batchSize: Int): JavaRDD[Array[Byte]] = { val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") - val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] - val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] + val kc = Utils.classForName[K](keyClass) + val vc = Utils.classForName[V](valueClass) val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) val confBroadcasted = sc.sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration())) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, @@ -296,9 +347,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration): RDD[(K, V)] = { - val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] - val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] - val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName[K](keyClass) + val vc = Utils.classForName[V](valueClass) + val fc = Utils.classForName[F](inputFormatClass) if (path.isDefined) { sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf) } else { @@ -365,9 +416,9 @@ private[spark] object PythonRDD extends Logging { keyClass: String, valueClass: String, conf: Configuration) = { - val kc = Utils.classForName(keyClass).asInstanceOf[Class[K]] - val vc = Utils.classForName(valueClass).asInstanceOf[Class[V]] - val fc = Utils.classForName(inputFormatClass).asInstanceOf[Class[F]] + val kc = Utils.classForName[K](keyClass) + val vc = Utils.classForName[V](valueClass) + val fc = Utils.classForName[F](inputFormatClass) if (path.isDefined) { sc.sc.hadoopFile(path.get, fc, kc, vc) } else { @@ -425,29 +476,33 @@ private[spark] object PythonRDD extends Logging { PythonHadoopUtil.mergeConfs(baseConf, conf) } - private def inferKeyValueTypes[K, V](rdd: RDD[(K, V)], keyConverterClass: String = null, - valueConverterClass: String = null): (Class[_], Class[_]) = { + private def inferKeyValueTypes[K, V, KK, VV](rdd: RDD[(K, V)], keyConverterClass: String = null, + valueConverterClass: String = null): (Class[_ <: KK], Class[_ <: VV]) = { // Peek at an element to figure out key/value types. Since Writables are not serializable, // we cannot call first() on the converted RDD. Instead, we call first() on the original RDD // and then convert locally. val (key, value) = rdd.first() - val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass, - new JavaToWritableConverter) + val (kc, vc) = getKeyValueConverters[K, V, KK, VV]( + keyConverterClass, valueConverterClass, new JavaToWritableConverter) (kc.convert(key).getClass, vc.convert(value).getClass) } - private def getKeyValueTypes(keyClass: String, valueClass: String): - Option[(Class[_], Class[_])] = { + private def getKeyValueTypes[K, V](keyClass: String, valueClass: String): + Option[(Class[K], Class[V])] = { for { k <- Option(keyClass) v <- Option(valueClass) } yield (Utils.classForName(k), Utils.classForName(v)) } - private def getKeyValueConverters(keyConverterClass: String, valueConverterClass: String, - defaultConverter: Converter[Any, Any]): (Converter[Any, Any], Converter[Any, Any]) = { - val keyConverter = Converter.getInstance(Option(keyConverterClass), defaultConverter) - val valueConverter = Converter.getInstance(Option(valueConverterClass), defaultConverter) + private def getKeyValueConverters[K, V, KK, VV]( + keyConverterClass: String, + valueConverterClass: String, + defaultConverter: Converter[_, _]): (Converter[K, KK], Converter[V, VV]) = { + val keyConverter = Converter.getInstance(Option(keyConverterClass), + defaultConverter.asInstanceOf[Converter[K, KK]]) + val valueConverter = Converter.getInstance(Option(valueConverterClass), + defaultConverter.asInstanceOf[Converter[V, VV]]) (keyConverter, valueConverter) } @@ -459,7 +514,7 @@ private[spark] object PythonRDD extends Logging { keyConverterClass: String, valueConverterClass: String, defaultConverter: Converter[Any, Any]): RDD[(Any, Any)] = { - val (kc, vc) = getKeyValueConverters(keyConverterClass, valueConverterClass, + val (kc, vc) = getKeyValueConverters[K, V, Any, Any](keyConverterClass, valueConverterClass, defaultConverter) PythonHadoopUtil.convertRDD(rdd, kc, vc) } @@ -470,7 +525,7 @@ private[spark] object PythonRDD extends Logging { * [[org.apache.hadoop.io.Writable]] types already, since Writables are not Java * `Serializable` and we can't peek at them. The `path` can be on any Hadoop file system. */ - def saveAsSequenceFile[K, V, C <: CompressionCodec]( + def saveAsSequenceFile[C <: CompressionCodec]( pyRDD: JavaRDD[Array[Byte]], batchSerialized: Boolean, path: String, @@ -489,7 +544,7 @@ private[spark] object PythonRDD extends Logging { * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of * this RDD. */ - def saveAsHadoopFile[K, V, F <: OutputFormat[_, _], C <: CompressionCodec]( + def saveAsHadoopFile[F <: OutputFormat[_, _], C <: CompressionCodec]( pyRDD: JavaRDD[Array[Byte]], batchSerialized: Boolean, path: String, @@ -507,7 +562,7 @@ private[spark] object PythonRDD extends Logging { val codec = Option(compressionCodecClass).map(Utils.classForName(_).asInstanceOf[Class[C]]) val converted = convertRDD(rdd, keyConverterClass, valueConverterClass, new JavaToWritableConverter) - val fc = Utils.classForName(outputFormatClass).asInstanceOf[Class[F]] + val fc = Utils.classForName[F](outputFormatClass) converted.saveAsHadoopFile(path, kc, vc, fc, new JobConf(mergedConf), codec = codec) } @@ -520,7 +575,7 @@ private[spark] object PythonRDD extends Logging { * `confAsMap` is merged with the default Hadoop conf associated with the SparkContext of * this RDD. */ - def saveAsNewAPIHadoopFile[K, V, F <: NewOutputFormat[_, _]]( + def saveAsNewAPIHadoopFile[F <: NewOutputFormat[_, _]]( pyRDD: JavaRDD[Array[Byte]], batchSerialized: Boolean, path: String, @@ -548,7 +603,7 @@ private[spark] object PythonRDD extends Logging { * (mapred vs. mapreduce). Keys/values are converted for output using either user specified * converters or, by default, [[org.apache.spark.api.python.JavaToWritableConverter]]. */ - def saveAsHadoopDataset[K, V]( + def saveAsHadoopDataset( pyRDD: JavaRDD[Array[Byte]], batchSerialized: Boolean, confAsMap: java.util.HashMap[String, String], diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index b7f14e062b437..dca87044513c2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.internal.Logging @@ -165,15 +166,15 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( context: TaskContext) extends Thread(s"stdout writer for $pythonExec") { - @volatile private var _exception: Exception = null + @volatile private var _exception: Throwable = null private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) setDaemon(true) - /** Contains the exception thrown while writing the parent iterator to the Python process. */ - def exception: Option[Exception] = Option(_exception) + /** Contains the throwable thrown while writing the parent iterator to the Python process. */ + def exception: Option[Throwable] = Option(_exception) /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ def shutdownOnTaskCompletion() { @@ -347,18 +348,21 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( dataOut.writeInt(SpecialLengths.END_OF_STREAM) dataOut.flush() } catch { - case e: Exception if context.isCompleted || context.isInterrupted => - logDebug("Exception thrown after task completion (likely due to cleanup)", e) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see org.apache.spark.executor.Executor). - _exception = e - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) + case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => + if (context.isCompleted || context.isInterrupted) { + logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + + "cleanup)", t) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } else { + // We must avoid throwing exceptions/NonFatals here, because the thread uncaught + // exception handler will kill the whole executor (see + // org.apache.spark.executor.Executor). + _exception = t + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } } } } @@ -511,7 +515,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( if (!context.isCompleted) { try { // Mimic the task name used in `Executor` to help the user find out the task to blame. - val taskName = s"${context.partitionId}.${context.taskAttemptId} " + + val taskName = s"${context.partitionId}.${context.attemptNumber} " + s"in stage ${context.stageId} (TID ${context.taskAttemptId})" logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker") env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index 01e64b6972ae2..9462dfd950bab 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -186,6 +186,9 @@ private[spark] object SerDeUtil extends Logging { val unpickle = new Unpickler iter.flatMap { row => val obj = unpickle.loads(row) + // `Opcodes.MEMOIZE` of Protocol 4 (Python 3.4+) will store objects in internal map + // of `Unpickler`. This map is cleared when calling `Unpickler.close()`. + unpickle.close() if (batched) { obj match { case array: Array[Any] => array.toSeq diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala index 36b4132088b58..c755dcba6bcea 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackend.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackend.scala @@ -30,7 +30,7 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder} import io.netty.handler.timeout.ReadTimeoutHandler -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config.R._ @@ -47,7 +47,7 @@ private[spark] class RBackend { private[r] val jvmObjectTracker = new JVMObjectTracker def init(): (Int, RAuthHelper) = { - val conf = new SparkConf() + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT) bossGroup = new NioEventLoopGroup(conf.get(R_NUM_BACKEND_THREADS)) val workerGroup = bossGroup @@ -124,7 +124,7 @@ private[spark] object RBackend extends Logging { val listenPort = serverSocket.getLocalPort() // Connection timeout is set by socket client. To make it configurable we will pass the // timeout value to client inside the temp file - val conf = new SparkConf() + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT) // tell the R process via temporary file diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 7b74efa41044f..f2f81b11fc813 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -20,13 +20,11 @@ package org.apache.spark.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} import java.util.concurrent.TimeUnit -import scala.language.existentials - import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.channel.ChannelHandler.Sharable import io.netty.handler.timeout.ReadTimeoutException -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.api.r.SerDe._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.R._ @@ -98,7 +96,7 @@ private[r] class RBackendHandler(server: RBackend) ctx.write(pingBaos.toByteArray) } } - val conf = new SparkConf() + val conf = Option(SparkEnv.get).map(_.conf).getOrElse(new SparkConf()) val heartBeatInterval = conf.get(R_HEARTBEAT_INTERVAL) val backendConnectionTimeout = conf.get(R_BACKEND_CONNECTION_TIMEOUT) val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index a66243012041c..99f841234005e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -26,7 +26,6 @@ import scala.collection.mutable.ListBuffer import scala.concurrent.{Future, Promise} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ -import scala.language.postfixOps import scala.sys.process._ import org.json4s._ @@ -112,7 +111,7 @@ private object FaultToleranceTest extends App with Logging { assertValidClusterState() killLeader() - delay(30 seconds) + delay(30.seconds) assertValidClusterState() createClient() assertValidClusterState() @@ -126,12 +125,12 @@ private object FaultToleranceTest extends App with Logging { killLeader() addMasters(1) - delay(30 seconds) + delay(30.seconds) assertValidClusterState() killLeader() addMasters(1) - delay(30 seconds) + delay(30.seconds) assertValidClusterState() } @@ -156,7 +155,7 @@ private object FaultToleranceTest extends App with Logging { killLeader() workers.foreach(_.kill()) workers.clear() - delay(30 seconds) + delay(30.seconds) addWorkers(2) assertValidClusterState() } @@ -174,7 +173,7 @@ private object FaultToleranceTest extends App with Logging { (1 to 3).foreach { _ => killLeader() - delay(30 seconds) + delay(30.seconds) assertValidClusterState() assertTrue(getLeader == masters.head) addMasters(1) @@ -264,7 +263,7 @@ private object FaultToleranceTest extends App with Logging { } // Avoid waiting indefinitely (e.g., we could register but get no executors). - assertTrue(ThreadUtils.awaitResult(f, 120 seconds)) + assertTrue(ThreadUtils.awaitResult(f, 2.minutes)) } /** @@ -317,7 +316,7 @@ private object FaultToleranceTest extends App with Logging { } try { - assertTrue(ThreadUtils.awaitResult(f, 120 seconds)) + assertTrue(ThreadUtils.awaitResult(f, 2.minutes)) } catch { case e: TimeoutException => logError("Master states: " + masters.map(_.state)) @@ -421,7 +420,7 @@ private object SparkDocker { } dockerCmd.run(ProcessLogger(findIpAndLog _)) - val ip = ThreadUtils.awaitResult(ipPromise.future, 30 seconds) + val ip = ThreadUtils.awaitResult(ipPromise.future, 30.seconds) val dockerId = Docker.getLastProcessId (ip, dockerId, outFile) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 9efaaa765b5d1..49d939539719d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -544,10 +544,14 @@ private[spark] class SparkSubmit extends Logging { // Yarn only OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.queue"), - OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles"), - OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"), - OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"), - OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives"), + OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles", + mergeFn = Some(mergeFileLists(_, _))), + OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars", + mergeFn = Some(mergeFileLists(_, _))), + OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files", + mergeFn = Some(mergeFileLists(_, _))), + OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives", + mergeFn = Some(mergeFileLists(_, _))), // Other options OptionAssigner(args.numExecutors, YARN | KUBERNETES, ALL_DEPLOY_MODES, @@ -608,7 +612,13 @@ private[spark] class SparkSubmit extends Logging { (deployMode & opt.deployMode) != 0 && (clusterManager & opt.clusterManager) != 0) { if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } - if (opt.confKey != null) { sparkConf.set(opt.confKey, opt.value) } + if (opt.confKey != null) { + if (opt.mergeFn.isDefined && sparkConf.contains(opt.confKey)) { + sparkConf.set(opt.confKey, opt.mergeFn.get.apply(sparkConf.get(opt.confKey), opt.value)) + } else { + sparkConf.set(opt.confKey, opt.value) + } + } } } @@ -1381,7 +1391,8 @@ private case class OptionAssigner( clusterManager: Int, deployMode: Int, clOption: String = null, - confKey: String = null) + confKey: String = null, + mergeFn: Option[(String, String) => String] = None) private[spark] trait SparkSubmitOperation { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 7c9ce14c652c4..7df36c5aeba07 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -34,7 +34,6 @@ import org.apache.spark.internal.config.History import org.apache.spark.internal.config.UI._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} -import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.{ShutdownHookManager, SystemClock, Utils} /** @@ -274,10 +273,9 @@ object HistoryServer extends Logging { val providerName = conf.get(History.PROVIDER) .getOrElse(classOf[FsHistoryProvider].getName()) - val provider = Utils.classForName(providerName) + val provider = Utils.classForName[ApplicationHistoryProvider](providerName) .getConstructor(classOf[SparkConf]) .newInstance(conf) - .asInstanceOf[ApplicationHistoryProvider] val port = conf.get(History.HISTORY_SERVER_UI_PORT) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 9d2301cea9b33..6f1484cee586e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -78,8 +78,8 @@ private[deploy] class ExecutorRunner( // Shutdown hook that kills actors on shutdown. shutdownHook = ShutdownHookManager.addShutdownHook { () => // It's possible that we arrive here before calling `fetchAndRunExecutor`, then `state` will - // be `ExecutorState.RUNNING`. In this case, we should set `state` to `FAILED`. - if (state == ExecutorState.RUNNING) { + // be `ExecutorState.LAUNCHING`. In this case, we should set `state` to `FAILED`. + if (state == ExecutorState.LAUNCHING) { state = ExecutorState.FAILED } killProcess(Some("Worker shutting down")) } @@ -183,6 +183,8 @@ private[deploy] class ExecutorRunner( Files.write(header, stderr, StandardCharsets.UTF_8) stderrAppender = FileAppender(process.getErrorStream, stderr, conf) + state = ExecutorState.RUNNING + worker.send(ExecutorStateChanged(appId, execId, state, None, None)) // Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown) // or with nonzero exit code val exitCode = process.waitFor() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index eb2add3af8251..f8ec5b6b190c1 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -540,12 +540,12 @@ private[deploy] class Worker( executorDir, workerUri, conf, - appLocalDirs, ExecutorState.RUNNING) + appLocalDirs, + ExecutorState.LAUNCHING) executors(appId + "/" + execId) = manager manager.start() coresUsed += cores_ memoryUsed += memory_ - sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 645f58716de63..af01e0b23dada 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -17,6 +17,7 @@ package org.apache.spark.executor +import java.io.{BufferedInputStream, FileInputStream} import java.net.URL import java.nio.ByteBuffer import java.util.Locale @@ -26,11 +27,18 @@ import scala.collection.mutable import scala.util.{Failure, Success} import scala.util.control.NonFatal +import com.fasterxml.jackson.databind.exc.MismatchedInputException +import org.json4s.DefaultFormats +import org.json4s.JsonAST.JArray +import org.json4s.MappingException +import org.json4s.jackson.JsonMethods._ + import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.rpc._ import org.apache.spark.scheduler.{ExecutorLossReason, TaskDescription} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ @@ -44,9 +52,12 @@ private[spark] class CoarseGrainedExecutorBackend( hostname: String, cores: Int, userClassPath: Seq[URL], - env: SparkEnv) + env: SparkEnv, + resourcesFile: Option[String]) extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { + private implicit val formats = DefaultFormats + private[this] val stopping = new AtomicBoolean(false) var executor: Executor = null @volatile var driver: Option[RpcEndpointRef] = None @@ -57,11 +68,12 @@ private[spark] class CoarseGrainedExecutorBackend( override def onStart() { logInfo("Connecting to driver: " + driverUrl) + val resources = parseOrFindResources(resourcesFile) rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref => // This is a very fast action so we can use "ThreadUtils.sameThread" driver = Some(ref) ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, extractLogUrls, - extractAttributes)) + extractAttributes, resources)) }(ThreadUtils.sameThread).onComplete { // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => @@ -71,6 +83,97 @@ private[spark] class CoarseGrainedExecutorBackend( }(ThreadUtils.sameThread) } + // Check that the actual resources discovered will satisfy the user specified + // requirements and that they match the configs specified by the user to catch + // mismatches between what the user requested and what resource manager gave or + // what the discovery script found. + private def checkResourcesMeetRequirements( + resourceConfigPrefix: String, + reqResourcesAndCounts: Array[(String, String)], + actualResources: Map[String, ResourceInformation]): Unit = { + + reqResourcesAndCounts.foreach { case (rName, reqCount) => + if (actualResources.contains(rName)) { + val resourceInfo = actualResources(rName) + + if (resourceInfo.addresses.size < reqCount.toLong) { + throw new SparkException(s"Resource: $rName with addresses: " + + s"${resourceInfo.addresses.mkString(",")} doesn't meet the " + + s"requirements of needing $reqCount of them") + } + // also make sure the resource count on start matches the + // resource configs specified by user + val userCountConfigName = + resourceConfigPrefix + rName + SPARK_RESOURCE_COUNT_POSTFIX + val userConfigCount = env.conf.getOption(userCountConfigName). + getOrElse(throw new SparkException(s"Resource: $rName not specified " + + s"via config: $userCountConfigName, but required, " + + "please fix your configuration")) + + if (userConfigCount.toLong > resourceInfo.addresses.size) { + throw new SparkException(s"Resource: $rName, with addresses: " + + s"${resourceInfo.addresses.mkString(",")} " + + s"is less than what the user requested for count: $userConfigCount, " + + s"via $userCountConfigName") + } + } else { + throw new SparkException(s"Executor resource config missing required task resource: $rName") + } + } + } + + // visible for testing + def parseOrFindResources(resourcesFile: Option[String]): Map[String, ResourceInformation] = { + // only parse the resources if a task requires them + val taskResourceConfigs = env.conf.getAllWithPrefix(SPARK_TASK_RESOURCE_PREFIX) + val resourceInfo = if (taskResourceConfigs.nonEmpty) { + val execResources = resourcesFile.map { resourceFileStr => { + val source = new BufferedInputStream(new FileInputStream(resourceFileStr)) + val resourceMap = try { + val parsedJson = parse(source).asInstanceOf[JArray].arr + parsedJson.map { json => + val name = (json \ "name").extract[String] + val addresses = (json \ "addresses").extract[Array[String]] + new ResourceInformation(name, addresses) + }.map(x => (x.name -> x)).toMap + } catch { + case e @ (_: MappingException | _: MismatchedInputException) => + throw new SparkException( + s"Exception parsing the resources in $resourceFileStr", e) + } finally { + source.close() + } + resourceMap + }}.getOrElse(ResourceDiscoverer.findResources(env.conf, isDriver = false)) + + if (execResources.isEmpty) { + throw new SparkException("User specified resources per task via: " + + s"$SPARK_TASK_RESOURCE_PREFIX, but can't find any resources available on the executor.") + } + // get just the map of resource name to count + val resourcesAndCounts = taskResourceConfigs. + withFilter { case (k, v) => k.endsWith(SPARK_RESOURCE_COUNT_POSTFIX)}. + map { case (k, v) => (k.dropRight(SPARK_RESOURCE_COUNT_POSTFIX.size), v)} + + checkResourcesMeetRequirements(SPARK_EXECUTOR_RESOURCE_PREFIX, resourcesAndCounts, + execResources) + + logInfo("===============================================================================") + logInfo(s"Executor $executorId Resources:") + execResources.foreach { case (k, v) => logInfo(s"$k -> $v") } + logInfo("===============================================================================") + + execResources + } else { + if (resourcesFile.nonEmpty) { + logWarning(s"A resources file was specified but the application is not configured " + + s"to use any resources, see the configs with prefix: ${SPARK_TASK_RESOURCE_PREFIX}") + } + Map.empty[String, ResourceInformation] + } + resourceInfo + } + def extractLogUrls: Map[String, String] = { val prefix = "SPARK_LOG_URL_" sys.env.filterKeys(_.startsWith(prefix)) @@ -188,13 +291,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { cores: Int, appId: String, workerUrl: Option[String], - userClassPath: mutable.ListBuffer[URL]) + userClassPath: mutable.ListBuffer[URL], + resourcesFile: Option[String]) def main(args: Array[String]): Unit = { val createFn: (RpcEnv, Arguments, SparkEnv) => CoarseGrainedExecutorBackend = { case (rpcEnv, arguments, env) => new CoarseGrainedExecutorBackend(rpcEnv, arguments.driverUrl, arguments.executorId, - arguments.hostname, arguments.cores, arguments.userClassPath, env) + arguments.hostname, arguments.cores, arguments.userClassPath, env, arguments.resourcesFile) } run(parseArguments(args, this.getClass.getCanonicalName.stripSuffix("$")), createFn) System.exit(0) @@ -239,6 +343,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { SparkHadoopUtil.get.addDelegationTokens(tokens, driverConf) } + driverConf.set(EXECUTOR_ID, arguments.executorId) val env = SparkEnv.createExecutorEnv(driverConf, arguments.executorId, arguments.hostname, arguments.cores, cfg.ioEncryptionKey, isLocal = false) @@ -255,6 +360,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { var executorId: String = null var hostname: String = null var cores: Int = 0 + var resourcesFile: Option[String] = None var appId: String = null var workerUrl: Option[String] = None val userClassPath = new mutable.ListBuffer[URL]() @@ -274,6 +380,9 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { case ("--cores") :: value :: tail => cores = value.toInt argv = tail + case ("--resourcesFile") :: value :: tail => + resourcesFile = Some(value) + argv = tail case ("--app-id") :: value :: tail => appId = value argv = tail @@ -299,7 +408,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { } Arguments(driverUrl, executorId, hostname, cores, appId, workerUrl, - userClassPath) + userClassPath, resourcesFile) } private def printUsageAndExit(classNameForEntry: String): Unit = { @@ -313,6 +422,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { | --executor-id | --hostname | --cores + | --resourcesFile | --app-id | --worker-url | --user-class-path diff --git a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala b/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala deleted file mode 100644 index e91ddd3e9741a..0000000000000 --- a/core/src/main/scala/org/apache/spark/internal/config/Kafka.scala +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.internal.config - -private[spark] object Kafka { - - val BOOTSTRAP_SERVERS = - ConfigBuilder("spark.kafka.bootstrap.servers") - .doc("A list of coma separated host/port pairs to use for establishing the initial " + - "connection to the Kafka cluster. For further details please see kafka documentation. " + - "Only used to obtain delegation token.") - .stringConf - .createOptional - - val SECURITY_PROTOCOL = - ConfigBuilder("spark.kafka.security.protocol") - .doc("Protocol used to communicate with brokers. For further details please see kafka " + - "documentation. Only used to obtain delegation token.") - .stringConf - .createWithDefault("SASL_SSL") - - val KERBEROS_SERVICE_NAME = - ConfigBuilder("spark.kafka.sasl.kerberos.service.name") - .doc("The Kerberos principal name that Kafka runs as. This can be defined either in " + - "Kafka's JAAS config or in Kafka's config. For further details please see kafka " + - "documentation. Only used to obtain delegation token.") - .stringConf - .createWithDefault("kafka") - - val TRUSTSTORE_LOCATION = - ConfigBuilder("spark.kafka.ssl.truststore.location") - .doc("The location of the trust store file. For further details please see kafka " + - "documentation. Only used to obtain delegation token.") - .stringConf - .createOptional - - val TRUSTSTORE_PASSWORD = - ConfigBuilder("spark.kafka.ssl.truststore.password") - .doc("The store password for the trust store file. This is optional for client and only " + - "needed if ssl.truststore.location is configured. For further details please see kafka " + - "documentation. Only used to obtain delegation token.") - .stringConf - .createOptional - - val KEYSTORE_LOCATION = - ConfigBuilder("spark.kafka.ssl.keystore.location") - .doc("The location of the key store file. This is optional for client and can be used for " + - "two-way authentication for client. For further details please see kafka documentation. " + - "Only used to obtain delegation token.") - .stringConf - .createOptional - - val KEYSTORE_PASSWORD = - ConfigBuilder("spark.kafka.ssl.keystore.password") - .doc("The store password for the key store file. This is optional for client and only " + - "needed if ssl.keystore.location is configured. For further details please see kafka " + - "documentation. Only used to obtain delegation token.") - .stringConf - .createOptional - - val KEY_PASSWORD = - ConfigBuilder("spark.kafka.ssl.key.password") - .doc("The password of the private key in the key store file. This is optional for client. " + - "For further details please see kafka documentation. Only used to obtain delegation token.") - .stringConf - .createOptional - - val TOKEN_SASL_MECHANISM = - ConfigBuilder("spark.kafka.sasl.token.mechanism") - .doc("SASL mechanism used for client connections with delegation token. Because SCRAM " + - "login module used for authentication a compatible mechanism has to be set here. " + - "For further details please see kafka documentation (sasl.mechanism). Only used to " + - "authenticate against Kafka broker with delegation token.") - .stringConf - .createWithDefault("SCRAM-SHA-512") -} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 0bd46bef35d27..0aed1af023f83 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -30,6 +30,13 @@ import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterSpillReader.MAX_ package object config { + private[spark] val SPARK_DRIVER_RESOURCE_PREFIX = "spark.driver.resource." + private[spark] val SPARK_EXECUTOR_RESOURCE_PREFIX = "spark.executor.resource." + private[spark] val SPARK_TASK_RESOURCE_PREFIX = "spark.task.resource." + + private[spark] val SPARK_RESOURCE_COUNT_POSTFIX = ".count" + private[spark] val SPARK_RESOURCE_DISCOVERY_SCRIPT_POSTFIX = ".discoveryScript" + private[spark] val DRIVER_CLASS_PATH = ConfigBuilder(SparkLauncher.DRIVER_EXTRA_CLASSPATH).stringConf.createOptional @@ -1303,4 +1310,10 @@ package object config { .doc("Staging directory used while submitting applications.") .stringConf .createOptional + + private[spark] val BUFFER_PAGESIZE = ConfigBuilder("spark.buffer.pageSize") + .doc("The amount of memory used per page in bytes") + .bytesConf(ByteUnit.BYTE) + .createOptional + } diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index e6e9c9e328853..854093851f5d0 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -149,7 +149,7 @@ object FileCommitProtocol extends Logging { logDebug(s"Creating committer $className; job $jobId; output=$outputPath;" + s" dynamic=$dynamicPartitionOverwrite") - val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] + val clazz = Utils.classForName[FileCommitProtocol](className) // First try the constructor with arguments (jobId: String, outputPath: String, // dynamicPartitionOverwrite: Boolean). // If that doesn't exist, try the one with (jobId: string, outputPath: String). diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 288c0d18191c3..065f05e87cf34 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -77,8 +77,9 @@ private[spark] object CompressionCodec { val codecClass = shortCompressionCodecNames.getOrElse(codecName.toLowerCase(Locale.ROOT), codecName) val codec = try { - val ctor = Utils.classForName(codecClass).getConstructor(classOf[SparkConf]) - Some(ctor.newInstance(conf).asInstanceOf[CompressionCodec]) + val ctor = + Utils.classForName[CompressionCodec](codecClass).getConstructor(classOf[SparkConf]) + Some(ctor.newInstance(conf)) } catch { case _: ClassNotFoundException | _: IllegalArgumentException => None } diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index ff6d84b57ebcb..c08b47f99dda3 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -255,7 +255,7 @@ private[spark] abstract class MemoryManager( } val size = ByteArrayMethods.nextPowerOf2(maxTungstenMemory / cores / safetyFactor) val default = math.min(maxPageSize, math.max(minPageSize, size)) - conf.getSizeAsBytes("spark.buffer.pageSize", default) + conf.get(BUFFER_PAGESIZE).getOrElse(default) } /** diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 8dad42b6096a9..c96640a6fab3f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -179,8 +179,8 @@ private[spark] class MetricsSystem private ( sourceConfigs.foreach { kv => val classPath = kv._2.getProperty("class") try { - val source = Utils.classForName(classPath).getConstructor().newInstance() - registerSource(source.asInstanceOf[Source]) + val source = Utils.classForName[Source](classPath).getConstructor().newInstance() + registerSource(source) } catch { case e: Exception => logError("Source class " + classPath + " cannot be instantiated", e) } @@ -195,13 +195,18 @@ private[spark] class MetricsSystem private ( val classPath = kv._2.getProperty("class") if (null != classPath) { try { - val sink = Utils.classForName(classPath) - .getConstructor(classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) - .newInstance(kv._2, registry, securityMgr) if (kv._1 == "servlet") { - metricsServlet = Some(sink.asInstanceOf[MetricsServlet]) + val servlet = Utils.classForName[MetricsServlet](classPath) + .getConstructor( + classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) + .newInstance(kv._2, registry, securityMgr) + metricsServlet = Some(servlet) } else { - sinks += sink.asInstanceOf[Sink] + val sink = Utils.classForName[Sink](classPath) + .getConstructor( + classOf[Properties], classOf[MetricRegistry], classOf[SecurityManager]) + .newInstance(kv._2, registry, securityMgr) + sinks += sink } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 27f4f94ea55f8..55e7109f935c1 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -20,7 +20,6 @@ package org.apache.spark.network.netty import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark.internal.Logging @@ -66,12 +65,7 @@ class NettyBlockRpcServer( case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. - val (level: StorageLevel, classTag: ClassTag[_]) = { - serializer - .newInstance() - .deserialize(ByteBuffer.wrap(uploadBlock.metadata)) - .asInstanceOf[(StorageLevel, ClassTag[_])] - } + val (level, classTag) = deserializeMetadata(uploadBlock.metadata) val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData)) val blockId = BlockId(uploadBlock.blockId) logDebug(s"Receiving replicated block $blockId with level ${level} " + @@ -87,12 +81,7 @@ class NettyBlockRpcServer( responseContext: RpcResponseCallback): StreamCallbackWithID = { val message = BlockTransferMessage.Decoder.fromByteBuffer(messageHeader).asInstanceOf[UploadBlockStream] - val (level: StorageLevel, classTag: ClassTag[_]) = { - serializer - .newInstance() - .deserialize(ByteBuffer.wrap(message.metadata)) - .asInstanceOf[(StorageLevel, ClassTag[_])] - } + val (level, classTag) = deserializeMetadata(message.metadata) val blockId = BlockId(message.blockId) logDebug(s"Receiving replicated block $blockId with level ${level} as stream " + s"from ${client.getSocketAddress}") @@ -101,5 +90,12 @@ class NettyBlockRpcServer( blockManager.putBlockDataAsStream(blockId, level, classTag) } + private def deserializeMetadata[T](metadata: Array[Byte]): (StorageLevel, ClassTag[T]) = { + serializer + .newInstance() + .deserialize(ByteBuffer.wrap(metadata)) + .asInstanceOf[(StorageLevel, ClassTag[T])] + } + override def getStreamManager(): StreamManager = streamManager } diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index 3ba0a0a750f97..c9103045260f2 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -36,16 +36,26 @@ object SparkTransportConf { * @param numUsableCores if nonzero, this will restrict the server and client threads to only * use the given number of cores, rather than all of the machine's cores. * This restriction will only occur if these properties are not already set. + * @param role optional role, could be driver, executor, worker and master. Default is + * [[None]], means no role specific configurations. */ - def fromSparkConf(_conf: SparkConf, module: String, numUsableCores: Int = 0): TransportConf = { + def fromSparkConf( + _conf: SparkConf, + module: String, + numUsableCores: Int = 0, + role: Option[String] = None): TransportConf = { val conf = _conf.clone - - // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily - // assuming we have all the machine's cores). - // NB: Only set if serverThreads/clientThreads not already set. + // specify default thread configuration based on our JVM's allocation of cores (rather than + // necessarily assuming we have all the machine's cores). val numThreads = NettyUtils.defaultNumThreads(numUsableCores) - conf.setIfMissing(s"spark.$module.io.serverThreads", numThreads.toString) - conf.setIfMissing(s"spark.$module.io.clientThreads", numThreads.toString) + // override threads configurations with role specific values if specified + // config order is role > module > default + Seq("serverThreads", "clientThreads").foreach { suffix => + val value = role.flatMap { r => conf.getOption(s"spark.$r.$module.io.$suffix") } + .getOrElse( + conf.get(s"spark.$module.io.$suffix", numThreads.toString)) + conf.set(s"spark.$module.io.$suffix", value) + } new TransportConf(module, new ConfigProvider { override def get(name: String): String = conf.get(name) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 7e76731f5e454..909f58512153b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -20,7 +20,6 @@ package org.apache.spark.rdd import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer -import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark._ diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 44bc40b5babbd..836d3e231269d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -21,7 +21,6 @@ import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark._ diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala index 117f51c5b8f2a..f6b20593462cd 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcCallContext.scala @@ -24,7 +24,7 @@ package org.apache.spark.rpc private[spark] trait RpcCallContext { /** - * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its `RpcEndpoint.receive` * will be called. */ def reply(response: Any): Unit diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 3dc41f7f12798..770ae2f1dd22f 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -52,7 +52,7 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S * * @note This can be used in the recover callback of a Future to add to a TimeoutException * Example: - * val timeout = new RpcTimeout(5 millis, "short timeout") + * val timeout = new RpcTimeout(5.milliseconds, "short timeout") * Future(throw new TimeoutException).recover(timeout.addMessageIfTimeout) */ def addMessageIfTimeout[T]: PartialFunction[Throwable, T] = { diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index ce238a256cfb4..2f923d7902b05 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -24,8 +24,9 @@ import scala.collection.JavaConverters._ import scala.concurrent.Promise import scala.util.control.NonFatal -import org.apache.spark.SparkException +import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.EXECUTOR_ID import org.apache.spark.internal.config.Network.RPC_NETTY_DISPATCHER_NUM_THREADS import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ @@ -194,12 +195,22 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte endpoints.containsKey(name) } - /** Thread pool used for dispatching messages. */ - private val threadpool: ThreadPoolExecutor = { + private def getNumOfThreads(conf: SparkConf): Int = { val availableCores = if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() - val numThreads = nettyEnv.conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS) + + val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS) .getOrElse(math.max(2, availableCores)) + + conf.get(EXECUTOR_ID).map { id => + val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor" + conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads) + }.getOrElse(modNumThreads) + } + + /** Thread pool used for dispatching messages. */ + private val threadpool: ThreadPoolExecutor = { + val numThreads = getNumOfThreads(nettyEnv.conf) val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") for (i <- 0 until numThreads) { pool.execute(new MessageLoop) @@ -225,7 +236,15 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte } } } catch { - case ie: InterruptedException => // exit + case _: InterruptedException => // exit + case t: Throwable => + try { + // Re-submit a MessageLoop so that Dispatcher will still work if + // UncaughtExceptionHandler decides to not kill JVM. + threadpool.execute(new MessageLoop) + } finally { + throw t + } } } } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index d32eba64e13e9..44d2622a42f58 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -106,7 +106,7 @@ private[netty] class Inbox( throw new SparkException(s"Unsupported message $message from ${_sender}") }) } catch { - case NonFatal(e) => + case e: Throwable => context.sendFailure(e) // Throw the exception -- this exception will be caught by the safelyCall function. // The endpoint's onError function will be called. diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index 472db45490e95..5dce43b7523d9 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -29,8 +29,9 @@ import scala.reflect.ClassTag import scala.util.{DynamicVariable, Failure, Success, Try} import scala.util.control.NonFatal -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.EXECUTOR_ID import org.apache.spark.internal.config.Network._ import org.apache.spark.network.TransportContext import org.apache.spark.network.client._ @@ -47,11 +48,15 @@ private[netty] class NettyRpcEnv( host: String, securityManager: SecurityManager, numUsableCores: Int) extends RpcEnv(conf) with Logging { + val role = conf.get(EXECUTOR_ID).map { id => + if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor" + } private[netty] val transportConf = SparkTransportConf.fromSparkConf( conf.clone.set(RPC_IO_NUM_CONNECTIONS_PER_PEER, 1), "rpc", - conf.get(RPC_IO_THREADS).getOrElse(numUsableCores)) + conf.get(RPC_IO_THREADS).getOrElse(numUsableCores), + role) private val dispatcher: Dispatcher = new Dispatcher(this, numUsableCores) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index d967d38c52631..de57807639f3b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -24,10 +24,8 @@ import java.util.concurrent.atomic.AtomicInteger import scala.annotation.tailrec import scala.collection.Map -import scala.collection.mutable.{ArrayStack, HashMap, HashSet} +import scala.collection.mutable.{HashMap, HashSet, ListBuffer} import scala.concurrent.duration._ -import scala.language.existentials -import scala.language.postfixOps import scala.util.control.NonFatal import org.apache.commons.lang3.SerializationUtils @@ -270,7 +268,7 @@ private[spark] class DAGScheduler( listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, accumUpdates, Some(executorUpdates))) blockManagerMaster.driverEndpoint.askSync[Boolean]( - BlockManagerHeartbeat(blockManagerId), new RpcTimeout(600 seconds, "BlockManagerHeartbeat")) + BlockManagerHeartbeat(blockManagerId), new RpcTimeout(10.minutes, "BlockManagerHeartbeat")) } /** @@ -383,7 +381,8 @@ private[spark] class DAGScheduler( * locations that are still available from the previous shuffle to avoid unnecessarily * regenerating data. */ - def createShuffleMapStage(shuffleDep: ShuffleDependency[_, _, _], jobId: Int): ShuffleMapStage = { + def createShuffleMapStage[K, V, C]( + shuffleDep: ShuffleDependency[K, V, C], jobId: Int): ShuffleMapStage = { val rdd = shuffleDep.rdd checkBarrierStageWithDynamicAllocation(rdd) checkBarrierStageWithNumSlots(rdd) @@ -469,21 +468,21 @@ private[spark] class DAGScheduler( /** Find ancestor shuffle dependencies that are not registered in shuffleToMapStage yet */ private def getMissingAncestorShuffleDependencies( - rdd: RDD[_]): ArrayStack[ShuffleDependency[_, _, _]] = { - val ancestors = new ArrayStack[ShuffleDependency[_, _, _]] + rdd: RDD[_]): ListBuffer[ShuffleDependency[_, _, _]] = { + val ancestors = new ListBuffer[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting - val waitingForVisit = new ArrayStack[RDD[_]] - waitingForVisit.push(rdd) + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd while (waitingForVisit.nonEmpty) { - val toVisit = waitingForVisit.pop() + val toVisit = waitingForVisit.remove(0) if (!visited(toVisit)) { visited += toVisit getShuffleDependencies(toVisit).foreach { shuffleDep => if (!shuffleIdToMapStage.contains(shuffleDep.shuffleId)) { - ancestors.push(shuffleDep) - waitingForVisit.push(shuffleDep.rdd) + ancestors.prepend(shuffleDep) + waitingForVisit.prepend(shuffleDep.rdd) } // Otherwise, the dependency and its ancestors have already been registered. } } @@ -507,17 +506,17 @@ private[spark] class DAGScheduler( rdd: RDD[_]): HashSet[ShuffleDependency[_, _, _]] = { val parents = new HashSet[ShuffleDependency[_, _, _]] val visited = new HashSet[RDD[_]] - val waitingForVisit = new ArrayStack[RDD[_]] - waitingForVisit.push(rdd) + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd while (waitingForVisit.nonEmpty) { - val toVisit = waitingForVisit.pop() + val toVisit = waitingForVisit.remove(0) if (!visited(toVisit)) { visited += toVisit toVisit.dependencies.foreach { case shuffleDep: ShuffleDependency[_, _, _] => parents += shuffleDep case dependency => - waitingForVisit.push(dependency.rdd) + waitingForVisit.prepend(dependency.rdd) } } } @@ -530,10 +529,10 @@ private[spark] class DAGScheduler( */ private def traverseParentRDDsWithinStage(rdd: RDD[_], predicate: RDD[_] => Boolean): Boolean = { val visited = new HashSet[RDD[_]] - val waitingForVisit = new ArrayStack[RDD[_]] - waitingForVisit.push(rdd) + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += rdd while (waitingForVisit.nonEmpty) { - val toVisit = waitingForVisit.pop() + val toVisit = waitingForVisit.remove(0) if (!visited(toVisit)) { if (!predicate(toVisit)) { return false @@ -543,7 +542,7 @@ private[spark] class DAGScheduler( case _: ShuffleDependency[_, _, _] => // Not within the same stage with current rdd, do nothing. case dependency => - waitingForVisit.push(dependency.rdd) + waitingForVisit.prepend(dependency.rdd) } } } @@ -555,7 +554,8 @@ private[spark] class DAGScheduler( val visited = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting - val waitingForVisit = new ArrayStack[RDD[_]] + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += stage.rdd def visit(rdd: RDD[_]) { if (!visited(rdd)) { visited += rdd @@ -569,15 +569,14 @@ private[spark] class DAGScheduler( missing += mapStage } case narrowDep: NarrowDependency[_] => - waitingForVisit.push(narrowDep.rdd) + waitingForVisit.prepend(narrowDep.rdd) } } } } } - waitingForVisit.push(stage.rdd) while (waitingForVisit.nonEmpty) { - visit(waitingForVisit.pop()) + visit(waitingForVisit.remove(0)) } missing.toList } @@ -1390,6 +1389,14 @@ private[spark] class DAGScheduler( event.reason match { case Success => + // An earlier attempt of a stage (which is zombie) may still have running tasks. If these + // tasks complete, they still count and we can mark the corresponding partitions as + // finished. Here we notify the task scheduler to skip running tasks for the same partition, + // to save resource. + if (task.stageAttemptId < stage.latestInfo.attemptNumber()) { + taskScheduler.notifyPartitionCompletion(stageId, task.partitionId) + } + task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask @@ -1993,7 +2000,8 @@ private[spark] class DAGScheduler( val visitedRdds = new HashSet[RDD[_]] // We are manually maintaining a stack here to prevent StackOverflowError // caused by recursively visiting - val waitingForVisit = new ArrayStack[RDD[_]] + val waitingForVisit = new ListBuffer[RDD[_]] + waitingForVisit += stage.rdd def visit(rdd: RDD[_]) { if (!visitedRdds(rdd)) { visitedRdds += rdd @@ -2002,17 +2010,16 @@ private[spark] class DAGScheduler( case shufDep: ShuffleDependency[_, _, _] => val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId) if (!mapStage.isAvailable) { - waitingForVisit.push(mapStage.rdd) + waitingForVisit.prepend(mapStage.rdd) } // Otherwise there's no need to follow the dependency back case narrowDep: NarrowDependency[_] => - waitingForVisit.push(narrowDep.rdd) + waitingForVisit.prepend(narrowDep.rdd) } } } } - waitingForVisit.push(stage.rdd) while (waitingForVisit.nonEmpty) { - visit(waitingForVisit.pop()) + visit(waitingForVisit.remove(0)) } visitedRdds.contains(target.rdd) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 54ab8f8b3e1d8..b514c2e7056f4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -19,8 +19,6 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.language.existentials - import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.util.{AccumulatorV2, CallSite} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 189e35ee83119..710f5eb211dde 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -21,13 +21,10 @@ import java.lang.management.ManagementFactory import java.nio.ByteBuffer import java.util.Properties -import scala.language.existentials - import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.shuffle.ShuffleWriter /** * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner @@ -85,13 +82,15 @@ private[spark] class ShuffleMapTask( threadMXBean.getCurrentThreadCpuTime } else 0L val ser = SparkEnv.get.closureSerializer.newInstance() - val (rdd, dep) = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( + val rddAndDep = ser.deserialize[(RDD[_], ShuffleDependency[_, _, _])]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTimeNs = System.nanoTime() - deserializeStartTimeNs _executorDeserializeCpuTime = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime - deserializeStartCpuTime } else 0L + val rdd = rddAndDep._1 + val dep = rddAndDep._2 dep.shuffleWriterProcessor.write(rdd, dep, partitionId, context, partition) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index c6dedaaa9554a..9b7f901c55e00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -155,6 +155,15 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul } } + // This method calls `TaskSchedulerImpl.handlePartitionCompleted` asynchronously. We do not want + // DAGScheduler to call `TaskSchedulerImpl.handlePartitionCompleted` directly, as it's + // synchronized and may hurt the throughput of the scheduler. + def enqueuePartitionCompletionNotification(stageId: Int, partitionId: Int): Unit = { + getTaskResultExecutor.execute(() => Utils.logUncaughtExceptions { + scheduler.handlePartitionCompleted(stageId, partitionId) + }) + } + def stop() { getTaskResultExecutor.shutdownNow() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 94221eb0d5515..bfdbf0217210a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -68,6 +68,10 @@ private[spark] trait TaskScheduler { // Throw UnsupportedOperationException if the backend doesn't support kill tasks. def killAllTaskAttempts(stageId: Int, interruptThread: Boolean, reason: String): Unit + // Notify the corresponding `TaskSetManager`s of the stage, that a partition has already completed + // and they can skip running tasks for it. + def notifyPartitionCompletion(stageId: Int, partitionId: Int) + // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index e401c395a0486..532eb322769aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -22,7 +22,7 @@ import java.util.{Locale, Timer, TimerTask} import java.util.concurrent.{ConcurrentHashMap, TimeUnit} import java.util.concurrent.atomic.AtomicLong -import scala.collection.mutable.{ArrayBuffer, BitSet, HashMap, HashSet} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.util.Random import org.apache.spark._ @@ -101,9 +101,6 @@ private[spark] class TaskSchedulerImpl( // Protected by `this` val taskIdToExecutorId = new HashMap[Long, String] - // Protected by `this` - private[scheduler] val stageIdToFinishedPartitions = new HashMap[Int, BitSet] - @volatile private var hasReceivedTask = false @volatile private var hasLaunchedTask = false private val starvationTimer = new Timer(true) @@ -252,20 +249,7 @@ private[spark] class TaskSchedulerImpl( private[scheduler] def createTaskSetManager( taskSet: TaskSet, maxTaskFailures: Int): TaskSetManager = { - // only create a BitSet once for a certain stage since we only remove - // that stage when an active TaskSetManager succeed. - stageIdToFinishedPartitions.getOrElseUpdate(taskSet.stageId, new BitSet) - val tsm = new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt) - // TaskSet got submitted by DAGScheduler may have some already completed - // tasks since DAGScheduler does not always know all the tasks that have - // been completed by other tasksets when completing a stage, so we mark - // those tasks as finished here to avoid launching duplicate tasks, while - // holding the TaskSchedulerImpl lock. - // See SPARK-25250 and `markPartitionCompletedInAllTaskSets()` - stageIdToFinishedPartitions.get(taskSet.stageId).foreach { - finishedPartitions => finishedPartitions.foreach(tsm.markPartitionCompleted(_, None)) - } - tsm + new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt) } override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized { @@ -317,6 +301,10 @@ private[spark] class TaskSchedulerImpl( } } + override def notifyPartitionCompletion(stageId: Int, partitionId: Int): Unit = { + taskResultGetter.enqueuePartitionCompletionNotification(stageId, partitionId) + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be @@ -653,6 +641,21 @@ private[spark] class TaskSchedulerImpl( } } + /** + * Marks the task has completed in the active TaskSetManager for the given stage. + * + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. + * If an earlier zombie attempt of a stage completes a task, we can ask the later active attempt + * to skip submitting and running the task for the same partition, to save resource. That also + * means that a task completion from an earlier zombie attempt can lead to the entire stage + * getting marked as successful. + */ + private[scheduler] def handlePartitionCompleted(stageId: Int, partitionId: Int) = synchronized { + taskSetsByStageIdAndAttempt.get(stageId).foreach(_.values.filter(!_.isZombie).foreach { tsm => + tsm.markPartitionCompleted(partitionId) + }) + } + def error(message: String) { synchronized { if (taskSetsByStageIdAndAttempt.nonEmpty) { @@ -884,36 +887,6 @@ private[spark] class TaskSchedulerImpl( manager } } - - /** - * Marks the task has completed in all TaskSetManagers(active / zombie) for the given stage. - * - * After stage failure and retry, there may be multiple TaskSetManagers for the stage. - * If an earlier attempt of a stage completes a task, we should ensure that the later attempts - * do not also submit those same tasks. That also means that a task completion from an earlier - * attempt can lead to the entire stage getting marked as successful. - * And there is also the possibility that the DAGScheduler submits another taskset at the same - * time as we're marking a task completed here -- that taskset would have a task for a partition - * that was already completed. We maintain the set of finished partitions in - * stageIdToFinishedPartitions, protected by this, so we can detect those tasks when the taskset - * is submitted. See SPARK-25250 for more details. - * - * note: this method must be called with a lock on this. - */ - private[scheduler] def markPartitionCompletedInAllTaskSets( - stageId: Int, - partitionId: Int, - taskInfo: TaskInfo) = { - // if we do not find a BitSet for this stage, which means an active TaskSetManager - // has already succeeded and removed the stage. - stageIdToFinishedPartitions.get(stageId).foreach{ - finishedPartitions => finishedPartitions += partitionId - } - taskSetsByStageIdAndAttempt.getOrElse(stageId, Map()).values.foreach { tsm => - tsm.markPartitionCompleted(partitionId, Some(taskInfo)) - } - } - } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 144422022c22f..52323b3331d7e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -21,7 +21,7 @@ import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.concurrent.ConcurrentLinkedQueue -import scala.collection.mutable.{ArrayBuffer, BitSet, HashMap, HashSet} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.math.max import scala.util.control.NonFatal @@ -62,14 +62,8 @@ private[spark] class TaskSetManager( private val addedJars = HashMap[String, Long](sched.sc.addedJars.toSeq: _*) private val addedFiles = HashMap[String, Long](sched.sc.addedFiles.toSeq: _*) - // Quantile of tasks at which to start speculation - val speculationQuantile = conf.get(SPECULATION_QUANTILE) - val speculationMultiplier = conf.get(SPECULATION_MULTIPLIER) - val maxResultSize = conf.get(config.MAX_RESULT_SIZE) - val speculationEnabled = conf.get(SPECULATION_ENABLED) - // Serializer for closures and tasks. val env = SparkEnv.get val ser = env.closureSerializer.newInstance() @@ -80,6 +74,12 @@ private[spark] class TaskSetManager( val numTasks = tasks.length val copiesRunning = new Array[Int](numTasks) + val speculationEnabled = conf.get(SPECULATION_ENABLED) + // Quantile of tasks at which to start speculation + val speculationQuantile = conf.get(SPECULATION_QUANTILE) + val speculationMultiplier = conf.get(SPECULATION_MULTIPLIER) + val minFinishedForSpeculation = math.max((speculationQuantile * numTasks).floor.toInt, 1) + // For each task, tracks whether a copy of the task has succeeded. A task will also be // marked as "succeeded" if it failed with a fetch failure, in which case it should not // be re-run because the missing map data needs to be regenerated first. @@ -800,19 +800,12 @@ private[spark] class TaskSetManager( // Mark successful and stop if all the tasks have succeeded. successful(index) = true if (tasksSuccessful == numTasks) { - // clean up finished partitions for the stage when the active TaskSetManager succeed - if (!isZombie) { - sched.stageIdToFinishedPartitions -= stageId - isZombie = true - } + isZombie = true } } else { logInfo("Ignoring task-finished event for " + info.id + " in stage " + taskSet.id + " because task " + index + " has already completed successfully") } - // There may be multiple tasksets for this stage -- we let all of them know that the partition - // was completed. This may result in some of the tasksets getting completed. - sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info) // This method is called by "TaskSchedulerImpl.handleSuccessfulTask" which holds the // "TaskSchedulerImpl" lock until exiting. To avoid the SPARK-7655 issue, we should not // "deserialize" the value when holding a lock to avoid blocking other threads. So we call @@ -823,21 +816,13 @@ private[spark] class TaskSetManager( maybeFinishTaskSet() } - private[scheduler] def markPartitionCompleted( - partitionId: Int, - taskInfo: Option[TaskInfo]): Unit = { + private[scheduler] def markPartitionCompleted(partitionId: Int): Unit = { partitionToIndex.get(partitionId).foreach { index => if (!successful(index)) { - if (speculationEnabled && !isZombie) { - taskInfo.foreach { info => successfulTaskDurations.insert(info.duration) } - } tasksSuccessful += 1 successful(index) = true if (tasksSuccessful == numTasks) { - if (!isZombie) { - sched.stageIdToFinishedPartitions -= stageId - isZombie = true - } + isZombie = true } maybeFinishTaskSet() } @@ -1047,10 +1032,13 @@ private[spark] class TaskSetManager( return false } var foundTasks = false - val minFinishedForSpeculation = (speculationQuantile * numTasks).floor.toInt logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation) - if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) { + // It's possible that a task is marked as completed by the scheduler, then the size of + // `successfulTaskDurations` may not equal to `tasksSuccessful`. Here we should only count the + // tasks that are submitted by this `TaskSetManager` and are completed successfully. + val numSuccessfulTasks = successfulTaskDurations.size() + if (numSuccessfulTasks >= minFinishedForSpeculation) { val time = clock.getTimeMillis() val medianDuration = successfulTaskDurations.median val threshold = max(speculationMultiplier * medianDuration, minTimeToSpeculation) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index afb48a31754f9..89425e702677a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer +import org.apache.spark.ResourceInformation import org.apache.spark.TaskState.TaskState import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.ExecutorLossReason @@ -64,7 +65,8 @@ private[spark] object CoarseGrainedClusterMessages { hostname: String, cores: Int, logUrls: Map[String, String], - attributes: Map[String, String]) + attributes: Map[String, String], + resources: Map[String, ResourceInformation]) extends CoarseGrainedClusterMessage case class StatusUpdate(executorId: String, taskId: Long, state: TaskState, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4830d0e6f8008..f7cf212d0bfe1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -185,7 +185,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls, attributes) => + case RegisterExecutor(executorId, executorRef, hostname, cores, logUrls, + attributes, resources) => if (executorDataMap.contains(executorId)) { executorRef.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) context.reply(true) diff --git a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala index c65c8fd6e3e1c..e616d239ce8d5 100644 --- a/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala +++ b/core/src/main/scala/org/apache/spark/security/SocketAuthServer.scala @@ -21,7 +21,6 @@ import java.net.{InetAddress, ServerSocket, Socket} import scala.concurrent.Promise import scala.concurrent.duration.Duration -import scala.language.existentials import scala.util.Try import org.apache.spark.SparkEnv diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index eef19973e8d77..39691069bf5f6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -164,8 +164,8 @@ class KryoSerializer(conf: SparkConf) } // Allow the user to register their own classes by setting spark.kryo.registrator. userRegistrators - .map(Utils.classForName(_, noSparkClassLoader = true).getConstructor(). - newInstance().asInstanceOf[KryoRegistrator]) + .map(Utils.classForName[KryoRegistrator](_, noSparkClassLoader = true). + getConstructor().newInstance()) .foreach { reg => reg.registerClasses(kryo) } } catch { case e: Exception => @@ -213,6 +213,10 @@ class KryoSerializer(conf: SparkConf) // We can't load those class directly in order to avoid unnecessary jar dependencies. // We load them safely, ignore it if the class not found. Seq( + "org.apache.spark.sql.catalyst.expressions.UnsafeRow", + "org.apache.spark.sql.catalyst.expressions.UnsafeArrayData", + "org.apache.spark.sql.catalyst.expressions.UnsafeMapData", + "org.apache.spark.ml.attribute.Attribute", "org.apache.spark.ml.attribute.AttributeGroup", "org.apache.spark.ml.attribute.BinaryAttribute", diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 84c2ad48f1f27..83f76db7e89da 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -77,7 +77,7 @@ private[spark] trait UIRoot { /** * Runs some code with the current SparkUI instance for the app / attempt. * - * @throws NoSuchElementException If the app / attempt pair does not exist. + * @throws java.util.NoSuchElementException If the app / attempt pair does not exist. */ def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T @@ -85,8 +85,8 @@ private[spark] trait UIRoot { def getApplicationInfo(appId: String): Option[ApplicationInfo] /** - * Write the event logs for the given app to the [[ZipOutputStream]] instance. If attemptId is - * [[None]], event logs for all attempts of this application will be written out. + * Write the event logs for the given app to the `ZipOutputStream` instance. If attemptId is + * `None`, event logs for all attempts of this application will be written out. */ def writeEventLogs(appId: String, attemptId: Option[String], zipStream: ZipOutputStream): Unit = { Response.serverError() diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala index 9d1d66a0e15a4..db53a400ed62f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/StagesResource.scala @@ -24,8 +24,9 @@ import org.apache.spark.SparkException import org.apache.spark.scheduler.StageInfo import org.apache.spark.status.api.v1.StageStatus._ import org.apache.spark.status.api.v1.TaskSorting._ -import org.apache.spark.ui.SparkUI +import org.apache.spark.ui.{SparkUI, UIUtils} import org.apache.spark.ui.jobs.ApiHelper._ +import org.apache.spark.util.Utils @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class StagesResource extends BaseAppResource { @@ -189,32 +190,42 @@ private[v1] class StagesResource extends BaseAppResource { val taskMetricsContainsValue = (task: TaskData) => task.taskMetrics match { case None => false case Some(metrics) => - (containsValue(task.taskMetrics.get.executorDeserializeTime) - || containsValue(task.taskMetrics.get.executorRunTime) - || containsValue(task.taskMetrics.get.jvmGcTime) - || containsValue(task.taskMetrics.get.resultSerializationTime) - || containsValue(task.taskMetrics.get.memoryBytesSpilled) - || containsValue(task.taskMetrics.get.diskBytesSpilled) - || containsValue(task.taskMetrics.get.peakExecutionMemory) - || containsValue(task.taskMetrics.get.inputMetrics.bytesRead) + (containsValue(UIUtils.formatDuration(task.taskMetrics.get.executorDeserializeTime)) + || containsValue(UIUtils.formatDuration(task.taskMetrics.get.executorRunTime)) + || containsValue(UIUtils.formatDuration(task.taskMetrics.get.jvmGcTime)) + || containsValue(UIUtils.formatDuration(task.taskMetrics.get.resultSerializationTime)) + || containsValue(Utils.bytesToString(task.taskMetrics.get.memoryBytesSpilled)) + || containsValue(Utils.bytesToString(task.taskMetrics.get.diskBytesSpilled)) + || containsValue(Utils.bytesToString(task.taskMetrics.get.peakExecutionMemory)) + || containsValue(Utils.bytesToString(task.taskMetrics.get.inputMetrics.bytesRead)) || containsValue(task.taskMetrics.get.inputMetrics.recordsRead) - || containsValue(task.taskMetrics.get.outputMetrics.bytesWritten) + || containsValue(Utils.bytesToString( + task.taskMetrics.get.outputMetrics.bytesWritten)) || containsValue(task.taskMetrics.get.outputMetrics.recordsWritten) - || containsValue(task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime) + || containsValue(UIUtils.formatDuration( + task.taskMetrics.get.shuffleReadMetrics.fetchWaitTime)) + || containsValue(Utils.bytesToString( + task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead)) + || containsValue(Utils.bytesToString( + task.taskMetrics.get.shuffleReadMetrics.localBytesRead + + task.taskMetrics.get.shuffleReadMetrics.remoteBytesRead)) || containsValue(task.taskMetrics.get.shuffleReadMetrics.recordsRead) - || containsValue(task.taskMetrics.get.shuffleWriteMetrics.bytesWritten) + || containsValue(Utils.bytesToString( + task.taskMetrics.get.shuffleWriteMetrics.bytesWritten)) || containsValue(task.taskMetrics.get.shuffleWriteMetrics.recordsWritten) - || containsValue(task.taskMetrics.get.shuffleWriteMetrics.writeTime)) + || containsValue(UIUtils.formatDuration( + task.taskMetrics.get.shuffleWriteMetrics.writeTime / 1000000))) } val filteredTaskDataSequence: Seq[TaskData] = taskDataList.filter(f => (containsValue(f.taskId) || containsValue(f.index) || containsValue(f.attempt) - || containsValue(f.launchTime) + || containsValue(UIUtils.formatDate(f.launchTime)) || containsValue(f.resultFetchStart.getOrElse(defaultOptionString)) || containsValue(f.executorId) || containsValue(f.host) || containsValue(f.status) || containsValue(f.taskLocality) || containsValue(f.speculative) || containsValue(f.errorMessage.getOrElse(defaultOptionString)) || taskMetricsContainsValue(f) - || containsValue(f.schedulerDelay) || containsValue(f.gettingResultTime))) + || containsValue(UIUtils.formatDuration(f.schedulerDelay)) + || containsValue(UIUtils.formatDuration(f.gettingResultTime)))) filteredTaskDataSequence } diff --git a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala index fc9b50f14a083..1c0dd7dee2228 100644 --- a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala +++ b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala @@ -53,9 +53,24 @@ private class HttpSecurityFilter( val hres = res.asInstanceOf[HttpServletResponse] hres.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") - if (!securityMgr.checkUIViewPermissions(hreq.getRemoteUser())) { + val requestUser = hreq.getRemoteUser() + + // The doAs parameter allows proxy servers (e.g. Knox) to impersonate other users. For + // that to be allowed, the authenticated user needs to be an admin. + val effectiveUser = Option(hreq.getParameter("doAs")) + .map { proxy => + if (requestUser != proxy && !securityMgr.checkAdminPermissions(requestUser)) { + hres.sendError(HttpServletResponse.SC_FORBIDDEN, + s"User $requestUser is not allowed to impersonate others.") + return + } + proxy + } + .getOrElse(requestUser) + + if (!securityMgr.checkUIViewPermissions(effectiveUser)) { hres.sendError(HttpServletResponse.SC_FORBIDDEN, - "User is not authorized to access this page.") + s"User $effectiveUser is not authorized to access this page.") return } @@ -77,12 +92,13 @@ private class HttpSecurityFilter( hres.setHeader("Strict-Transport-Security", _)) } - chain.doFilter(new XssSafeRequest(hreq), res) + chain.doFilter(new XssSafeRequest(hreq, effectiveUser), res) } } -private class XssSafeRequest(req: HttpServletRequest) extends HttpServletRequestWrapper(req) { +private class XssSafeRequest(req: HttpServletRequest, effectiveUser: String) + extends HttpServletRequestWrapper(req) { private val NEWLINE_AND_SINGLE_QUOTE_REGEX = raw"(?i)(\r\n|\n|\r|%0D%0A|%0A|%0D|'|%27)".r @@ -92,6 +108,8 @@ private class XssSafeRequest(req: HttpServletRequest) extends HttpServletRequest }.toMap } + override def getRemoteUser(): String = effectiveUser + override def getParameterMap(): JMap[String, Array[String]] = parameterMap.asJava override def getParameterNames(): Enumeration[String] = { diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index d7bda8b80c24f..11647c0c7c623 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -109,12 +109,12 @@ private[spark] object UIUtils extends Logging { } } // if time is more than a year - return s"$yearString $weekString $dayString" + s"$yearString $weekString $dayString" } catch { case e: Exception => logError("Error converting time to string", e) // if there is some error, return blank string - return "" + "" } } @@ -336,7 +336,7 @@ private[spark] object UIUtils extends Logging { def getHeaderContent(header: String): Seq[Node] = { if (newlinesInHeader) {
    - { header.split("\n").map { case t =>
  • {t}
  • } } + { header.split("\n").map(t =>
  • {t}
  • ) }
} else { Text(header) @@ -446,7 +446,7 @@ private[spark] object UIUtils extends Logging { * the whole string will rendered as a simple escaped text. * * Note: In terms of security, only anchor tags with root relative links are supported. So any - * attempts to embed links outside Spark UI, or other tags like {@code