diff --git a/LICENSE b/LICENSE
index 0a42d389e4c3c..9b364a4d00079 100644
--- a/LICENSE
+++ b/LICENSE
@@ -771,6 +771,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+========================================================================
+For TestTimSort (core/src/test/java/org/apache/spark/util/collection/TestTimSort.java):
+========================================================================
+Copyright (C) 2015 Stijn de Gouw
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
========================================================================
For LimitedInputStream
diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
index 409e1a41c5d49..a90cc0e761f62 100644
--- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java
+++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
@@ -425,15 +425,14 @@ private void pushRun(int runBase, int runLen) {
private void mergeCollapse() {
while (stackSize > 1) {
int n = stackSize - 2;
- if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
+ if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1])
+ || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) {
if (runLen[n - 1] < runLen[n + 1])
n--;
- mergeAt(n);
- } else if (runLen[n] <= runLen[n + 1]) {
- mergeAt(n);
- } else {
+ } else if (runLen[n] > runLen[n + 1]) {
break; // Invariant is established
}
+ mergeAt(n);
}
}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 30f0ccd73ccca..bcf832467f00b 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -280,15 +280,24 @@ object AccumulatorParam {
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
-private[spark] object Accumulators {
- // Store a WeakReference instead of a StrongReference because this way accumulators can be
- // appropriately garbage collected during long-running jobs and release memory
- type WeakAcc = WeakReference[Accumulable[_, _]]
- val originals = Map[Long, WeakAcc]()
- val localAccums = new ThreadLocal[Map[Long, WeakAcc]]() {
- override protected def initialValue() = Map[Long, WeakAcc]()
+private[spark] object Accumulators extends Logging {
+ /**
+ * This global map holds the original accumulator objects that are created on the driver.
+ * It keeps weak references to these objects so that accumulators can be garbage-collected
+ * once the RDDs and user-code that reference them are cleaned up.
+ */
+ val originals = Map[Long, WeakReference[Accumulable[_, _]]]()
+
+ /**
+ * This thread-local map holds per-task copies of accumulators; it is used to collect the set
+ * of accumulator updates to send back to the driver when tasks complete. After tasks complete,
+ * this map is cleared by `Accumulators.clear()` (see Executor.scala).
+ */
+ private val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
+ override protected def initialValue() = Map[Long, Accumulable[_, _]]()
}
- var lastId: Long = 0
+
+ private var lastId: Long = 0
def newId(): Long = synchronized {
lastId += 1
@@ -297,16 +306,16 @@ private[spark] object Accumulators {
def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
if (original) {
- originals(a.id) = new WeakAcc(a)
+ originals(a.id) = new WeakReference[Accumulable[_, _]](a)
} else {
- localAccums.get()(a.id) = new WeakAcc(a)
+ localAccums.get()(a.id) = a
}
}
// Clear the local (non-original) accumulators for the current thread
def clear() {
synchronized {
- localAccums.get.clear
+ localAccums.get.clear()
}
}
@@ -320,12 +329,7 @@ private[spark] object Accumulators {
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.get) {
- // Since we are now storing weak references, we must check whether the underlying data
- // is valid.
- ret(id) = accum.get match {
- case Some(values) => values.localValue
- case None => None
- }
+ ret(id) = accum.localValue
}
return ret
}
@@ -341,6 +345,8 @@ private[spark] object Accumulators {
case None =>
throw new IllegalAccessError("Attempted to access garbage collected Accumulator.")
}
+ } else {
+ logWarning(s"Ignoring accumulator update for unknown accumulator id $id")
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 83ae57b7f1516..69178da1a7773 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -17,33 +17,86 @@
package org.apache.spark
-import akka.actor.Actor
+import scala.concurrent.duration._
+import scala.collection.mutable
+
+import akka.actor.{Actor, Cancellable}
+
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.scheduler.TaskScheduler
+import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
import org.apache.spark.util.ActorLogReceive
/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
- * components to convey liveness or execution information for in-progress tasks.
+ * components to convey liveness or execution information for in-progress tasks. It will also
+ * expire the hosts that have not heartbeated for more than spark.network.timeout.
*/
private[spark] case class Heartbeat(
executorId: String,
taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
blockManagerId: BlockManagerId)
+private[spark] case object ExpireDeadHosts
+
private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
/**
* Lives in the driver to receive heartbeats from executors..
*/
-private[spark] class HeartbeatReceiver(scheduler: TaskScheduler)
+private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskScheduler)
extends Actor with ActorLogReceive with Logging {
+ // executor ID -> timestamp of when the last heartbeat from this executor was received
+ private val executorLastSeen = new mutable.HashMap[String, Long]
+
+ private val executorTimeoutMs = sc.conf.getLong("spark.network.timeout",
+ sc.conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120)) * 1000
+
+ private val checkTimeoutIntervalMs = sc.conf.getLong("spark.network.timeoutInterval",
+ sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60)) * 1000
+
+ private var timeoutCheckingTask: Cancellable = null
+
+ override def preStart(): Unit = {
+ import context.dispatcher
+ timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
+ checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts)
+ super.preStart()
+ }
+
override def receiveWithLogging = {
case Heartbeat(executorId, taskMetrics, blockManagerId) =>
- val response = HeartbeatResponse(
- !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId))
+ val unknownExecutor = !scheduler.executorHeartbeatReceived(
+ executorId, taskMetrics, blockManagerId)
+ val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
+ executorLastSeen(executorId) = System.currentTimeMillis()
sender ! response
+ case ExpireDeadHosts =>
+ expireDeadHosts()
+ }
+
+ private def expireDeadHosts(): Unit = {
+ logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.")
+ val now = System.currentTimeMillis()
+ for ((executorId, lastSeenMs) <- executorLastSeen) {
+ if (now - lastSeenMs > executorTimeoutMs) {
+ logWarning(s"Removing executor $executorId with no recent heartbeats: " +
+ s"${now - lastSeenMs} ms exceeds timeout $executorTimeoutMs ms")
+ scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " +
+ "timed out after ${now - lastSeenMs} ms"))
+ if (sc.supportDynamicAllocation) {
+ sc.killExecutor(executorId)
+ }
+ executorLastSeen.remove(executorId)
+ }
+ }
+ }
+
+ override def postStop(): Unit = {
+ if (timeoutCheckingTask != null) {
+ timeoutCheckingTask.cancel()
+ }
+ super.postStop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 0f4922ab4e310..61b34d524a421 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -407,7 +407,7 @@ private[spark] object SparkConf extends Logging {
* @param warn Whether to print a warning if the key is deprecated. Warnings will be printed
* only once for each key.
*/
- def translateConfKey(userKey: String, warn: Boolean = false): String = {
+ private def translateConfKey(userKey: String, warn: Boolean = false): String = {
deprecatedConfigs.get(userKey)
.map { deprecatedKey =>
if (warn) {
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 930d4bea4785b..3cd0c218a36fd 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -351,7 +351,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private[spark] var (schedulerBackend, taskScheduler) =
SparkContext.createTaskScheduler(this, master)
private val heartbeatReceiver = env.actorSystem.actorOf(
- Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver")
+ Props(new HeartbeatReceiver(this, taskScheduler)), "HeartbeatReceiver")
@volatile private[spark] var dagScheduler: DAGScheduler = _
try {
dagScheduler = new DAGScheduler(this)
@@ -398,7 +398,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false)
private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] =
if (dynamicAllocationEnabled) {
- assert(master.contains("yarn") || dynamicAllocationTesting,
+ assert(supportDynamicAllocation,
"Dynamic allocation of executors is currently only supported in YARN mode")
Some(new ExecutorAllocationManager(this, listenerBus, conf))
} else {
@@ -1122,6 +1122,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
postEnvironmentUpdate()
}
+ /**
+ * Return whether dynamically adjusting the amount of resources allocated to
+ * this application is supported. This is currently only available for YARN.
+ */
+ private[spark] def supportDynamicAllocation =
+ master.contains("yarn") || dynamicAllocationTesting
+
/**
* :: DeveloperApi ::
* Register a listener to receive up-calls from events that happen during execution.
@@ -1155,7 +1162,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
- assert(master.contains("yarn") || dynamicAllocationTesting,
+ assert(supportDynamicAllocation,
"Requesting executors is currently only supported in YARN mode")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
@@ -1173,7 +1180,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
*/
@DeveloperApi
override def killExecutors(executorIds: Seq[String]): Boolean = {
- assert(master.contains("yarn") || dynamicAllocationTesting,
+ assert(supportDynamicAllocation,
"Killing executors is currently only supported in YARN mode")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
@@ -1382,17 +1389,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
stopped = true
env.metricsSystem.report()
metadataCleaner.cancel()
- env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
dagScheduler.stop()
dagScheduler = null
+ listenerBus.stop()
+ eventLogger.foreach(_.stop())
+ env.actorSystem.stop(heartbeatReceiver)
progressBar.foreach(_.stop())
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
- listenerBus.stop()
- eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
SparkContext.clearActiveContext()
} else {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 4c4110812e0a1..4a74641f4e1fa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -655,8 +655,7 @@ private[spark] object SparkSubmitUtils {
/**
* Extracts maven coordinates from a comma-delimited string. Coordinates should be provided
- * in the format `groupId:artifactId:version` or `groupId/artifactId:version`. The latter provides
- * simplicity for Spark Package users.
+ * in the format `groupId:artifactId:version` or `groupId/artifactId:version`.
* @param coordinates Comma-delimited string of maven coordinates
* @return Sequence of Maven coordinates
*/
@@ -747,6 +746,35 @@ private[spark] object SparkSubmitUtils {
md.addDependency(dd)
}
}
+
+ /** Add exclusion rules for dependencies already included in the spark-assembly */
+ private[spark] def addExclusionRules(
+ ivySettings: IvySettings,
+ ivyConfName: String,
+ md: DefaultModuleDescriptor): Unit = {
+ // Add scala exclusion rule
+ val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*")
+ val scalaDependencyExcludeRule =
+ new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null)
+ scalaDependencyExcludeRule.addConfiguration(ivyConfName)
+ md.addExcludeRule(scalaDependencyExcludeRule)
+
+ // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and
+ // other spark-streaming utility components. Underscore is there to differentiate between
+ // spark-streaming_2.1x and spark-streaming-kafka-assembly_2.1x
+ val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_",
+ "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_")
+
+ components.foreach { comp =>
+ val sparkArtifacts =
+ new ArtifactId(new ModuleId("org.apache.spark", s"spark-$comp*"), "*", "*", "*")
+ val sparkDependencyExcludeRule =
+ new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null)
+ sparkDependencyExcludeRule.addConfiguration(ivyConfName)
+
+ md.addExcludeRule(sparkDependencyExcludeRule)
+ }
+ }
/** A nice function to use in tests as well. Values are dummy strings. */
private[spark] def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance(
@@ -768,6 +796,9 @@ private[spark] object SparkSubmitUtils {
if (coordinates == null || coordinates.trim.isEmpty) {
""
} else {
+ val sysOut = System.out
+ // To prevent ivy from logging to system out
+ System.setOut(printStream)
val artifacts = extractMavenCoordinates(coordinates)
// Default configuration name for ivy
val ivyConfName = "default"
@@ -811,19 +842,9 @@ private[spark] object SparkSubmitUtils {
val md = getModuleDescriptor
md.setDefaultConf(ivyConfName)
- // Add an exclusion rule for Spark and Scala Library
- val sparkArtifacts = new ArtifactId(new ModuleId("org.apache.spark", "*"), "*", "*", "*")
- val sparkDependencyExcludeRule =
- new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null)
- sparkDependencyExcludeRule.addConfiguration(ivyConfName)
- val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*")
- val scalaDependencyExcludeRule =
- new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null)
- scalaDependencyExcludeRule.addConfiguration(ivyConfName)
-
- // Exclude any Spark dependencies, and add all supplied maven artifacts as dependencies
- md.addExcludeRule(sparkDependencyExcludeRule)
- md.addExcludeRule(scalaDependencyExcludeRule)
+ // Add exclusion rules for Spark and Scala Library
+ addExclusionRules(ivySettings, ivyConfName, md)
+ // add all supplied maven artifacts as dependencies
addDependenciesToIvy(md, artifacts, ivyConfName)
// resolve dependencies
@@ -835,7 +856,7 @@ private[spark] object SparkSubmitUtils {
ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId,
packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]",
retrieveOptions.setConfs(Array(ivyConfName)))
-
+ System.setOut(sysOut)
resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 1aaa7b72735ab..3e3d6ff29faf0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -49,8 +49,8 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
// Interval between each check for event log updates
private val UPDATE_INTERVAL_MS = conf.getOption("spark.history.fs.update.interval.seconds")
- .orElse(conf.getOption(SparkConf.translateConfKey("spark.history.fs.updateInterval", true)))
- .orElse(conf.getOption(SparkConf.translateConfKey("spark.history.updateInterval", true)))
+ .orElse(conf.getOption("spark.history.fs.updateInterval"))
+ .orElse(conf.getOption("spark.history.updateInterval"))
.map(_.toInt)
.getOrElse(10) * 1000
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index fa9bfe5426b6c..af483d560b33e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -96,6 +96,10 @@ class HistoryServer(
}
}
}
+ // SPARK-5983 ensure TRACE is not supported
+ protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = {
+ res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
+ }
}
initialize()
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index 67e6c5d66af0e..f5b946329ae9b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -21,7 +21,7 @@ private[spark] object ApplicationState extends Enumeration {
type ApplicationState = Value
- val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value
+ val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value
val MAX_NUM_RETRY = 10
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 3aae2b95d7396..76fc40e17d9a8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -24,6 +24,7 @@ import scala.xml.Node
import akka.pattern.ask
import org.json4s.JValue
+import org.json4s.JsonAST.JNothing
import org.apache.spark.deploy.{ExecutorState, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
@@ -44,7 +45,11 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
- JsonProtocol.writeApplicationInfo(app)
+ if (app == null) {
+ JNothing
+ } else {
+ JsonProtocol.writeApplicationInfo(app)
+ }
}
/** Executor details for a particular application */
@@ -55,6 +60,10 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
+ if (app == null) {
+ val msg =
No running application with ID {appId}
+ return UIUtils.basicSparkPage(msg, "Not Found")
+ }
val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs")
val allExecutors = (app.executors.values ++ app.removedExecutors).toSet.toSeq
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index 9dd96493ee48d..c086cadca2c7d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -26,8 +26,8 @@ import akka.pattern.ask
import org.json4s.JValue
import org.apache.spark.deploy.JsonProtocol
-import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
-import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
+import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState}
+import org.apache.spark.deploy.master._
import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.Utils
@@ -41,6 +41,31 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
JsonProtocol.writeMasterState(state)
}
+ def handleAppKillRequest(request: HttpServletRequest): Unit = {
+ handleKillRequest(request, id => {
+ parent.master.idToApp.get(id).foreach { app =>
+ parent.master.removeApplication(app, ApplicationState.KILLED)
+ }
+ })
+ }
+
+ def handleDriverKillRequest(request: HttpServletRequest): Unit = {
+ handleKillRequest(request, id => { master ! RequestKillDriver(id) })
+ }
+
+ private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
+ if (parent.killEnabled &&
+ parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
+ val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
+ val id = Option(request.getParameter("id"))
+ if (id.isDefined && killFlag) {
+ action(id.get)
+ }
+
+ Thread.sleep(100)
+ }
+ }
+
/** Index view listing applications and executors */
def render(request: HttpServletRequest): Seq[Node] = {
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
@@ -167,9 +192,20 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
private def appRow(app: ApplicationInfo, active: Boolean): Seq[Node] = {
+ val killLink = if (parent.killEnabled &&
+ (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) {
+ val killLinkUri = s"app/kill?id=${app.id}&terminate=true"
+ val confirm = "return window.confirm(" +
+ s"'Are you sure you want to kill application ${app.id} ?');"
+
+ (kill)
+
+ }
+
{app.id}
+ {killLink}
|
{app.desc.name}
@@ -182,7 +218,7 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
}
|
- {app.requestedCores}
+ {if (app.requestedCores == Int.MaxValue) "*" else app.requestedCores}
|
{Utils.megabytesToString(app.desc.memoryPerSlave)}
@@ -203,8 +239,19 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
private def driverRow(driver: DriverInfo): Seq[Node] = {
+ val killLink = if (parent.killEnabled &&
+ (driver.state == DriverState.RUNNING ||
+ driver.state == DriverState.SUBMITTED ||
+ driver.state == DriverState.RELAUNCHING)) {
+ val killLinkUri = s"driver/kill?id=${driver.id}&terminate=true"
+ val confirm = "return window.confirm(" +
+ s"'Are you sure you want to kill driver ${driver.id} ?');"
+
+ (kill)
+
+ }
|
- {driver.id} |
+ {driver.id} {killLink} |
{driver.submitDate} |
{driver.worker.map(w => {w.id.toString}).getOrElse("None")}
|
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index 73400c5affb5d..170f90a00ad2a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -32,15 +32,21 @@ class MasterWebUI(val master: Master, requestedPort: Int)
val masterActorRef = master.self
val timeout = AkkaUtils.askTimeout(master.conf)
+ val killEnabled = master.conf.getBoolean("spark.ui.killEnabled", true)
initialize()
/** Initialize all components of the server. */
def initialize() {
+ val masterPage = new MasterPage(this)
attachPage(new ApplicationPage(this))
attachPage(new HistoryNotFoundPage(this))
- attachPage(new MasterPage(this))
+ attachPage(masterPage)
attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static"))
+ attachHandler(
+ createRedirectHandler("/app/kill", "/", masterPage.handleAppKillRequest))
+ attachHandler(
+ createRedirectHandler("/driver/kill", "/", masterPage.handleDriverKillRequest))
}
/** Attach a reconstructed UI to this Master UI. Only valid after bind(). */
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 955b42c3baaa1..6b4f097ea9ae5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -993,6 +993,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
+ require(writer != null, "Unable to obtain RecordWriter")
var recordsWritten = 0L
try {
while (iter.hasNext) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index f095915352b17..ed3418676e077 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -73,5 +73,9 @@ private[spark] trait TaskScheduler {
* @return An application ID
*/
def applicationId(): String = appId
-
+
+ /**
+ * Process a lost executor
+ */
+ def executorLost(executorId: String, reason: ExecutorLossReason): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 54f8fcfc416d1..7a9cf1c2e7f30 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -436,7 +436,7 @@ private[spark] class TaskSchedulerImpl(
}
}
- def executorLost(executorId: String, reason: ExecutorLossReason) {
+ override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {
var failedExecutor: Option[String] = None
synchronized {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 64133464d8daa..787b0f96bec32 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConversions._
import scala.concurrent.Future
import scala.concurrent.duration._
-import akka.actor.{Actor, ActorRef, Cancellable}
+import akka.actor.{Actor, ActorRef}
import akka.pattern.ask
import org.apache.spark.{Logging, SparkConf, SparkException}
@@ -52,19 +52,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
private val akkaTimeout = AkkaUtils.askTimeout(conf)
- val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120 * 1000)
-
- val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)
-
- var timeoutCheckingTask: Cancellable = null
-
- override def preStart() {
- import context.dispatcher
- timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
- checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
- super.preStart()
- }
-
override def receiveWithLogging = {
case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
register(blockManagerId, maxMemSize, slaveActor)
@@ -118,14 +105,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case StopBlockManagerMaster =>
sender ! true
- if (timeoutCheckingTask != null) {
- timeoutCheckingTask.cancel()
- }
context.stop(self)
- case ExpireDeadHosts =>
- expireDeadHosts()
-
case BlockManagerHeartbeat(blockManagerId) =>
sender ! heartbeatReceived(blockManagerId)
@@ -207,21 +188,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
logInfo(s"Removing block manager $blockManagerId")
}
- private def expireDeadHosts() {
- logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.")
- val now = System.currentTimeMillis()
- val minSeenTime = now - slaveTimeout
- val toRemove = new mutable.HashSet[BlockManagerId]
- for (info <- blockManagerInfo.values) {
- if (info.lastSeenMs < minSeenTime && !info.blockManagerId.isDriver) {
- logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: "
- + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms")
- toRemove += info.blockManagerId
- }
- }
- toRemove.foreach(removeBlockManager)
- }
-
private def removeExecutor(execId: String) {
logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 3f32099d08cc9..48247453edef0 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -109,6 +109,4 @@ private[spark] object BlockManagerMessages {
extends ToBlockManagerMaster
case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
-
- case object ExpireDeadHosts extends ToBlockManagerMaster
}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index bf4b24e98b134..95f254a9ef22a 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -80,6 +80,10 @@ private[spark] object JettyUtils extends Logging {
response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage)
}
}
+ // SPARK-5983 ensure TRACE is not supported
+ protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = {
+ res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
+ }
}
}
@@ -119,6 +123,10 @@ private[spark] object JettyUtils extends Logging {
val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString
response.sendRedirect(newUrl)
}
+ // SPARK-5983 ensure TRACE is not supported
+ protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = {
+ res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
+ }
}
createServletHandler(srcPath, servlet, basePath)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index d752434ad58ae..110f8780a9a12 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -626,15 +626,16 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
}
private def getSchedulerDelay(info: TaskInfo, metrics: TaskMetrics): Long = {
- val totalExecutionTime = {
- if (info.gettingResultTime > 0) {
- (info.gettingResultTime - info.launchTime)
+ val totalExecutionTime =
+ if (info.gettingResult) {
+ info.gettingResultTime - info.launchTime
+ } else if (info.finished) {
+ info.finishTime - info.launchTime
} else {
- (info.finishTime - info.launchTime)
+ 0
}
- }
val executorOverhead = (metrics.executorDeserializeTime +
metrics.resultSerializationTime)
- totalExecutionTime - metrics.executorRunTime - executorOverhead
+ math.max(0, totalExecutionTime - metrics.executorRunTime - executorOverhead)
}
}
diff --git a/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java b/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java
new file mode 100644
index 0000000000000..45772b6d3c20d
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/util/collection/TestTimSort.java
@@ -0,0 +1,134 @@
+/**
+ * Copyright 2015 Stijn de Gouw
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ *
+ */
+package org.apache.spark.util.collection;
+
+import java.util.*;
+
+/**
+ * This codes generates a int array which fails the standard TimSort.
+ *
+ * The blog that reported the bug
+ * http://www.envisage-project.eu/timsort-specification-and-verification/
+ *
+ * This codes was originally wrote by Stijn de Gouw, modified by Evan Yu to adapt to
+ * our test suite.
+ *
+ * https://github.com/abstools/java-timsort-bug
+ * https://github.com/abstools/java-timsort-bug/blob/master/LICENSE
+ */
+public class TestTimSort {
+
+ private static final int MIN_MERGE = 32;
+
+ /**
+ * Returns an array of integers that demonstrate the bug in TimSort
+ */
+ public static int[] getTimSortBugTestSet(int length) {
+ int minRun = minRunLength(length);
+ List runs = runsJDKWorstCase(minRun, length);
+ return createArray(runs, length);
+ }
+
+ private static int minRunLength(int n) {
+ int r = 0; // Becomes 1 if any 1 bits are shifted off
+ while (n >= MIN_MERGE) {
+ r |= (n & 1);
+ n >>= 1;
+ }
+ return n + r;
+ }
+
+ private static int[] createArray(List runs, int length) {
+ int[] a = new int[length];
+ Arrays.fill(a, 0);
+ int endRun = -1;
+ for (long len : runs) {
+ a[endRun += len] = 1;
+ }
+ a[length - 1] = 0;
+ return a;
+ }
+
+ /**
+ * Fills runs
with a sequence of run lengths of the form
+ * Y_n x_{n,1} x_{n,2} ... x_{n,l_n}
+ * Y_{n-1} x_{n-1,1} x_{n-1,2} ... x_{n-1,l_{n-1}}
+ * ...
+ * Y_1 x_{1,1} x_{1,2} ... x_{1,l_1}
+ * The Y_i's are chosen to satisfy the invariant throughout execution,
+ * but the x_{i,j}'s are merged (by TimSort.mergeCollapse
)
+ * into an X_i that violates the invariant.
+ *
+ * @param length The sum of all run lengths that will be added to runs
.
+ */
+ private static List runsJDKWorstCase(int minRun, int length) {
+ List runs = new ArrayList();
+
+ long runningTotal = 0, Y = minRun + 4, X = minRun;
+
+ while (runningTotal + Y + X <= length) {
+ runningTotal += X + Y;
+ generateJDKWrongElem(runs, minRun, X);
+ runs.add(0, Y);
+ // X_{i+1} = Y_i + x_{i,1} + 1, since runs.get(1) = x_{i,1}
+ X = Y + runs.get(1) + 1;
+ // Y_{i+1} = X_{i+1} + Y_i + 1
+ Y += X + 1;
+ }
+
+ if (runningTotal + X <= length) {
+ runningTotal += X;
+ generateJDKWrongElem(runs, minRun, X);
+ }
+
+ runs.add(length - runningTotal);
+ return runs;
+ }
+
+ /**
+ * Adds a sequence x_1, ..., x_n of run lengths to runs
such that:
+ * 1. X = x_1 + ... + x_n
+ * 2. x_j >= minRun for all j
+ * 3. x_1 + ... + x_{j-2} < x_j < x_1 + ... + x_{j-1} for all j
+ * These conditions guarantee that TimSort merges all x_j's one by one
+ * (resulting in X) using only merges on the second-to-last element.
+ *
+ * @param X The sum of the sequence that should be added to runs.
+ */
+ private static void generateJDKWrongElem(List runs, int minRun, long X) {
+ for (long newTotal; X >= 2 * minRun + 1; X = newTotal) {
+ //Default strategy
+ newTotal = X / 2 + 1;
+ //Specialized strategies
+ if (3 * minRun + 3 <= X && X <= 4 * minRun + 1) {
+ // add x_1=MIN+1, x_2=MIN, x_3=X-newTotal to runs
+ newTotal = 2 * minRun + 1;
+ } else if (5 * minRun + 5 <= X && X <= 6 * minRun + 5) {
+ // add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=X-newTotal to runs
+ newTotal = 3 * minRun + 3;
+ } else if (8 * minRun + 9 <= X && X <= 10 * minRun + 9) {
+ // add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=X-newTotal to runs
+ newTotal = 5 * minRun + 5;
+ } else if (13 * minRun + 15 <= X && X <= 16 * minRun + 17) {
+ // add x_1=MIN+1, x_2=MIN, x_3=MIN+2, x_4=2MIN+2, x_5=3MIN+4, x_6=X-newTotal to runs
+ newTotal = 8 * minRun + 9;
+ }
+ runs.add(0, X - newTotal);
+ }
+ runs.add(0, X);
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
index ad62b35f624f6..8bcca926097a1 100644
--- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala
@@ -117,8 +117,20 @@ class SparkSubmitUtilsSuite extends FunSuite with BeforeAndAfterAll {
}
test("neglects Spark and Spark's dependencies") {
- val path = SparkSubmitUtils.resolveMavenCoordinates(
- "org.apache.spark:spark-core_2.10:1.2.0", None, None, true)
+ val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_",
+ "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_")
+
+ val coordinates =
+ components.map(comp => s"org.apache.spark:spark-${comp}2.10:1.2.0").mkString(",") +
+ ",org.apache.spark:spark-core_fake:1.2.0"
+
+ val path = SparkSubmitUtils.resolveMavenCoordinates(coordinates, None, None, true)
assert(path === "", "should return empty path")
+ // Should not exclude the following dependency. Will throw an error, because it doesn't exist,
+ // but the fact that it is checking means that it wasn't excluded.
+ intercept[RuntimeException] {
+ SparkSubmitUtils.resolveMavenCoordinates(coordinates +
+ ",org.apache.spark:spark-streaming-kafka-assembly_2.10:1.2.0", None, None, true)
+ }
}
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
new file mode 100644
index 0000000000000..3a9963a5ce7b7
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.history
+
+import javax.servlet.http.HttpServletRequest
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.Path
+import org.mockito.Mockito.{when}
+import org.scalatest.FunSuite
+import org.scalatest.Matchers
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.ui.SparkUI
+
+class HistoryServerSuite extends FunSuite with Matchers with MockitoSugar {
+
+ test("generate history page with relative links") {
+ val historyServer = mock[HistoryServer]
+ val request = mock[HttpServletRequest]
+ val ui = mock[SparkUI]
+ val link = "/history/app1"
+ val info = new ApplicationHistoryInfo("app1", "app1", 0, 2, 1, "xxx", true)
+ when(historyServer.getApplicationList()).thenReturn(Seq(info))
+ when(ui.basePath).thenReturn(link)
+ when(historyServer.getProviderConfig()).thenReturn(Map[String, String]())
+ val page = new HistoryPage(historyServer)
+
+ //when
+ val response = page.render(request)
+
+ //then
+ val links = response \\ "a"
+ val justHrefs = for {
+ l <- links
+ attrs <- l.attribute("href")
+ } yield (attrs.toString)
+ justHrefs should contain(link)
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 4bf7f9e647d55..30119ce5d4eec 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -96,6 +96,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
}
override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
override def defaultParallelism() = 2
+ override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
}
/** Length of time to wait while draining listener events. */
@@ -386,6 +387,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
override def defaultParallelism() = 2
override def executorHeartbeatReceived(execId: String, taskMetrics: Array[(Long, TaskMetrics)],
blockManagerId: BlockManagerId): Boolean = true
+ override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {}
}
val noKillScheduler = new DAGScheduler(
sc,
diff --git a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
index 0cb1ed7397655..e0d6cc16bde05 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/SorterSuite.scala
@@ -65,6 +65,13 @@ class SorterSuite extends FunSuite {
}
}
+ // http://www.envisage-project.eu/timsort-specification-and-verification/
+ test("SPARK-5984 TimSort bug") {
+ val data = TestTimSort.getTimSortBugTestSet(67108864)
+ new Sorter(new IntArraySortDataFormat).sort(data, 0, data.length, Ordering.Int)
+ (0 to data.length - 2).foreach(i => assert(data(i) <= data(i + 1)))
+ }
+
/** Runs an experiment several times. */
def runExperiment(name: String, skip: Boolean = false)(f: => Unit, prepare: () => Unit): Unit = {
if (skip) {
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index 935cd8dad3b25..76140282a2dd0 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -97,8 +97,9 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) =>
}.mean()
println("Mean Squared Error = " + MSE)
-model.save("myModelPath")
-val sameModel = MatrixFactorizationModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = MatrixFactorizationModel.load(sc, "myModelPath")
{% endhighlight %}
If the rating matrix is derived from another source of information (e.g., it is inferred from
@@ -186,8 +187,9 @@ public class CollaborativeFiltering {
).rdd()).mean();
System.out.println("Mean Squared Error = " + MSE);
- model.save("myModelPath");
- MatrixFactorizationModel sameModel = MatrixFactorizationModel.load("myModelPath");
+ // Save and load model
+ model.save(sc.sc(), "myModelPath");
+ MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath");
}
}
{% endhighlight %}
@@ -198,10 +200,8 @@ In the following example we load rating data. Each row consists of a user, a pro
We use the default ALS.train() method which assumes ratings are explicit. We evaluate the
recommendation by measuring the Mean Squared Error of rating prediction.
-Note that the Python API does not yet support model save/load but will in the future.
-
{% highlight python %}
-from pyspark.mllib.recommendation import ALS, Rating
+from pyspark.mllib.recommendation import ALS, MatrixFactorizationModel, Rating
# Load and parse the data
data = sc.textFile("data/mllib/als/test.data")
@@ -218,6 +218,10 @@ predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2]))
ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions)
MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count()
print("Mean Squared Error = " + str(MSE))
+
+# Save and load model
+model.save(sc, "myModelPath")
+sameModel = MatrixFactorizationModel.load(sc, "myModelPath")
{% endhighlight %}
If the rating matrix is derived from other source of information (i.e., it is inferred from other
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index 4695d1cde4901..8e478ab035582 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -223,8 +223,9 @@ val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.
println("Test Error = " + testErr)
println("Learned classification tree model:\n" + model.toDebugString)
-model.save("myModelPath")
-val sameModel = DecisionTreeModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = DecisionTreeModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -284,8 +285,9 @@ Double testErr =
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:\n" + model.toDebugString());
-model.save("myModelPath");
-DecisionTreeModel sameModel = DecisionTreeModel.load("myModelPath");
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
@@ -362,8 +364,9 @@ val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean
println("Test Mean Squared Error = " + testMSE)
println("Learned regression tree model:\n" + model.toDebugString)
-model.save("myModelPath")
-val sameModel = DecisionTreeModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = DecisionTreeModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -429,8 +432,9 @@ Double testMSE =
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression tree model:\n" + model.toDebugString());
-model.save("myModelPath");
-DecisionTreeModel sameModel = DecisionTreeModel.load("myModelPath");
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md
index ddae84165f8a9..ec1ef38b453d3 100644
--- a/docs/mllib-ensembles.md
+++ b/docs/mllib-ensembles.md
@@ -129,8 +129,9 @@ val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.
println("Test Error = " + testErr)
println("Learned classification forest model:\n" + model.toDebugString)
-model.save("myModelPath")
-val sameModel = RandomForestModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -193,8 +194,9 @@ Double testErr =
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification forest model:\n" + model.toDebugString());
-model.save("myModelPath");
-RandomForestModel sameModel = RandomForestModel.load("myModelPath");
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
@@ -276,8 +278,9 @@ val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean
println("Test Mean Squared Error = " + testMSE)
println("Learned regression forest model:\n" + model.toDebugString)
-model.save("myModelPath")
-val sameModel = RandomForestModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -343,8 +346,9 @@ Double testMSE =
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression forest model:\n" + model.toDebugString());
-model.save("myModelPath");
-RandomForestModel sameModel = RandomForestModel.load("myModelPath");
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
@@ -504,8 +508,9 @@ val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.
println("Test Error = " + testErr)
println("Learned classification GBT model:\n" + model.toDebugString)
-model.save("myModelPath")
-val sameModel = GradientBoostedTreesModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -568,8 +573,9 @@ Double testErr =
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification GBT model:\n" + model.toDebugString());
-model.save("myModelPath");
-GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load("myModelPath");
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
@@ -647,8 +653,9 @@ val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean
println("Test Mean Squared Error = " + testMSE)
println("Learned regression GBT model:\n" + model.toDebugString)
-model.save("myModelPath")
-val sameModel = GradientBoostedTreesModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -717,8 +724,9 @@ Double testMSE =
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression GBT model:\n" + model.toDebugString());
-model.save("myModelPath");
-GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load("myModelPath");
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index 5b97d8c177975..15eef91e77d0b 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -189,8 +189,9 @@ val auROC = metrics.areaUnderROC()
println("Area under ROC = " + auROC)
-model.save("myModelPath")
-val sameModel = SVMModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = SVMModel.load(sc, "myModelPath")
{% endhighlight %}
The `SVMWithSGD.train()` method by default performs L2 regularization with the
@@ -274,8 +275,9 @@ public class SVMClassifier {
System.out.println("Area under ROC = " + auROC);
- model.save("myModelPath");
- SVMModel sameModel = SVMModel.load("myModelPath");
+ // Save and load model
+ model.save(sc.sc(), "myModelPath");
+ SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath");
}
}
{% endhighlight %}
@@ -442,8 +444,9 @@ val valuesAndPreds = parsedData.map { point =>
val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean()
println("training Mean Squared Error = " + MSE)
-model.save("myModelPath")
-val sameModel = LinearRegressionModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = LinearRegressionModel.load(sc, "myModelPath")
{% endhighlight %}
[`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD)
@@ -515,8 +518,9 @@ public class LinearRegression {
).rdd()).mean();
System.out.println("training Mean Squared Error = " + MSE);
- model.save("myModelPath");
- LinearRegressionModel sameModel = LinearRegressionModel.load("myModelPath");
+ // Save and load model
+ model.save(sc.sc(), "myModelPath");
+ LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath");
}
}
{% endhighlight %}
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 81173255b590d..55b8f2ce6c364 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -56,8 +56,9 @@ val model = NaiveBayes.train(training, lambda = 1.0)
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
-model.save("myModelPath")
-val sameModel = NaiveBayesModel.load("myModelPath")
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = NaiveBayesModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -97,8 +98,9 @@ double accuracy = predictionAndLabel.filter(new Function,
}
}).count() / (double) test.count();
-model.save("myModelPath");
-NaiveBayesModel sameModel = NaiveBayesModel.load("myModelPath");
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
@@ -113,22 +115,28 @@ used for evaluation and prediction.
Note that the Python API does not yet support model save/load but will in the future.
-
{% highlight python %}
-from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.classification import NaiveBayes
+from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.regression import LabeledPoint
+
+def parseLine(line):
+ parts = line.split(',')
+ label = float(parts[0])
+ features = Vectors.dense([float(x) for x in parts[1].split(' ')])
+ return LabeledPoint(label, features)
+
+data = sc.textFile('data/mllib/sample_naive_bayes_data.txt').map(parseLine)
-# an RDD of LabeledPoint
-data = sc.parallelize([
- LabeledPoint(0.0, [0.0, 0.0])
- ... # more labeled points
-])
+# Split data aproximately into training (60%) and test (40%)
+training, test = data.randomSplit([0.6, 0.4], seed = 0)
# Train a naive Bayes model.
-model = NaiveBayes.train(data, 1.0)
+model = NaiveBayes.train(training, 1.0)
-# Make prediction.
-prediction = model.predict([0.0, 0.0])
+# Make prediction and test accuracy.
+predictionAndLabel = test.map(lambda p : (model.predict(p.features), p.label))
+accuracy = 1.0 * predictionAndLabel.filter(lambda (x, v): x == v).count() / test.count()
{% endhighlight %}
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 5c6084fb46255..74d8653a8b845 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -222,8 +222,7 @@ SPARK_WORKER_OPTS supports the following system properties:
false |
Enable periodic cleanup of worker / application directories. Note that this only affects standalone
- mode, as YARN works differently. Applications directories are cleaned up regardless of whether
- the application is still running.
+ mode, as YARN works differently. Only the directories of stopped applications are cleaned up.
|
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index 62a659518943d..5a9bd4214cf51 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -512,7 +512,7 @@ object KafkaUtils {
* @param topics Names of the topics to consume
*/
@Experimental
- def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R](
+ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]](
jssc: JavaStreamingContext,
keyClass: Class[K],
valueClass: Class[V],
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index b0e991d2f2344..0a3f21ecee0dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -130,7 +130,7 @@ abstract class LDAModel private[clustering] {
/* TODO
* Compute the estimated topic distribution for each document.
- * This is often called “theta” in the literature.
+ * This is often called 'theta' in the literature.
*
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
* The term count vectors are "bags of words" with a fixed-size vocabulary
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
index 4458340497f0b..526d055c87387 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
@@ -48,7 +48,7 @@ trait Saveable {
*
* @param sc Spark context used to save model data.
* @param path Path specifying the directory in which to save this model.
- * This directory and any intermediate directory will be created if needed.
+ * If the directory already exists, this method throws an exception.
*/
def save(sc: SparkContext, path: String): Unit
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index acec8f18f2b5c..39b99f54f6dbc 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -33,6 +33,8 @@
http://spark.apache.org/
network-yarn
+
+ provided
@@ -47,7 +49,6 @@
org.apache.hadoop
hadoop-client
- provided
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 0d99e6dedfad9..03d7d011474cb 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -19,7 +19,8 @@
from pyspark import SparkContext
from pyspark.rdd import RDD
-from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
+from pyspark.mllib.util import Saveable, JavaLoader
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
@@ -39,7 +40,8 @@ def __reduce__(self):
return Rating, (int(self.user), int(self.product), float(self.rating))
-class MatrixFactorizationModel(JavaModelWrapper):
+@inherit_doc
+class MatrixFactorizationModel(JavaModelWrapper, Saveable, JavaLoader):
"""A matrix factorisation model trained by regularized alternating
least-squares.
@@ -81,6 +83,17 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
>>> model.predict(2,2)
0.43...
+
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = MatrixFactorizationModel.load(sc, path)
+ >>> sameModel.predict(2,2)
+ 0.43...
+ >>> try:
+ ... os.removedirs(path)
+ ... except:
+ ... pass
"""
def predict(self, user, product):
return self._java_model.predict(int(user), int(product))
@@ -98,6 +111,9 @@ def userFeatures(self):
def productFeatures(self):
return self.call("getProductFeatures")
+ def save(self, sc, path):
+ self.call("save", sc._jsc.sc(), path)
+
class ALS(object):
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 4ed978b45409c..17d43eadba12b 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -168,6 +168,64 @@ def loadLabeledPoints(sc, path, minPartitions=None):
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
+class Saveable(object):
+ """
+ Mixin for models and transformers which may be saved as files.
+ """
+
+ def save(self, sc, path):
+ """
+ Save this model to the given path.
+
+ This saves:
+ * human-readable (JSON) model metadata to path/metadata/
+ * Parquet formatted data to path/data/
+
+ The model may be loaded using py:meth:`Loader.load`.
+
+ :param sc: Spark context used to save model data.
+ :param path: Path specifying the directory in which to save
+ this model. If the directory already exists,
+ this method throws an exception.
+ """
+ raise NotImplementedError
+
+
+class Loader(object):
+ """
+ Mixin for classes which can load saved models from files.
+ """
+
+ @classmethod
+ def load(cls, sc, path):
+ """
+ Load a model from the given path. The model should have been
+ saved using py:meth:`Saveable.save`.
+
+ :param sc: Spark context used for loading model files.
+ :param path: Path specifying the directory to which the model
+ was saved.
+ :return: model instance
+ """
+ raise NotImplemented
+
+
+class JavaLoader(Loader):
+ """
+ Mixin for classes which can load saved models using its Scala
+ implementation.
+ """
+
+ @classmethod
+ def load(cls, sc, path):
+ java_package = cls.__module__.replace("pyspark", "org.apache.spark")
+ java_class = ".".join([java_package, cls.__name__])
+ java_obj = sc._jvm
+ for name in java_class.split("."):
+ java_obj = getattr(java_obj, name)
+ return cls(java_obj.load(sc._jsc.sc(), path))
+
+
def _test():
import doctest
from pyspark.context import SparkContext
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 5d7aeb664cadf..795ef0dbc4c47 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -17,7 +17,6 @@
import warnings
import json
-from array import array
from itertools import imap
from py4j.protocol import Py4JError
@@ -25,7 +24,7 @@
from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import StringType, StructType, _verify_type, \
+from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame
@@ -620,93 +619,6 @@ def _get_hive_ctx(self):
return self._jvm.HiveContext(self._jsc.sc())
-def _create_row(fields, values):
- row = Row(*values)
- row.__FIELDS__ = fields
- return row
-
-
-class Row(tuple):
-
- """
- A row in L{DataFrame}. The fields in it can be accessed like attributes.
-
- Row can be used to create a row object by using named arguments,
- the fields will be sorted by names.
-
- >>> row = Row(name="Alice", age=11)
- >>> row
- Row(age=11, name='Alice')
- >>> row.name, row.age
- ('Alice', 11)
-
- Row also can be used to create another Row like class, then it
- could be used to create Row objects, such as
-
- >>> Person = Row("name", "age")
- >>> Person
-
- >>> Person("Alice", 11)
- Row(name='Alice', age=11)
- """
-
- def __new__(self, *args, **kwargs):
- if args and kwargs:
- raise ValueError("Can not use both args "
- "and kwargs to create Row")
- if args:
- # create row class or objects
- return tuple.__new__(self, args)
-
- elif kwargs:
- # create row objects
- names = sorted(kwargs.keys())
- values = tuple(kwargs[n] for n in names)
- row = tuple.__new__(self, values)
- row.__FIELDS__ = names
- return row
-
- else:
- raise ValueError("No args or kwargs")
-
- def asDict(self):
- """
- Return as an dict
- """
- if not hasattr(self, "__FIELDS__"):
- raise TypeError("Cannot convert a Row class into dict")
- return dict(zip(self.__FIELDS__, self))
-
- # let obect acs like class
- def __call__(self, *args):
- """create new Row object"""
- return _create_row(self, args)
-
- def __getattr__(self, item):
- if item.startswith("__"):
- raise AttributeError(item)
- try:
- # it will be slow when it has many fields,
- # but this will not be used in normal cases
- idx = self.__FIELDS__.index(item)
- return self[idx]
- except IndexError:
- raise AttributeError(item)
-
- def __reduce__(self):
- if hasattr(self, "__FIELDS__"):
- return (_create_row, (self.__FIELDS__, tuple(self)))
- else:
- return tuple.__reduce__(self)
-
- def __repr__(self):
- if hasattr(self, "__FIELDS__"):
- return "Row(%s)" % ", ".join("%s=%r" % (k, v)
- for k, v in zip(self.__FIELDS__, self))
- else:
- return "" % ", ".join(self)
-
-
def _test():
import doctest
from pyspark.context import SparkContext
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index aec99017fbdc1..5c3b7377c33b5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -1025,10 +1025,12 @@ def cast(self, dataType):
ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(dataType.json())
jc = self._jc.cast(jdt)
+ else:
+ raise TypeError("unexpected type: %s" % type(dataType))
return Column(jc)
def __repr__(self):
- return 'Column<%s>' % self._jdf.toString().encode('utf8')
+ return 'Column<%s>' % self._jc.toString().encode('utf8')
def _test():
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 83899ad4b1b12..2720439416682 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -24,6 +24,7 @@
import pydoc
import shutil
import tempfile
+import pickle
import py4j
@@ -88,6 +89,14 @@ def __eq__(self, other):
other.x == self.x and other.y == self.y
+class DataTypeTests(unittest.TestCase):
+ # regression test for SPARK-6055
+ def test_data_type_eq(self):
+ lt = LongType()
+ lt2 = pickle.loads(pickle.dumps(LongType()))
+ self.assertEquals(lt, lt2)
+
+
class SQLTests(ReusedPySparkTestCase):
@classmethod
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 0f5dc2be6dab8..31a861e1feb46 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -21,6 +21,7 @@
import warnings
import json
import re
+import weakref
from array import array
from operator import itemgetter
@@ -42,8 +43,7 @@ def __hash__(self):
return hash(str(self))
def __eq__(self, other):
- return (isinstance(other, self.__class__) and
- self.__dict__ == other.__dict__)
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
@@ -64,6 +64,8 @@ def json(self):
sort_keys=True)
+# This singleton pattern does not work with pickle, you will get
+# another object after pickle and unpickle
class PrimitiveTypeSingleton(type):
"""Metaclass for PrimitiveType"""
@@ -82,10 +84,6 @@ class PrimitiveType(DataType):
__metaclass__ = PrimitiveTypeSingleton
- def __eq__(self, other):
- # because they should be the same object
- return self is other
-
class NullType(PrimitiveType):
@@ -242,11 +240,12 @@ def __init__(self, elementType, containsNull=True):
:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains None values.
- >>> ArrayType(StringType) == ArrayType(StringType, True)
+ >>> ArrayType(StringType()) == ArrayType(StringType(), True)
True
- >>> ArrayType(StringType, False) == ArrayType(StringType)
+ >>> ArrayType(StringType(), False) == ArrayType(StringType())
False
"""
+ assert isinstance(elementType, DataType), "elementType should be DataType"
self.elementType = elementType
self.containsNull = containsNull
@@ -292,13 +291,15 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
:param valueContainsNull: indicates whether values contains
null values.
- >>> (MapType(StringType, IntegerType)
- ... == MapType(StringType, IntegerType, True))
+ >>> (MapType(StringType(), IntegerType())
+ ... == MapType(StringType(), IntegerType(), True))
True
- >>> (MapType(StringType, IntegerType, False)
- ... == MapType(StringType, FloatType))
+ >>> (MapType(StringType(), IntegerType(), False)
+ ... == MapType(StringType(), FloatType()))
False
"""
+ assert isinstance(keyType, DataType), "keyType should be DataType"
+ assert isinstance(valueType, DataType), "valueType should be DataType"
self.keyType = keyType
self.valueType = valueType
self.valueContainsNull = valueContainsNull
@@ -348,13 +349,14 @@ def __init__(self, name, dataType, nullable=True, metadata=None):
to simple type that can be serialized to JSON
automatically
- >>> (StructField("f1", StringType, True)
- ... == StructField("f1", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f1", StringType(), True))
True
- >>> (StructField("f1", StringType, True)
- ... == StructField("f2", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f2", StringType(), True))
False
"""
+ assert isinstance(dataType, DataType), "dataType should be DataType"
self.name = name
self.dataType = dataType
self.nullable = nullable
@@ -393,16 +395,17 @@ class StructType(DataType):
def __init__(self, fields):
"""Creates a StructType
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True)])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True),
- ... [StructField("f2", IntegerType, False)]])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True),
+ ... StructField("f2", IntegerType(), False)])
>>> struct1 == struct2
False
"""
+ assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
self.fields = fields
def simpleString(self):
@@ -505,20 +508,24 @@ def __eq__(self, other):
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
+ >>> import pickle
>>> def check_datatype(datatype):
+ ... pickled = pickle.loads(pickle.dumps(datatype))
+ ... assert datatype == pickled
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
- ... return datatype == python_datatype
- >>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
- True
+ ... assert datatype == python_datatype
+ >>> for cls in _all_primitive_types.values():
+ ... check_datatype(cls())
+
>>> # Simple ArrayType.
>>> simple_arraytype = ArrayType(StringType(), True)
>>> check_datatype(simple_arraytype)
- True
+
>>> # Simple MapType.
>>> simple_maptype = MapType(StringType(), LongType())
>>> check_datatype(simple_maptype)
- True
+
>>> # Simple StructType.
>>> simple_structtype = StructType([
... StructField("a", DecimalType(), False),
@@ -526,7 +533,7 @@ def _parse_datatype_json_string(json_string):
... StructField("c", LongType(), True),
... StructField("d", BinaryType(), False)])
>>> check_datatype(simple_structtype)
- True
+
>>> # Complex StructType.
>>> complex_structtype = StructType([
... StructField("simpleArray", simple_arraytype, True),
@@ -535,22 +542,20 @@ def _parse_datatype_json_string(json_string):
... StructField("boolean", BooleanType(), False),
... StructField("withMeta", DoubleType(), False, {"name": "age"})])
>>> check_datatype(complex_structtype)
- True
+
>>> # Complex ArrayType.
>>> complex_arraytype = ArrayType(complex_structtype, True)
>>> check_datatype(complex_arraytype)
- True
+
>>> # Complex MapType.
>>> complex_maptype = MapType(complex_structtype,
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
- True
+
>>> check_datatype(ExamplePointUDT())
- True
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> check_datatype(structtype_with_udt)
- True
"""
return _parse_datatype_json_value(json.loads(json_string))
@@ -786,8 +791,24 @@ def _merge_type(a, b):
return a
+def _need_converter(dataType):
+ if isinstance(dataType, StructType):
+ return True
+ elif isinstance(dataType, ArrayType):
+ return _need_converter(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
+ elif isinstance(dataType, NullType):
+ return True
+ else:
+ return False
+
+
def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
+ if not _need_converter(dataType):
+ return lambda x: x
+
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
@@ -806,13 +827,17 @@ def _create_converter(dataType):
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
+ convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
def convert_struct(obj):
if obj is None:
return
if isinstance(obj, (tuple, list)):
- return tuple(conv(v) for v, conv in zip(obj, converters))
+ if convert_fields:
+ return tuple(conv(v) for v, conv in zip(obj, converters))
+ else:
+ return tuple(obj)
if isinstance(obj, dict):
d = obj
@@ -821,7 +846,10 @@ def convert_struct(obj):
else:
raise ValueError("Unexpected obj: %s" % obj)
- return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ if convert_fields:
+ return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ else:
+ return tuple([d.get(name) for name in names])
return convert_struct
@@ -871,20 +899,20 @@ def _parse_field_abstract(s):
Parse a field in schema abstract
>>> _parse_field_abstract("a")
- StructField(a,None,true)
+ StructField(a,NullType,true)
>>> _parse_field_abstract("b(c d)")
- StructField(b,StructType(...c,None,true),StructField(d...
+ StructField(b,StructType(...c,NullType,true),StructField(d...
>>> _parse_field_abstract("a[]")
- StructField(a,ArrayType(None,true),true)
+ StructField(a,ArrayType(NullType,true),true)
>>> _parse_field_abstract("a{[]}")
- StructField(a,MapType(None,ArrayType(None,true),true),true)
+ StructField(a,MapType(NullType,ArrayType(NullType,true),true),true)
"""
if set(_BRACKETS.keys()) & set(s):
idx = min((s.index(c) for c in _BRACKETS if c in s))
name = s[:idx]
return StructField(name, _parse_schema_abstract(s[idx:]), True)
else:
- return StructField(s, None, True)
+ return StructField(s, NullType(), True)
def _parse_schema_abstract(s):
@@ -898,11 +926,11 @@ def _parse_schema_abstract(s):
>>> _parse_schema_abstract("c{} d{a b}")
StructType...c,MapType...d,MapType...a...b...
>>> _parse_schema_abstract("a b(t)").fields[1]
- StructField(b,StructType(List(StructField(t,None,true))),true)
+ StructField(b,StructType(List(StructField(t,NullType,true))),true)
"""
s = s.strip()
if not s:
- return
+ return NullType()
elif s.startswith('('):
return _parse_schema_abstract(s[1:-1])
@@ -911,7 +939,7 @@ def _parse_schema_abstract(s):
return ArrayType(_parse_schema_abstract(s[1:-1]), True)
elif s.startswith('{'):
- return MapType(None, _parse_schema_abstract(s[1:-1]))
+ return MapType(NullType(), _parse_schema_abstract(s[1:-1]))
parts = _split_schema_abstract(s)
fields = [_parse_field_abstract(p) for p in parts]
@@ -931,7 +959,7 @@ def _infer_schema_type(obj, dataType):
>>> _infer_schema_type(row, schema)
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
- if dataType is None:
+ if dataType is NullType():
return _infer_type(obj)
if not obj:
@@ -1037,8 +1065,7 @@ def _verify_type(obj, dataType):
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)
-
-_cached_cls = {}
+_cached_cls = weakref.WeakValueDictionary()
def _restore_object(dataType, obj):
@@ -1233,8 +1260,7 @@ def __new__(self, *args, **kwargs):
elif kwargs:
# create row objects
names = sorted(kwargs.keys())
- values = tuple(kwargs[n] for n in names)
- row = tuple.__new__(self, values)
+ row = tuple.__new__(self, [kwargs[n] for n in names])
row.__FIELDS__ = names
return row
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 03a5c9e7c24a0..e28baa512b95c 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -109,5 +109,13 @@
target/scala-${scala.binary.version}/classes
target/scala-${scala.binary.version}/test-classes
+
+
+ ../../python
+
+ pyspark/sql/*.py
+
+
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index a08c0f5ce3ff4..4815620c6fe57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -51,6 +51,11 @@ private[spark] object SQLConf {
// This is used to set the default data source
val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default"
+ // This is used to control the when we will split a schema's JSON string to multiple pieces
+ // in order to fit the JSON string in metastore's table property (by default, the value has
+ // a length restriction of 4000 characters). We will split the JSON string of a schema
+ // to its length exceeds the threshold.
+ val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold"
// Whether to perform eager analysis when constructing a dataframe.
// Set to false when debugging requires the ability to look at invalid query plans.
@@ -177,6 +182,11 @@ private[sql] class SQLConf extends Serializable {
private[spark] def defaultDataSourceName: String =
getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet")
+ // Do not use a value larger than 4000 as the default value of this property.
+ // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
+ private[spark] def schemaStringLengthThreshold: Int =
+ getConf(SCHEMA_STRING_LENGTH_THRESHOLD, "4000").toInt
+
private[spark] def dataFrameEagerAnalysis: Boolean =
getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 9061d3f5fee4d..225ec6db7d553 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -126,6 +126,13 @@ private[sql] case class ParquetTableScan(
conf)
if (requestedPartitionOrdinals.nonEmpty) {
+ // This check is based on CatalystConverter.createRootConverter.
+ val primitiveRow = output.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))
+
+ // Uses temporary variable to avoid the whole `ParquetTableScan` object being captured into
+ // the `mapPartitionsWithInputSplit` closure below.
+ val outputSize = output.size
+
baseRDD.mapPartitionsWithInputSplit { case (split, iter) =>
val partValue = "([^=]+)=([^=]+)".r
val partValues =
@@ -143,19 +150,47 @@ private[sql] case class ParquetTableScan(
relation.partitioningAttributes
.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
- new Iterator[Row] {
- def hasNext = iter.hasNext
- def next() = {
- val row = iter.next()._2.asInstanceOf[SpecificMutableRow]
-
- // Parquet will leave partitioning columns empty, so we fill them in here.
- var i = 0
- while (i < requestedPartitionOrdinals.size) {
- row(requestedPartitionOrdinals(i)._2) =
- partitionRowValues(requestedPartitionOrdinals(i)._1)
- i += 1
+ if (primitiveRow) {
+ new Iterator[Row] {
+ def hasNext = iter.hasNext
+ def next() = {
+ // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow.
+ val row = iter.next()._2.asInstanceOf[SpecificMutableRow]
+
+ // Parquet will leave partitioning columns empty, so we fill them in here.
+ var i = 0
+ while (i < requestedPartitionOrdinals.size) {
+ row(requestedPartitionOrdinals(i)._2) =
+ partitionRowValues(requestedPartitionOrdinals(i)._1)
+ i += 1
+ }
+ row
+ }
+ }
+ } else {
+ // Create a mutable row since we need to fill in values from partition columns.
+ val mutableRow = new GenericMutableRow(outputSize)
+ new Iterator[Row] {
+ def hasNext = iter.hasNext
+ def next() = {
+ // We are using CatalystGroupConverter and it returns a GenericRow.
+ // Since GenericRow is not mutable, we just cast it to a Row.
+ val row = iter.next()._2.asInstanceOf[Row]
+
+ var i = 0
+ while (i < row.size) {
+ mutableRow(i) = row(i)
+ i += 1
+ }
+ // Parquet will leave partitioning columns empty, so we fill them in here.
+ i = 0
+ while (i < requestedPartitionOrdinals.size) {
+ mutableRow(requestedPartitionOrdinals(i)._2) =
+ partitionRowValues(requestedPartitionOrdinals(i)._1)
+ i += 1
+ }
+ mutableRow
}
- row
}
}
}
@@ -434,22 +469,13 @@ private[parquet] class FilteringParquetRowInputFormat
return splits
}
- Option(globalMetaData.getKeyValueMetaData.get(RowReadSupport.SPARK_METADATA_KEY)).foreach {
- schemas =>
- val mergedSchema = schemas
- .map(DataType.fromJson(_).asInstanceOf[StructType])
- .reduce(_ merge _)
- .json
-
- val mergedMetadata = globalMetaData
- .getKeyValueMetaData
- .updated(RowReadSupport.SPARK_METADATA_KEY, setAsJavaSet(Set(mergedSchema)))
-
- globalMetaData = new GlobalMetaData(
- globalMetaData.getSchema,
- mergedMetadata,
- globalMetaData.getCreatedBy)
- }
+ val metadata = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA)
+ val mergedMetadata = globalMetaData
+ .getKeyValueMetaData
+ .updated(RowReadSupport.SPARK_METADATA_KEY, setAsJavaSet(Set(metadata)))
+
+ globalMetaData = new GlobalMetaData(globalMetaData.getSchema,
+ mergedMetadata, globalMetaData.getCreatedBy)
val readContext = getReadSupport(configuration).init(
new InitContext(configuration,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index e648618468d5d..6d56be3ab8dd4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -482,6 +482,10 @@ private[sql] case class ParquetRelation2(
// When the data does not include the key and the key is requested then we must fill it in
// based on information from the input split.
if (!partitionKeysIncludedInDataSchema && partitionKeyLocations.nonEmpty) {
+ // This check is based on CatalystConverter.createRootConverter.
+ val primitiveRow =
+ requestedSchema.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))
+
baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) =>
val partValues = selectedPartitions.collectFirst {
case p if split.getPath.getParent.toString == p.path => p.values
@@ -489,16 +493,42 @@ private[sql] case class ParquetRelation2(
val requiredPartOrdinal = partitionKeyLocations.keys.toSeq
- iterator.map { pair =>
- val row = pair._2.asInstanceOf[SpecificMutableRow]
- var i = 0
- while (i < requiredPartOrdinal.size) {
- // TODO Avoids boxing cost here!
- val partOrdinal = requiredPartOrdinal(i)
- row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal))
- i += 1
+ if (primitiveRow) {
+ iterator.map { pair =>
+ // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow.
+ val row = pair._2.asInstanceOf[SpecificMutableRow]
+ var i = 0
+ while (i < requiredPartOrdinal.size) {
+ // TODO Avoids boxing cost here!
+ val partOrdinal = requiredPartOrdinal(i)
+ row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal))
+ i += 1
+ }
+ row
+ }
+ } else {
+ // Create a mutable row since we need to fill in values from partition columns.
+ val mutableRow = new GenericMutableRow(requestedSchema.size)
+ iterator.map { pair =>
+ // We are using CatalystGroupConverter and it returns a GenericRow.
+ // Since GenericRow is not mutable, we just cast it to a Row.
+ val row = pair._2.asInstanceOf[Row]
+ var i = 0
+ while (i < row.size) {
+ // TODO Avoids boxing cost here!
+ mutableRow(i) = row(i)
+ i += 1
+ }
+
+ i = 0
+ while (i < requiredPartOrdinal.size) {
+ // TODO Avoids boxing cost here!
+ val partOrdinal = requiredPartOrdinal(i)
+ mutableRow.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal))
+ i += 1
+ }
+ mutableRow
}
- row
}
}
} else {
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
index 77ef37253e38f..d783d487b5c60 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -39,6 +39,7 @@ import org.scalatest.{BeforeAndAfterAll, FunSuite}
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql.hive.HiveShim
+import org.apache.spark.util.Utils
object TestData {
def getTestDataFilePath(name: String) = {
@@ -273,6 +274,7 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit
private var metastorePath: File = _
private def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
+ private val pidDir: File = Utils.createTempDir("thriftserver-pid")
private var logPath: File = _
private var logTailingProcess: Process = _
private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String]
@@ -315,7 +317,14 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit
logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt")
- logPath = Process(command, None, "SPARK_TESTING" -> "0").lines.collectFirst {
+ val env = Seq(
+ // Disables SPARK_TESTING to exclude log4j.properties in test directories.
+ "SPARK_TESTING" -> "0",
+ // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be started
+ // at a time, which is not Jenkins friendly.
+ "SPARK_PID_DIR" -> pidDir.getCanonicalPath)
+
+ logPath = Process(command, None, env: _*).lines.collectFirst {
case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length))
}.getOrElse {
throw new RuntimeException("Failed to find HiveThriftServer2 log file.")
@@ -346,7 +355,7 @@ abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll wit
private def stopThriftServer(): Unit = {
// The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
- Process(stopScript, None).run().exitValue()
+ Process(stopScript, None, "SPARK_PID_DIR" -> pidDir.getCanonicalPath).run().exitValue()
Thread.sleep(3.seconds.toMillis)
warehousePath.delete()
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 8af5a4848fd44..d3ad364328265 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -69,13 +69,23 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
val table = synchronized {
client.getTable(in.database, in.name)
}
- val schemaString = table.getProperty("spark.sql.sources.schema")
val userSpecifiedSchema =
- if (schemaString == null) {
- None
- } else {
- Some(DataType.fromJson(schemaString).asInstanceOf[StructType])
+ Option(table.getProperty("spark.sql.sources.schema.numParts")).map { numParts =>
+ val parts = (0 until numParts.toInt).map { index =>
+ val part = table.getProperty(s"spark.sql.sources.schema.part.${index}")
+ if (part == null) {
+ throw new AnalysisException(
+ s"Could not read schema from the metastore because it is corrupted " +
+ s"(missing part ${index} of the schema).")
+ }
+
+ part
+ }
+ // Stick all parts back to a single schema string in the JSON representation
+ // and convert it back to a StructType.
+ DataType.fromJson(parts.mkString).asInstanceOf[StructType]
}
+
// It does not appear that the ql client for the metastore has a way to enumerate all the
// SerDe properties directly...
val options = table.getTTable.getSd.getSerdeInfo.getParameters.toMap
@@ -119,7 +129,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
tbl.setProperty("spark.sql.sources.provider", provider)
if (userSpecifiedSchema.isDefined) {
- tbl.setProperty("spark.sql.sources.schema", userSpecifiedSchema.get.json)
+ val threshold = hive.conf.schemaStringLengthThreshold
+ val schemaJsonString = userSpecifiedSchema.get.json
+ // Split the JSON string.
+ val parts = schemaJsonString.grouped(threshold).toSeq
+ tbl.setProperty("spark.sql.sources.schema.numParts", parts.size.toString)
+ parts.zipWithIndex.foreach { case (part, index) =>
+ tbl.setProperty(s"spark.sql.sources.schema.part.${index}", part)
+ }
}
options.foreach { case (key, value) => tbl.setSerdeParam(key, value) }
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 0bd82773f3a55..00306f1cd7f86 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -591,4 +591,25 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalUseDataSource)
}
}
+
+ test("SPARK-6024 wide schema support") {
+ // We will need 80 splits for this schema if the threshold is 4000.
+ val schema = StructType((1 to 5000).map(i => StructField(s"c_${i}", StringType, true)))
+ assert(
+ schema.json.size > conf.schemaStringLengthThreshold,
+ "To correctly test the fix of SPARK-6024, the value of " +
+ s"spark.sql.sources.schemaStringLengthThreshold needs to be less than ${schema.json.size}")
+ // Manually create a metastore data source table.
+ catalog.createDataSourceTable(
+ tableName = "wide_schema",
+ userSpecifiedSchema = Some(schema),
+ provider = "json",
+ options = Map("path" -> "just a dummy path"),
+ isExternal = false)
+
+ invalidateTable("wide_schema")
+
+ val actualSchema = table("wide_schema").schema
+ assert(schema === actualSchema)
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
index 6a9d9daf6750c..c8da8eea4e646 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
@@ -36,6 +36,20 @@ case class ParquetData(intField: Int, stringField: String)
// The data that also includes the partitioning key
case class ParquetDataWithKey(p: Int, intField: Int, stringField: String)
+case class StructContainer(intStructField :Int, stringStructField: String)
+
+case class ParquetDataWithComplexTypes(
+ intField: Int,
+ stringField: String,
+ structField: StructContainer,
+ arrayField: Seq[Int])
+
+case class ParquetDataWithKeyAndComplexTypes(
+ p: Int,
+ intField: Int,
+ stringField: String,
+ structField: StructContainer,
+ arrayField: Seq[Int])
/**
* A suite to test the automatic conversion of metastore tables with parquet data to use the
@@ -86,6 +100,38 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest {
location '${new File(normalTableDir, "normal").getCanonicalPath}'
""")
+ sql(s"""
+ CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes
+ (
+ intField INT,
+ stringField STRING,
+ structField STRUCT,
+ arrayField ARRAY
+ )
+ PARTITIONED BY (p int)
+ ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ STORED AS
+ INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ LOCATION '${partitionedTableDirWithComplexTypes.getCanonicalPath}'
+ """)
+
+ sql(s"""
+ CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes
+ (
+ intField INT,
+ stringField STRING,
+ structField STRUCT,
+ arrayField ARRAY
+ )
+ PARTITIONED BY (p int)
+ ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ STORED AS
+ INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}'
+ """)
+
(1 to 10).foreach { p =>
sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)")
}
@@ -94,7 +140,15 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest {
sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)")
}
- val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
+ (1 to 10).foreach { p =>
+ sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)")
+ }
+
+ (1 to 10).foreach { p =>
+ sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)")
+ }
+
+ val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
jsonRDD(rdd1).registerTempTable("jt")
val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}"""))
jsonRDD(rdd2).registerTempTable("jt_array")
@@ -105,6 +159,8 @@ class ParquetMetastoreSuiteBase extends ParquetPartitioningTest {
override def afterAll(): Unit = {
sql("DROP TABLE partitioned_parquet")
sql("DROP TABLE partitioned_parquet_with_key")
+ sql("DROP TABLE partitioned_parquet_with_complextypes")
+ sql("DROP TABLE partitioned_parquet_with_key_and_complextypes")
sql("DROP TABLE normal_parquet")
sql("DROP TABLE IF EXISTS jt")
sql("DROP TABLE IF EXISTS jt_array")
@@ -409,6 +465,22 @@ class ParquetSourceSuiteBase extends ParquetPartitioningTest {
path '${new File(partitionedTableDir, "p=1").getCanonicalPath}'
)
""")
+
+ sql( s"""
+ CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes
+ USING org.apache.spark.sql.parquet
+ OPTIONS (
+ path '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}'
+ )
+ """)
+
+ sql( s"""
+ CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes
+ USING org.apache.spark.sql.parquet
+ OPTIONS (
+ path '${partitionedTableDirWithComplexTypes.getCanonicalPath}'
+ )
+ """)
}
test("SPARK-6016 make sure to use the latest footers") {
@@ -473,7 +545,8 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
var partitionedTableDir: File = null
var normalTableDir: File = null
var partitionedTableDirWithKey: File = null
-
+ var partitionedTableDirWithComplexTypes: File = null
+ var partitionedTableDirWithKeyAndComplexTypes: File = null
override def beforeAll(): Unit = {
partitionedTableDir = File.createTempFile("parquettests", "sparksql")
@@ -509,9 +582,45 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
.toDF()
.saveAsParquetFile(partDir.getCanonicalPath)
}
+
+ partitionedTableDirWithKeyAndComplexTypes = File.createTempFile("parquettests", "sparksql")
+ partitionedTableDirWithKeyAndComplexTypes.delete()
+ partitionedTableDirWithKeyAndComplexTypes.mkdir()
+
+ (1 to 10).foreach { p =>
+ val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p")
+ sparkContext.makeRDD(1 to 10).map { i =>
+ ParquetDataWithKeyAndComplexTypes(
+ p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i)
+ }.toDF().saveAsParquetFile(partDir.getCanonicalPath)
+ }
+
+ partitionedTableDirWithComplexTypes = File.createTempFile("parquettests", "sparksql")
+ partitionedTableDirWithComplexTypes.delete()
+ partitionedTableDirWithComplexTypes.mkdir()
+
+ (1 to 10).foreach { p =>
+ val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p")
+ sparkContext.makeRDD(1 to 10).map { i =>
+ ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i)
+ }.toDF().saveAsParquetFile(partDir.getCanonicalPath)
+ }
+ }
+
+ override protected def afterAll(): Unit = {
+ partitionedTableDir.delete()
+ normalTableDir.delete()
+ partitionedTableDirWithKey.delete()
+ partitionedTableDirWithComplexTypes.delete()
+ partitionedTableDirWithKeyAndComplexTypes.delete()
}
- Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table =>
+ Seq(
+ "partitioned_parquet",
+ "partitioned_parquet_with_key",
+ "partitioned_parquet_with_complextypes",
+ "partitioned_parquet_with_key_and_complextypes").foreach { table =>
+
test(s"ordering of the partitioning columns $table") {
checkAnswer(
sql(s"SELECT p, stringField FROM $table WHERE p = 1"),
@@ -601,6 +710,25 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
}
}
+ Seq(
+ "partitioned_parquet_with_key_and_complextypes",
+ "partitioned_parquet_with_complextypes").foreach { table =>
+
+ test(s"SPARK-5775 read struct from $table") {
+ checkAnswer(
+ sql(s"SELECT p, structField.intStructField, structField.stringStructField FROM $table WHERE p = 1"),
+ (1 to 10).map(i => Row(1, i, f"${i}_string")))
+ }
+
+ // Re-enable this after SPARK-5508 is fixed
+ ignore(s"SPARK-5775 read array from $table") {
+ checkAnswer(
+ sql(s"SELECT arrayField, p FROM $table WHERE p = 1"),
+ (1 to 10).map(i => Row(1 to i, 1)))
+ }
+ }
+
+
test("non-part select(*)") {
checkAnswer(
sql("SELECT COUNT(*) FROM normal_parquet"),
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 20fc19166ac4e..e966bfba7bb7d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -68,8 +68,8 @@ private[spark] class ApplicationMaster(
@volatile private var finalMsg: String = ""
@volatile private var userClassThread: Thread = _
- private var reporterThread: Thread = _
- private var allocator: YarnAllocator = _
+ @volatile private var reporterThread: Thread = _
+ @volatile private var allocator: YarnAllocator = _
// Fields used in client mode.
private var actorSystem: ActorSystem = null
@@ -486,11 +486,10 @@ private[spark] class ApplicationMaster(
case _: InterruptedException =>
// Reporter thread can interrupt to stop user class
case cause: Throwable =>
+ logError("User class threw exception: " + cause.getMessage, cause)
finish(FinalApplicationStatus.FAILED,
ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
"User class threw exception: " + cause.getMessage)
- // re-throw to get it logged
- throw cause
}
}
}