diff --git a/.rat-excludes b/.rat-excludes
index 769defbac11b7..8c61e67a0c7d1 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -1,4 +1,5 @@
target
+cache
.gitignore
.gitattributes
.project
@@ -18,6 +19,7 @@ fairscheduler.xml.template
spark-defaults.conf.template
log4j.properties
log4j.properties.template
+metrics.properties
metrics.properties.template
slaves
slaves.template
diff --git a/LICENSE b/LICENSE
index 0a42d389e4c3c..9b364a4d00079 100644
--- a/LICENSE
+++ b/LICENSE
@@ -771,6 +771,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
+========================================================================
+For TestTimSort (core/src/test/java/org/apache/spark/util/collection/TestTimSort.java):
+========================================================================
+Copyright (C) 2015 Stijn de Gouw
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
========================================================================
For LimitedInputStream
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 301ff69c2ae3b..3d1ed0dd8a7bd 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -36,10 +36,6 @@
scala-${scala.binary.version}spark-assembly-${project.version}-hadoop${hadoop.version}.jar${project.build.directory}/${spark.jar.dir}/${spark.jar.basename}
- spark
- /usr/share/spark
- root
- 755
@@ -118,6 +114,16 @@
META-INF/*.RSA
+
+
+ org.jblas:jblas
+
+
+ lib/static/Linux/i386/**
+ lib/static/Mac OS X/**
+ lib/static/Windows/**
+
+
@@ -217,113 +223,6 @@
-
- deb
-
-
-
- org.codehaus.mojo
- buildnumber-maven-plugin
- 1.2
-
-
- validate
-
- create
-
-
- 8
-
-
-
-
-
- org.vafer
- jdeb
- 0.11
-
-
- package
-
- jdeb
-
-
- ${project.build.directory}/${deb.pkg.name}_${project.version}-${buildNumber}_all.deb
- false
- gzip
-
-
- ${spark.jar}
- file
-
- perm
- ${deb.user}
- ${deb.user}
- ${deb.install.path}/jars
-
-
-
- ${basedir}/src/deb/RELEASE
- file
-
- perm
- ${deb.user}
- ${deb.user}
- ${deb.install.path}
-
-
-
- ${basedir}/../conf
- directory
-
- perm
- ${deb.user}
- ${deb.user}
- ${deb.install.path}/conf
- ${deb.bin.filemode}
-
-
-
- ${basedir}/../bin
- directory
-
- perm
- ${deb.user}
- ${deb.user}
- ${deb.install.path}/bin
- ${deb.bin.filemode}
-
-
-
- ${basedir}/../sbin
- directory
-
- perm
- ${deb.user}
- ${deb.user}
- ${deb.install.path}/sbin
- ${deb.bin.filemode}
-
-
-
- ${basedir}/../python
- directory
-
- perm
- ${deb.user}
- ${deb.user}
- ${deb.install.path}/python
- ${deb.bin.filemode}
-
-
-
-
-
-
-
-
-
- kinesis-asl
diff --git a/assembly/src/deb/RELEASE b/assembly/src/deb/RELEASE
deleted file mode 100644
index aad50ee73aa45..0000000000000
--- a/assembly/src/deb/RELEASE
+++ /dev/null
@@ -1,2 +0,0 @@
-compute-classpath.sh uses the existence of this file to decide whether to put the assembly jar on the
-classpath or instead to use classfiles in the source tree.
\ No newline at end of file
diff --git a/assembly/src/deb/control/control b/assembly/src/deb/control/control
deleted file mode 100644
index a6b4471d485f4..0000000000000
--- a/assembly/src/deb/control/control
+++ /dev/null
@@ -1,8 +0,0 @@
-Package: [[deb.pkg.name]]
-Version: [[version]]-[[buildNumber]]
-Section: misc
-Priority: extra
-Architecture: all
-Maintainer: Matei Zaharia
-Description: [[name]]
-Distribution: development
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index a8c344b1ca594..f4f6b7b909490 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -76,7 +76,7 @@ fi
num_jars=0
-for f in ${assembly_folder}/spark-assembly*hadoop*.jar; do
+for f in "${assembly_folder}"/spark-assembly*hadoop*.jar; do
if [[ ! -e "$f" ]]; then
echo "Failed to find Spark assembly in $assembly_folder" 1>&2
echo "You need to build Spark before running this program." 1>&2
@@ -88,7 +88,7 @@ done
if [ "$num_jars" -gt "1" ]; then
echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2
- ls ${assembly_folder}/spark-assembly*hadoop*.jar 1>&2
+ ls "${assembly_folder}"/spark-assembly*hadoop*.jar 1>&2
echo "Please remove all but one jar." 1>&2
exit 1
fi
diff --git a/bin/run-example b/bin/run-example
index c567acf9a6b5c..a106411392e06 100755
--- a/bin/run-example
+++ b/bin/run-example
@@ -42,7 +42,7 @@ fi
JAR_COUNT=0
-for f in ${JAR_PATH}/spark-examples-*hadoop*.jar; do
+for f in "${JAR_PATH}"/spark-examples-*hadoop*.jar; do
if [[ ! -e "$f" ]]; then
echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2
echo "You need to build Spark before running this program" 1>&2
@@ -54,7 +54,7 @@ done
if [ "$JAR_COUNT" -gt "1" ]; then
echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2
- ls ${JAR_PATH}/spark-examples-*hadoop*.jar 1>&2
+ ls "${JAR_PATH}"/spark-examples-*hadoop*.jar 1>&2
echo "Please remove all but one jar." 1>&2
exit 1
fi
diff --git a/bin/utils.sh b/bin/utils.sh
index 2241200082018..748dbe345a74c 100755
--- a/bin/utils.sh
+++ b/bin/utils.sh
@@ -35,7 +35,8 @@ function gatherSparkSubmitOpts() {
--master | --deploy-mode | --class | --name | --jars | --packages | --py-files | --files | \
--conf | --repositories | --properties-file | --driver-memory | --driver-java-options | \
--driver-library-path | --driver-class-path | --executor-memory | --driver-cores | \
- --total-executor-cores | --executor-cores | --queue | --num-executors | --archives)
+ --total-executor-cores | --executor-cores | --queue | --num-executors | --archives | \
+ --proxy-user)
if [[ $# -lt 2 ]]; then
"$SUBMIT_USAGE_FUNCTION"
exit 1;
diff --git a/bin/windows-utils.cmd b/bin/windows-utils.cmd
index 567b8733f7f77..0cf9e87ca554b 100644
--- a/bin/windows-utils.cmd
+++ b/bin/windows-utils.cmd
@@ -33,6 +33,7 @@ SET opts="%opts:~1,-1% \<--conf\> \<--properties-file\> \<--driver-memory\> \<--
SET opts="%opts:~1,-1% \<--driver-library-path\> \<--driver-class-path\> \<--executor-memory\>"
SET opts="%opts:~1,-1% \<--driver-cores\> \<--total-executor-cores\> \<--executor-cores\> \<--queue\>"
SET opts="%opts:~1,-1% \<--num-executors\> \<--archives\> \<--packages\> \<--repositories\>"
+SET opts="%opts:~1,-1% \<--proxy-user\>"
echo %1 | findstr %opts% >nul
if %ERRORLEVEL% equ 0 (
diff --git a/build/mvn b/build/mvn
index 53babf54debb6..3561110a4c019 100755
--- a/build/mvn
+++ b/build/mvn
@@ -21,6 +21,8 @@
_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
# Preserve the calling directory
_CALLING_DIR="$(pwd)"
+# Options used during compilation
+_COMPILE_JVM_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"
# Installs any application tarball given a URL, the expected tarball name,
# and, optionally, a checkable binary path to determine if the binary has
@@ -136,6 +138,7 @@ cd "${_CALLING_DIR}"
# Now that zinc is ensured to be installed, check its status and, if its
# not running or just installed, start it
if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then
+ export ZINC_OPTS=${ZINC_OPTS:-"$_COMPILE_JVM_OPTS"}
${ZINC_BIN} -shutdown
${ZINC_BIN} -start -port ${ZINC_PORT} \
-scala-compiler "${SCALA_COMPILER}" \
@@ -143,7 +146,7 @@ if [ -n "${ZINC_INSTALL_FLAG}" -o -z "`${ZINC_BIN} -status`" ]; then
fi
# Set any `mvn` options if not already present
-export MAVEN_OPTS=${MAVEN_OPTS:-"-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m"}
+export MAVEN_OPTS=${MAVEN_OPTS:-"$_COMPILE_JVM_OPTS"}
# Last, call the `mvn` command as usual
${MVN_BIN} "$@"
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index 464c14457e53f..2e0cb5db170ac 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -122,6 +122,15 @@
#worker.sink.csv.unit=minutes
+# Enable Slf4jSink for all instances by class name
+#*.sink.slf4j.class=org.apache.spark.metrics.sink.Slf4jSink
+
+# Polling period for Slf4JSink
+#*.sink.sl4j.period=1
+
+#*.sink.sl4j.unit=minutes
+
+
# Enable jvm source for instance master, worker, driver and executor
#master.source.jvm.class=org.apache.spark.metrics.source.JvmSource
diff --git a/core/pom.xml b/core/pom.xml
index 66180035e61f1..c993781c0e0d6 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -329,16 +329,6 @@
scalacheck_${scala.binary.version}test
-
- org.easymock
- easymockclassextension
- test
-
-
- asm
- asm
- test
- junitjunit
diff --git a/core/src/main/java/org/apache/spark/util/collection/TimSort.java b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
index 409e1a41c5d49..a90cc0e761f62 100644
--- a/core/src/main/java/org/apache/spark/util/collection/TimSort.java
+++ b/core/src/main/java/org/apache/spark/util/collection/TimSort.java
@@ -425,15 +425,14 @@ private void pushRun(int runBase, int runLen) {
private void mergeCollapse() {
while (stackSize > 1) {
int n = stackSize - 2;
- if (n > 0 && runLen[n-1] <= runLen[n] + runLen[n+1]) {
+ if ( (n >= 1 && runLen[n-1] <= runLen[n] + runLen[n+1])
+ || (n >= 2 && runLen[n-2] <= runLen[n] + runLen[n-1])) {
if (runLen[n - 1] < runLen[n + 1])
n--;
- mergeAt(n);
- } else if (runLen[n] <= runLen[n + 1]) {
- mergeAt(n);
- } else {
+ } else if (runLen[n] > runLen[n + 1]) {
break; // Invariant is established
}
+ mergeAt(n);
}
}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index 68b33b5f0d7c7..6c37cc8b98236 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -196,7 +196,7 @@ span.additional-metric-title {
/* Hide all additional metrics by default. This is done here rather than using JavaScript to
* avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
-.scheduler_delay, .deserialization_time, .fetch_wait_time, .serialization_time,
-.getting_result_time {
+.scheduler_delay, .deserialization_time, .fetch_wait_time, .shuffle_read_remote,
+.serialization_time, .getting_result_time {
display: none;
}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 5f31bfba3f8d6..30f0ccd73ccca 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -23,6 +23,7 @@ import java.lang.ThreadLocal
import scala.collection.generic.Growable
import scala.collection.mutable.Map
+import scala.ref.WeakReference
import scala.reflect.ClassTag
import org.apache.spark.serializer.JavaSerializer
@@ -280,10 +281,12 @@ object AccumulatorParam {
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
private[spark] object Accumulators {
- // TODO: Use soft references? => need to make readObject work properly then
- val originals = Map[Long, Accumulable[_, _]]()
- val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
- override protected def initialValue() = Map[Long, Accumulable[_, _]]()
+ // Store a WeakReference instead of a StrongReference because this way accumulators can be
+ // appropriately garbage collected during long-running jobs and release memory
+ type WeakAcc = WeakReference[Accumulable[_, _]]
+ val originals = Map[Long, WeakAcc]()
+ val localAccums = new ThreadLocal[Map[Long, WeakAcc]]() {
+ override protected def initialValue() = Map[Long, WeakAcc]()
}
var lastId: Long = 0
@@ -294,9 +297,9 @@ private[spark] object Accumulators {
def register(a: Accumulable[_, _], original: Boolean): Unit = synchronized {
if (original) {
- originals(a.id) = a
+ originals(a.id) = new WeakAcc(a)
} else {
- localAccums.get()(a.id) = a
+ localAccums.get()(a.id) = new WeakAcc(a)
}
}
@@ -307,11 +310,22 @@ private[spark] object Accumulators {
}
}
+ def remove(accId: Long) {
+ synchronized {
+ originals.remove(accId)
+ }
+ }
+
// Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
for ((id, accum) <- localAccums.get) {
- ret(id) = accum.localValue
+ // Since we are now storing weak references, we must check whether the underlying data
+ // is valid.
+ ret(id) = accum.get match {
+ case Some(values) => values.localValue
+ case None => None
+ }
}
return ret
}
@@ -320,7 +334,13 @@ private[spark] object Accumulators {
def add(values: Map[Long, Any]): Unit = synchronized {
for ((id, value) <- values) {
if (originals.contains(id)) {
- originals(id).asInstanceOf[Accumulable[Any, Any]] ++= value
+ // Since we are now storing weak references, we must check whether the underlying data
+ // is valid.
+ originals(id).get match {
+ case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]] ++= value
+ case None =>
+ throw new IllegalAccessError("Attempted to access garbage collected Accumulator.")
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
index ede1e23f4fcc5..434f1e47cf822 100644
--- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala
@@ -32,6 +32,7 @@ private sealed trait CleanupTask
private case class CleanRDD(rddId: Int) extends CleanupTask
private case class CleanShuffle(shuffleId: Int) extends CleanupTask
private case class CleanBroadcast(broadcastId: Long) extends CleanupTask
+private case class CleanAccum(accId: Long) extends CleanupTask
/**
* A WeakReference associated with a CleanupTask.
@@ -114,6 +115,10 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
registerForCleanup(rdd, CleanRDD(rdd.id))
}
+ def registerAccumulatorForCleanup(a: Accumulable[_, _]): Unit = {
+ registerForCleanup(a, CleanAccum(a.id))
+ }
+
/** Register a ShuffleDependency for cleanup when it is garbage collected. */
def registerShuffleForCleanup(shuffleDependency: ShuffleDependency[_, _, _]) {
registerForCleanup(shuffleDependency, CleanShuffle(shuffleDependency.shuffleId))
@@ -145,6 +150,8 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
doCleanupShuffle(shuffleId, blocking = blockOnShuffleCleanupTasks)
case CleanBroadcast(broadcastId) =>
doCleanupBroadcast(broadcastId, blocking = blockOnCleanupTasks)
+ case CleanAccum(accId) =>
+ doCleanupAccum(accId, blocking = blockOnCleanupTasks)
}
}
} catch {
@@ -190,6 +197,18 @@ private[spark] class ContextCleaner(sc: SparkContext) extends Logging {
}
}
+ /** Perform accumulator cleanup. */
+ def doCleanupAccum(accId: Long, blocking: Boolean) {
+ try {
+ logDebug("Cleaning accumulator " + accId)
+ Accumulators.remove(accId)
+ listeners.foreach(_.accumCleaned(accId))
+ logInfo("Cleaned accumulator " + accId)
+ } catch {
+ case e: Exception => logError("Error cleaning accumulator " + accId, e)
+ }
+ }
+
private def blockManagerMaster = sc.env.blockManager.master
private def broadcastManager = sc.env.broadcastManager
private def mapOutputTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
@@ -206,4 +225,5 @@ private[spark] trait CleanerListener {
def rddCleaned(rddId: Int)
def shuffleCleaned(shuffleId: Int)
def broadcastCleaned(broadcastId: Long)
+ def accumCleaned(accId: Long)
}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
index a46a81eabd965..443830f8d03b6 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -19,24 +19,32 @@ package org.apache.spark
/**
* A client that communicates with the cluster manager to request or kill executors.
+ * This is currently supported only in YARN mode.
*/
private[spark] trait ExecutorAllocationClient {
+ /**
+ * Express a preference to the cluster manager for a given total number of executors.
+ * This can result in canceling pending requests or filing additional requests.
+ * @return whether the request is acknowledged by the cluster manager.
+ */
+ private[spark] def requestTotalExecutors(numExecutors: Int): Boolean
+
/**
* Request an additional number of executors from the cluster manager.
- * Return whether the request is acknowledged by the cluster manager.
+ * @return whether the request is acknowledged by the cluster manager.
*/
def requestExecutors(numAdditionalExecutors: Int): Boolean
/**
* Request that the cluster manager kill the specified executors.
- * Return whether the request is acknowledged by the cluster manager.
+ * @return whether the request is acknowledged by the cluster manager.
*/
def killExecutors(executorIds: Seq[String]): Boolean
/**
* Request that the cluster manager kill the specified executor.
- * Return whether the request is acknowledged by the cluster manager.
+ * @return whether the request is acknowledged by the cluster manager.
*/
def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId))
}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index 02d54bf3b53cc..21c6e6ffa6666 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -20,6 +20,7 @@ package org.apache.spark
import scala.collection.mutable
import org.apache.spark.scheduler._
+import org.apache.spark.util.{SystemClock, Clock}
/**
* An agent that dynamically allocates and removes executors based on the workload.
@@ -123,7 +124,7 @@ private[spark] class ExecutorAllocationManager(
private val intervalMillis: Long = 100
// Clock used to schedule when executors should be added and removed
- private var clock: Clock = new RealClock
+ private var clock: Clock = new SystemClock()
// Listener for Spark events that impact the allocation policy
private val listener = new ExecutorAllocationListener
@@ -201,18 +202,34 @@ private[spark] class ExecutorAllocationManager(
}
/**
- * If the add time has expired, request new executors and refresh the add time.
- * If the remove time for an existing executor has expired, kill the executor.
+ * The number of executors we would have if the cluster manager were to fulfill all our existing
+ * requests.
+ */
+ private def targetNumExecutors(): Int =
+ numExecutorsPending + executorIds.size - executorsPendingToRemove.size
+
+ /**
+ * The maximum number of executors we would need under the current load to satisfy all running
+ * and pending tasks, rounded up.
+ */
+ private def maxNumExecutorsNeeded(): Int = {
+ val numRunningOrPendingTasks = listener.totalPendingTasks + listener.totalRunningTasks
+ (numRunningOrPendingTasks + tasksPerExecutor - 1) / tasksPerExecutor
+ }
+
+ /**
+ * This is called at a fixed interval to regulate the number of pending executor requests
+ * and number of executors running.
+ *
+ * First, adjust our requested executors based on the add time and our current needs.
+ * Then, if the remove time for an existing executor has expired, kill the executor.
+ *
* This is factored out into its own method for testing.
*/
private def schedule(): Unit = synchronized {
val now = clock.getTimeMillis
- if (addTime != NOT_SET && now >= addTime) {
- addExecutors()
- logDebug(s"Starting timer to add more executors (to " +
- s"expire in $sustainedSchedulerBacklogTimeout seconds)")
- addTime += sustainedSchedulerBacklogTimeout * 1000
- }
+
+ addOrCancelExecutorRequests(now)
removeTimes.retain { case (executorId, expireTime) =>
val expired = now >= expireTime
@@ -223,59 +240,89 @@ private[spark] class ExecutorAllocationManager(
}
}
+ /**
+ * Check to see whether our existing allocation and the requests we've made previously exceed our
+ * current needs. If so, let the cluster manager know so that it can cancel pending requests that
+ * are unneeded.
+ *
+ * If not, and the add time has expired, see if we can request new executors and refresh the add
+ * time.
+ *
+ * @return the delta in the target number of executors.
+ */
+ private def addOrCancelExecutorRequests(now: Long): Int = synchronized {
+ val currentTarget = targetNumExecutors
+ val maxNeeded = maxNumExecutorsNeeded
+
+ if (maxNeeded < currentTarget) {
+ // The target number exceeds the number we actually need, so stop adding new
+ // executors and inform the cluster manager to cancel the extra pending requests.
+ val newTotalExecutors = math.max(maxNeeded, minNumExecutors)
+ client.requestTotalExecutors(newTotalExecutors)
+ numExecutorsToAdd = 1
+ updateNumExecutorsPending(newTotalExecutors)
+ } else if (addTime != NOT_SET && now >= addTime) {
+ val delta = addExecutors(maxNeeded)
+ logDebug(s"Starting timer to add more executors (to " +
+ s"expire in $sustainedSchedulerBacklogTimeout seconds)")
+ addTime += sustainedSchedulerBacklogTimeout * 1000
+ delta
+ } else {
+ 0
+ }
+ }
+
/**
* Request a number of executors from the cluster manager.
* If the cap on the number of executors is reached, give up and reset the
* number of executors to add next round instead of continuing to double it.
- * Return the number actually requested.
+ *
+ * @param maxNumExecutorsNeeded the maximum number of executors all currently running or pending
+ * tasks could fill
+ * @return the number of additional executors actually requested.
*/
- private def addExecutors(): Int = synchronized {
- // Do not request more executors if we have already reached the upper bound
- val numExistingExecutors = executorIds.size + numExecutorsPending
- if (numExistingExecutors >= maxNumExecutors) {
+ private def addExecutors(maxNumExecutorsNeeded: Int): Int = {
+ // Do not request more executors if it would put our target over the upper bound
+ val currentTarget = targetNumExecutors
+ if (currentTarget >= maxNumExecutors) {
logDebug(s"Not adding executors because there are already ${executorIds.size} " +
s"registered and $numExecutorsPending pending executor(s) (limit $maxNumExecutors)")
numExecutorsToAdd = 1
return 0
}
- // The number of executors needed to satisfy all pending tasks is the number of tasks pending
- // divided by the number of tasks each executor can fit, rounded up.
- val maxNumExecutorsPending =
- (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor
- if (numExecutorsPending >= maxNumExecutorsPending) {
- logDebug(s"Not adding executors because there are already $numExecutorsPending " +
- s"pending and pending tasks could only fill $maxNumExecutorsPending")
- numExecutorsToAdd = 1
- return 0
- }
-
- // It's never useful to request more executors than could satisfy all the pending tasks, so
- // cap request at that amount.
- // Also cap request with respect to the configured upper bound.
- val maxNumExecutorsToAdd = math.min(
- maxNumExecutorsPending - numExecutorsPending,
- maxNumExecutors - numExistingExecutors)
- assert(maxNumExecutorsToAdd > 0)
-
- val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd)
-
- val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd
- val addRequestAcknowledged = testing || client.requestExecutors(actualNumExecutorsToAdd)
+ val actualMaxNumExecutors = math.min(maxNumExecutors, maxNumExecutorsNeeded)
+ val newTotalExecutors = math.min(currentTarget + numExecutorsToAdd, actualMaxNumExecutors)
+ val addRequestAcknowledged = testing || client.requestTotalExecutors(newTotalExecutors)
if (addRequestAcknowledged) {
- logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " +
- s"tasks are backlogged (new desired total will be $newTotalExecutors)")
- numExecutorsToAdd =
- if (actualNumExecutorsToAdd == numExecutorsToAdd) numExecutorsToAdd * 2 else 1
- numExecutorsPending += actualNumExecutorsToAdd
- actualNumExecutorsToAdd
+ val delta = updateNumExecutorsPending(newTotalExecutors)
+ logInfo(s"Requesting $delta new executor(s) because tasks are backlogged" +
+ s" (new desired total will be $newTotalExecutors)")
+ numExecutorsToAdd = if (delta == numExecutorsToAdd) {
+ numExecutorsToAdd * 2
+ } else {
+ 1
+ }
+ delta
} else {
- logWarning(s"Unable to reach the cluster manager " +
- s"to request $actualNumExecutorsToAdd executors!")
+ logWarning(
+ s"Unable to reach the cluster manager to request $newTotalExecutors total executors!")
0
}
}
+ /**
+ * Given the new target number of executors, update the number of pending executor requests,
+ * and return the delta from the old number of pending requests.
+ */
+ private def updateNumExecutorsPending(newTotalExecutors: Int): Int = {
+ val newNumExecutorsPending =
+ newTotalExecutors - executorIds.size + executorsPendingToRemove.size
+ val delta = newNumExecutorsPending - numExecutorsPending
+ numExecutorsPending = newNumExecutorsPending
+ delta
+ }
+
/**
* Request the cluster manager to remove the given executor.
* Return whether the request is received.
@@ -415,6 +462,8 @@ private[spark] class ExecutorAllocationManager(
private val stageIdToNumTasks = new mutable.HashMap[Int, Int]
private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]]
+ // Number of tasks currently running on the cluster. Should be 0 when no stages are active.
+ private var numRunningTasks: Int = _
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
val stageId = stageSubmitted.stageInfo.stageId
@@ -435,6 +484,10 @@ private[spark] class ExecutorAllocationManager(
// This is needed in case the stage is aborted for any reason
if (stageIdToNumTasks.isEmpty) {
allocationManager.onSchedulerQueueEmpty()
+ if (numRunningTasks != 0) {
+ logWarning("No stages are running, but numRunningTasks != 0")
+ numRunningTasks = 0
+ }
}
}
}
@@ -446,6 +499,7 @@ private[spark] class ExecutorAllocationManager(
val executorId = taskStart.taskInfo.executorId
allocationManager.synchronized {
+ numRunningTasks += 1
// This guards against the race condition in which the `SparkListenerTaskStart`
// event is posted before the `SparkListenerBlockManagerAdded` event, which is
// possible because these events are posted in different threads. (see SPARK-4951)
@@ -475,7 +529,8 @@ private[spark] class ExecutorAllocationManager(
val executorId = taskEnd.taskInfo.executorId
val taskId = taskEnd.taskInfo.taskId
allocationManager.synchronized {
- // If the executor is no longer running scheduled any tasks, mark it as idle
+ numRunningTasks -= 1
+ // If the executor is no longer running any scheduled tasks, mark it as idle
if (executorIdToTaskIds.contains(executorId)) {
executorIdToTaskIds(executorId) -= taskId
if (executorIdToTaskIds(executorId).isEmpty) {
@@ -514,6 +569,11 @@ private[spark] class ExecutorAllocationManager(
}.sum
}
+ /**
+ * The number of tasks currently running across all stages.
+ */
+ def totalRunningTasks(): Int = numRunningTasks
+
/**
* Return true if an executor is not currently running a task, and false otherwise.
*
@@ -529,28 +589,3 @@ private[spark] class ExecutorAllocationManager(
private object ExecutorAllocationManager {
val NOT_SET = Long.MaxValue
}
-
-/**
- * An abstract clock for measuring elapsed time.
- */
-private trait Clock {
- def getTimeMillis: Long
-}
-
-/**
- * A clock backed by a monotonically increasing time source.
- * The time returned by this clock does not correspond to any notion of wall-clock time.
- */
-private class RealClock extends Clock {
- override def getTimeMillis: Long = System.nanoTime / (1000 * 1000)
-}
-
-/**
- * A clock that allows the caller to customize the time.
- * This is used mainly for testing.
- */
-private class TestClock(startTimeMillis: Long) extends Clock {
- private var time: Long = startTimeMillis
- override def getTimeMillis: Long = time
- def tick(ms: Long): Unit = { time += ms }
-}
diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
index 83ae57b7f1516..69178da1a7773 100644
--- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
+++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala
@@ -17,33 +17,86 @@
package org.apache.spark
-import akka.actor.Actor
+import scala.concurrent.duration._
+import scala.collection.mutable
+
+import akka.actor.{Actor, Cancellable}
+
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.storage.BlockManagerId
-import org.apache.spark.scheduler.TaskScheduler
+import org.apache.spark.scheduler.{SlaveLost, TaskScheduler}
import org.apache.spark.util.ActorLogReceive
/**
* A heartbeat from executors to the driver. This is a shared message used by several internal
- * components to convey liveness or execution information for in-progress tasks.
+ * components to convey liveness or execution information for in-progress tasks. It will also
+ * expire the hosts that have not heartbeated for more than spark.network.timeout.
*/
private[spark] case class Heartbeat(
executorId: String,
taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
blockManagerId: BlockManagerId)
+private[spark] case object ExpireDeadHosts
+
private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean)
/**
* Lives in the driver to receive heartbeats from executors..
*/
-private[spark] class HeartbeatReceiver(scheduler: TaskScheduler)
+private[spark] class HeartbeatReceiver(sc: SparkContext, scheduler: TaskScheduler)
extends Actor with ActorLogReceive with Logging {
+ // executor ID -> timestamp of when the last heartbeat from this executor was received
+ private val executorLastSeen = new mutable.HashMap[String, Long]
+
+ private val executorTimeoutMs = sc.conf.getLong("spark.network.timeout",
+ sc.conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120)) * 1000
+
+ private val checkTimeoutIntervalMs = sc.conf.getLong("spark.network.timeoutInterval",
+ sc.conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60)) * 1000
+
+ private var timeoutCheckingTask: Cancellable = null
+
+ override def preStart(): Unit = {
+ import context.dispatcher
+ timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
+ checkTimeoutIntervalMs.milliseconds, self, ExpireDeadHosts)
+ super.preStart()
+ }
+
override def receiveWithLogging = {
case Heartbeat(executorId, taskMetrics, blockManagerId) =>
- val response = HeartbeatResponse(
- !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId))
+ val unknownExecutor = !scheduler.executorHeartbeatReceived(
+ executorId, taskMetrics, blockManagerId)
+ val response = HeartbeatResponse(reregisterBlockManager = unknownExecutor)
+ executorLastSeen(executorId) = System.currentTimeMillis()
sender ! response
+ case ExpireDeadHosts =>
+ expireDeadHosts()
+ }
+
+ private def expireDeadHosts(): Unit = {
+ logTrace("Checking for hosts with no recent heartbeats in HeartbeatReceiver.")
+ val now = System.currentTimeMillis()
+ for ((executorId, lastSeenMs) <- executorLastSeen) {
+ if (now - lastSeenMs > executorTimeoutMs) {
+ logWarning(s"Removing executor $executorId with no recent heartbeats: " +
+ s"${now - lastSeenMs} ms exceeds timeout $executorTimeoutMs ms")
+ scheduler.executorLost(executorId, SlaveLost("Executor heartbeat " +
+ "timed out after ${now - lastSeenMs} ms"))
+ if (sc.supportDynamicAllocation) {
+ sc.killExecutor(executorId)
+ }
+ executorLastSeen.remove(executorId)
+ }
+ }
+ }
+
+ override def postStop(): Unit = {
+ if (timeoutCheckingTask != null) {
+ timeoutCheckingTask.cancel()
+ }
+ super.postStop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index 3f33332a81eaf..7e706bcc42f04 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -50,6 +50,15 @@ private[spark] class HttpFileServer(
def stop() {
httpServer.stop()
+
+ // If we only stop sc, but the driver process still run as a services then we need to delete
+ // the tmp dir, if not, it will create too many tmp dirs
+ try {
+ Utils.deleteRecursively(baseDir)
+ } catch {
+ case e: Exception =>
+ logWarning(s"Exception while deleting Spark temp dir: ${baseDir.getAbsolutePath}", e)
+ }
}
def addFile(file: File) : String = {
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 88d35a4bacc6e..3653f724ba192 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -27,6 +27,7 @@ import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.network.sasl.SecretKeyHolder
+import org.apache.spark.util.Utils
/**
* Spark class responsible for security.
@@ -203,7 +204,7 @@ private[spark] class SecurityManager(sparkConf: SparkConf)
// always add the current user and SPARK_USER to the viewAcls
private val defaultAclUsers = Set[String](System.getProperty("user.name", ""),
- Option(System.getenv("SPARK_USER")).getOrElse("")).filter(!_.isEmpty)
+ Utils.getCurrentUserName())
setViewAcls(defaultAclUsers, sparkConf.get("spark.ui.view.acls", ""))
setModifyAcls(defaultAclUsers, sparkConf.get("spark.modify.acls", ""))
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index 13aa9960ac33a..61b34d524a421 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -18,6 +18,7 @@
package org.apache.spark
import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicBoolean
import scala.collection.JavaConverters._
import scala.collection.mutable.LinkedHashSet
@@ -67,7 +68,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
if (value == null) {
throw new NullPointerException("null value for " + key)
}
- settings.put(key, value)
+ settings.put(translateConfKey(key, warn = true), value)
this
}
@@ -139,7 +140,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Set a parameter if it isn't already configured */
def setIfMissing(key: String, value: String): SparkConf = {
- settings.putIfAbsent(key, value)
+ settings.putIfAbsent(translateConfKey(key, warn = true), value)
this
}
@@ -175,7 +176,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
- Option(settings.get(key))
+ Option(settings.get(translateConfKey(key)))
}
/** Get all parameters as a list of pairs */
@@ -228,7 +229,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def getAppId: String = get("spark.app.id")
/** Does the configuration contain a given parameter? */
- def contains(key: String): Boolean = settings.containsKey(key)
+ def contains(key: String): Boolean = settings.containsKey(translateConfKey(key))
/** Copy this object */
override def clone: SparkConf = {
@@ -285,7 +286,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
// Validate memory fractions
val memoryKeys = Seq(
"spark.storage.memoryFraction",
- "spark.shuffle.memoryFraction",
+ "spark.shuffle.memoryFraction",
"spark.shuffle.safetyFraction",
"spark.storage.unrollFraction",
"spark.storage.safetyFraction")
@@ -351,9 +352,26 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
def toDebugString: String = {
getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
}
+
}
-private[spark] object SparkConf {
+private[spark] object SparkConf extends Logging {
+
+ private val deprecatedConfigs: Map[String, DeprecatedConfig] = {
+ val configs = Seq(
+ DeprecatedConfig("spark.files.userClassPathFirst", "spark.executor.userClassPathFirst",
+ "1.3"),
+ DeprecatedConfig("spark.yarn.user.classpath.first", null, "1.3",
+ "Use spark.{driver,executor}.userClassPathFirst instead."),
+ DeprecatedConfig("spark.history.fs.updateInterval",
+ "spark.history.fs.update.interval.seconds",
+ "1.3", "Use spark.history.fs.update.interval.seconds instead"),
+ DeprecatedConfig("spark.history.updateInterval",
+ "spark.history.fs.update.interval.seconds",
+ "1.3", "Use spark.history.fs.update.interval.seconds instead"))
+ configs.map { x => (x.oldName, x) }.toMap
+ }
+
/**
* Return whether the given config is an akka config (e.g. akka.actor.provider).
* Note that this does not include spark-specific akka configs (e.g. spark.akka.timeout).
@@ -380,4 +398,63 @@ private[spark] object SparkConf {
def isSparkPortConf(name: String): Boolean = {
(name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.")
}
+
+ /**
+ * Translate the configuration key if it is deprecated and has a replacement, otherwise just
+ * returns the provided key.
+ *
+ * @param userKey Configuration key from the user / caller.
+ * @param warn Whether to print a warning if the key is deprecated. Warnings will be printed
+ * only once for each key.
+ */
+ private def translateConfKey(userKey: String, warn: Boolean = false): String = {
+ deprecatedConfigs.get(userKey)
+ .map { deprecatedKey =>
+ if (warn) {
+ deprecatedKey.warn()
+ }
+ deprecatedKey.newName.getOrElse(userKey)
+ }.getOrElse(userKey)
+ }
+
+ /**
+ * Holds information about keys that have been deprecated or renamed.
+ *
+ * @param oldName Old configuration key.
+ * @param newName New configuration key, or `null` if key has no replacement, in which case the
+ * deprecated key will be used (but the warning message will still be printed).
+ * @param version Version of Spark where key was deprecated.
+ * @param deprecationMessage Message to include in the deprecation warning; mandatory when
+ * `newName` is not provided.
+ */
+ private case class DeprecatedConfig(
+ oldName: String,
+ _newName: String,
+ version: String,
+ deprecationMessage: String = null) {
+
+ private val warned = new AtomicBoolean(false)
+ val newName = Option(_newName)
+
+ if (newName == null && (deprecationMessage == null || deprecationMessage.isEmpty())) {
+ throw new IllegalArgumentException("Need new config name or deprecation message.")
+ }
+
+ def warn(): Unit = {
+ if (warned.compareAndSet(false, true)) {
+ if (newName != null) {
+ val message = Option(deprecationMessage).getOrElse(
+ s"Please use the alternative '$newName' instead.")
+ logWarning(
+ s"The configuration option '$oldName' has been replaced as of Spark $version and " +
+ s"may be removed in the future. $message")
+ } else {
+ logWarning(
+ s"The configuration option '$oldName' has been deprecated as of Spark $version and " +
+ s"may be removed in the future. $deprecationMessage")
+ }
+ }
+ }
+
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8d3c3d000adf3..3cd0c218a36fd 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -191,7 +191,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
// log out Spark Version in Spark driver log
logInfo(s"Running Spark version $SPARK_VERSION")
-
+
private[spark] val conf = config.clone()
conf.validateSettings()
@@ -249,7 +249,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
conf.set("spark.executor.id", SparkContext.DRIVER_IDENTIFIER)
// Create the Spark execution environment (cache, map output tracker, etc)
- private[spark] val env = SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
+
+ // This function allows components created by SparkEnv to be mocked in unit tests:
+ private[spark] def createSparkEnv(
+ conf: SparkConf,
+ isLocal: Boolean,
+ listenerBus: LiveListenerBus): SparkEnv = {
+ SparkEnv.createDriverEnv(conf, isLocal, listenerBus)
+ }
+
+ private[spark] val env = createSparkEnv(conf, isLocal, listenerBus)
SparkEnv.set(env)
// Used to store a URL for each static file/jar together with the file's local timestamp
@@ -335,18 +344,14 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
executorEnvs ++= conf.getExecutorEnv
// Set SPARK_USER for user who is running SparkContext.
- val sparkUser = Option {
- Option(System.getenv("SPARK_USER")).getOrElse(System.getProperty("user.name"))
- }.getOrElse {
- SparkContext.SPARK_UNKNOWN_USER
- }
+ val sparkUser = Utils.getCurrentUserName()
executorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler
private[spark] var (schedulerBackend, taskScheduler) =
SparkContext.createTaskScheduler(this, master)
private val heartbeatReceiver = env.actorSystem.actorOf(
- Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver")
+ Props(new HeartbeatReceiver(this, taskScheduler)), "HeartbeatReceiver")
@volatile private[spark] var dagScheduler: DAGScheduler = _
try {
dagScheduler = new DAGScheduler(this)
@@ -393,7 +398,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
private val dynamicAllocationTesting = conf.getBoolean("spark.dynamicAllocation.testing", false)
private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] =
if (dynamicAllocationEnabled) {
- assert(master.contains("yarn") || dynamicAllocationTesting,
+ assert(supportDynamicAllocation,
"Dynamic allocation of executors is currently only supported in YARN mode")
Some(new ExecutorAllocationManager(this, listenerBus, conf))
} else {
@@ -543,6 +548,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
* to parallelize and before the first action on the RDD, the resultant RDD will reflect the
* modified collection. Pass a copy of the argument to avoid this.
+ * @note avoid using `parallelize(Seq())` to create an empty `RDD`. Consider `emptyRDD` for an
+ * RDD with no partitions, or `parallelize(Seq[T]())` for an RDD of `T` with empty partitions.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
assertNotStopped()
@@ -826,7 +833,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
assertNotStopped()
- // The call to new NewHadoopJob automatically adds security credentials to conf,
+ // The call to new NewHadoopJob automatically adds security credentials to conf,
// so we don't need to explicitly add them ourselves
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
@@ -956,11 +963,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/** Build the union of a list of RDDs. */
- def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
+ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = {
+ val partitioners = rdds.flatMap(_.partitioner).toSet
+ if (partitioners.size == 1) {
+ new PartitionerAwareUnionRDD(this, rdds)
+ } else {
+ new UnionRDD(this, rdds)
+ }
+ }
/** Build the union of a list of RDDs passed as variable-length arguments. */
def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] =
- new UnionRDD(this, Seq(first) ++ rest)
+ union(Seq(first) ++ rest)
/** Get an RDD that has no partitions or elements. */
def emptyRDD[T: ClassTag] = new EmptyRDD[T](this)
@@ -972,7 +986,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* values to using the `+=` method. Only the driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T)(implicit param: AccumulatorParam[T]) =
- new Accumulator(initialValue, param)
+ {
+ val acc = new Accumulator(initialValue, param)
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ acc
+ }
/**
* Create an [[org.apache.spark.Accumulator]] variable of a given type, with a name for display
@@ -980,7 +998,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* driver can access the accumulator's `value`.
*/
def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) = {
- new Accumulator(initialValue, param, Some(name))
+ val acc = new Accumulator(initialValue, param, Some(name))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ acc
}
/**
@@ -989,8 +1009,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* @tparam R accumulator result type
* @tparam T type that can be added to the accumulator
*/
- def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) =
- new Accumulable(initialValue, param)
+ def accumulable[R, T](initialValue: R)(implicit param: AccumulableParam[R, T]) = {
+ val acc = new Accumulable(initialValue, param)
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ acc
+ }
/**
* Create an [[org.apache.spark.Accumulable]] shared variable, with a name for display in the
@@ -999,8 +1022,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
* @tparam R accumulator result type
* @tparam T type that can be added to the accumulator
*/
- def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) =
- new Accumulable(initialValue, param, Some(name))
+ def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) = {
+ val acc = new Accumulable(initialValue, param, Some(name))
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ acc
+ }
/**
* Create an accumulator from a "mutable collection" type.
@@ -1011,7 +1037,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
(initialValue: R): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R,T]
- new Accumulable(initialValue, param)
+ val acc = new Accumulable(initialValue, param)
+ cleaner.foreach(_.registerAccumulatorForCleanup(acc))
+ acc
}
/**
@@ -1094,6 +1122,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
postEnvironmentUpdate()
}
+ /**
+ * Return whether dynamically adjusting the amount of resources allocated to
+ * this application is supported. This is currently only available for YARN.
+ */
+ private[spark] def supportDynamicAllocation =
+ master.contains("yarn") || dynamicAllocationTesting
+
/**
* :: DeveloperApi ::
* Register a listener to receive up-calls from events that happen during execution.
@@ -1103,14 +1138,31 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
listenerBus.addListener(listener)
}
+ /**
+ * Express a preference to the cluster manager for a given total number of executors.
+ * This can result in canceling pending requests or filing additional requests.
+ * This is currently only supported in YARN mode. Return whether the request is received.
+ */
+ private[spark] override def requestTotalExecutors(numExecutors: Int): Boolean = {
+ assert(master.contains("yarn") || dynamicAllocationTesting,
+ "Requesting executors is currently only supported in YARN mode")
+ schedulerBackend match {
+ case b: CoarseGrainedSchedulerBackend =>
+ b.requestTotalExecutors(numExecutors)
+ case _ =>
+ logWarning("Requesting executors is only supported in coarse-grained mode")
+ false
+ }
+ }
+
/**
* :: DeveloperApi ::
* Request an additional number of executors from the cluster manager.
- * This is currently only supported in Yarn mode. Return whether the request is received.
+ * This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
- assert(master.contains("yarn") || dynamicAllocationTesting,
+ assert(supportDynamicAllocation,
"Requesting executors is currently only supported in YARN mode")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
@@ -1124,11 +1176,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
/**
* :: DeveloperApi ::
* Request that the cluster manager kill the specified executors.
- * This is currently only supported in Yarn mode. Return whether the request is received.
+ * This is currently only supported in YARN mode. Return whether the request is received.
*/
@DeveloperApi
override def killExecutors(executorIds: Seq[String]): Boolean = {
- assert(master.contains("yarn") || dynamicAllocationTesting,
+ assert(supportDynamicAllocation,
"Killing executors is currently only supported in YARN mode")
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
@@ -1337,16 +1389,17 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
stopped = true
env.metricsSystem.report()
metadataCleaner.cancel()
- env.actorSystem.stop(heartbeatReceiver)
cleaner.foreach(_.stop())
dagScheduler.stop()
dagScheduler = null
+ listenerBus.stop()
+ eventLogger.foreach(_.stop())
+ env.actorSystem.stop(heartbeatReceiver)
+ progressBar.foreach(_.stop())
taskScheduler = null
// TODO: Cache.stop()?
env.stop()
SparkEnv.set(null)
- listenerBus.stop()
- eventLogger.foreach(_.stop())
logInfo("Successfully stopped SparkContext")
SparkContext.clearActiveContext()
} else {
@@ -1609,8 +1662,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
@deprecated("use defaultMinPartitions", "1.0.0")
def defaultMinSplits: Int = math.min(defaultParallelism, 2)
- /**
- * Default min number of partitions for Hadoop RDDs when not given by user
+ /**
+ * Default min number of partitions for Hadoop RDDs when not given by user
* Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2.
* The reasons for this are discussed in https://github.com/mesos/spark/pull/718
*/
@@ -1827,8 +1880,6 @@ object SparkContext extends Logging {
private[spark] val SPARK_JOB_INTERRUPT_ON_CANCEL = "spark.job.interruptOnCancel"
- private[spark] val SPARK_UNKNOWN_USER = ""
-
private[spark] val DRIVER_IDENTIFIER = ""
// The following deprecated objects have already been copied to `object AccumulatorParam` to
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index f25db7f8de565..2a0c7e756dd3a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -34,7 +34,8 @@ import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.network.BlockTransferService
import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.network.nio.NioBlockTransferService
-import org.apache.spark.scheduler.LiveListenerBus
+import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus}
+import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager}
import org.apache.spark.storage._
@@ -67,6 +68,7 @@ class SparkEnv (
val sparkFilesDir: String,
val metricsSystem: MetricsSystem,
val shuffleMemoryManager: ShuffleMemoryManager,
+ val outputCommitCoordinator: OutputCommitCoordinator,
val conf: SparkConf) extends Logging {
private[spark] var isStopped = false
@@ -76,6 +78,8 @@ class SparkEnv (
// (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
private[spark] val hadoopJobMetadata = new MapMaker().softValues().makeMap[String, Any]()
+ private var driverTmpDirToDelete: Option[String] = None
+
private[spark] def stop() {
isStopped = true
pythonWorkers.foreach { case(key, worker) => worker.stop() }
@@ -86,6 +90,7 @@ class SparkEnv (
blockManager.stop()
blockManager.master.stop()
metricsSystem.stop()
+ outputCommitCoordinator.stop()
actorSystem.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
@@ -93,6 +98,22 @@ class SparkEnv (
// actorSystem.awaitTermination()
// Note that blockTransferService is stopped by BlockManager since it is started by it.
+
+ // If we only stop sc, but the driver process still run as a services then we need to delete
+ // the tmp dir, if not, it will create too many tmp dirs.
+ // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the
+ // current working dir in executor which we do not need to delete.
+ driverTmpDirToDelete match {
+ case Some(path) => {
+ try {
+ Utils.deleteRecursively(new File(path))
+ } catch {
+ case e: Exception =>
+ logWarning(s"Exception while deleting Spark temp dir: $path", e)
+ }
+ }
+ case None => // We just need to delete tmp dir created by driver, so do nothing on executor
+ }
}
private[spark]
@@ -151,7 +172,8 @@ object SparkEnv extends Logging {
private[spark] def createDriverEnv(
conf: SparkConf,
isLocal: Boolean,
- listenerBus: LiveListenerBus): SparkEnv = {
+ listenerBus: LiveListenerBus,
+ mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
assert(conf.contains("spark.driver.host"), "spark.driver.host is not set on the driver!")
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
val hostname = conf.get("spark.driver.host")
@@ -163,7 +185,8 @@ object SparkEnv extends Logging {
port,
isDriver = true,
isLocal = isLocal,
- listenerBus = listenerBus
+ listenerBus = listenerBus,
+ mockOutputCommitCoordinator = mockOutputCommitCoordinator
)
}
@@ -202,7 +225,8 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean,
listenerBus: LiveListenerBus = null,
- numUsableCores: Int = 0): SparkEnv = {
+ numUsableCores: Int = 0,
+ mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
// Listener bus is only used on the driver
if (isDriver) {
@@ -350,7 +374,14 @@ object SparkEnv extends Logging {
"levels using the RDD.persist() method instead.")
}
- new SparkEnv(
+ val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
+ new OutputCommitCoordinator(conf)
+ }
+ val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator",
+ new OutputCommitCoordinatorActor(outputCommitCoordinator))
+ outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor)
+
+ val envInstance = new SparkEnv(
executorId,
actorSystem,
serializer,
@@ -366,7 +397,17 @@ object SparkEnv extends Logging {
sparkFilesDir,
metricsSystem,
shuffleMemoryManager,
+ outputCommitCoordinator,
conf)
+
+ // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is
+ // called, and we only need to do it for driver. Because driver may run as a service, and if we
+ // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs.
+ if (isDriver) {
+ envInstance.driverTmpDirToDelete = Some(sparkFilesDir)
+ }
+
+ envInstance
}
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 40237596570de..6eb4537d10477 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
+import org.apache.spark.executor.CommitDeniedException
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.HadoopRDD
@@ -105,24 +106,56 @@ class SparkHadoopWriter(@transient jobConf: JobConf)
def commit() {
val taCtxt = getTaskContext()
val cmtr = getOutputCommitter()
- if (cmtr.needsTaskCommit(taCtxt)) {
+
+ // Called after we have decided to commit
+ def performCommit(): Unit = {
try {
cmtr.commitTask(taCtxt)
- logInfo (taID + ": Committed")
+ logInfo (s"$taID: Committed")
} catch {
- case e: IOException => {
+ case e: IOException =>
logError("Error committing the output of task: " + taID.value, e)
cmtr.abortTask(taCtxt)
throw e
+ }
+ }
+
+ // First, check whether the task's output has already been committed by some other attempt
+ if (cmtr.needsTaskCommit(taCtxt)) {
+ // The task output needs to be committed, but we don't know whether some other task attempt
+ // might be racing to commit the same output partition. Therefore, coordinate with the driver
+ // in order to determine whether this attempt can commit (see SPARK-4879).
+ val shouldCoordinateWithDriver: Boolean = {
+ val sparkConf = SparkEnv.get.conf
+ // We only need to coordinate with the driver if there are multiple concurrent task
+ // attempts, which should only occur if speculation is enabled
+ val speculationEnabled = sparkConf.getBoolean("spark.speculation", false)
+ // This (undocumented) setting is an escape-hatch in case the commit code introduces bugs
+ sparkConf.getBoolean("spark.hadoop.outputCommitCoordination.enabled", speculationEnabled)
+ }
+ if (shouldCoordinateWithDriver) {
+ val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
+ val canCommit = outputCommitCoordinator.canCommit(jobID, splitID, attemptID)
+ if (canCommit) {
+ performCommit()
+ } else {
+ val msg = s"$taID: Not committed because the driver did not authorize commit"
+ logInfo(msg)
+ // We need to abort the task so that the driver can reschedule new attempts, if necessary
+ cmtr.abortTask(taCtxt)
+ throw new CommitDeniedException(msg, jobID, splitID, attemptID)
}
+ } else {
+ // Speculation is disabled or a user has chosen to manually bypass the commit coordination
+ performCommit()
}
} else {
- logInfo ("No need to commit output of task: " + taID.value)
+ // Some other attempt committed the output, so we do nothing and signal success
+ logInfo(s"No need to commit output of task because needsTaskCommit=false: ${taID.value}")
}
}
def commitJob() {
- // always ? Or if cmtr.needsTaskCommit ?
val cmtr = getOutputCommitter()
cmtr.commitJob(getJobContext())
}
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index af5fd8e0ac00c..29a5cd5fdac76 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -146,6 +146,20 @@ case object TaskKilled extends TaskFailedReason {
override def toErrorString: String = "TaskKilled (killed intentionally)"
}
+/**
+ * :: DeveloperApi ::
+ * Task requested the driver to commit, but was denied.
+ */
+@DeveloperApi
+case class TaskCommitDenied(
+ jobID: Int,
+ partitionID: Int,
+ attemptID: Int)
+ extends TaskFailedReason {
+ override def toErrorString: String = s"TaskCommitDenied (Driver denied task commit)" +
+ s" for job: $jobID, partition: $partitionID, attempt: $attemptID"
+}
+
/**
* :: DeveloperApi ::
* The task failed because the executor that it was running on was lost. This may happen because
diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala
index be081c3825566..35b324ba6f573 100644
--- a/core/src/main/scala/org/apache/spark/TestUtils.scala
+++ b/core/src/main/scala/org/apache/spark/TestUtils.scala
@@ -17,12 +17,13 @@
package org.apache.spark
-import java.io.{File, FileInputStream, FileOutputStream}
+import java.io.{ByteArrayInputStream, File, FileInputStream, FileOutputStream}
import java.net.{URI, URL}
import java.util.jar.{JarEntry, JarOutputStream}
import scala.collection.JavaConversions._
+import com.google.common.base.Charsets.UTF_8
import com.google.common.io.{ByteStreams, Files}
import javax.tools.{JavaFileObject, SimpleJavaFileObject, ToolProvider}
@@ -59,6 +60,22 @@ private[spark] object TestUtils {
createJar(files1 ++ files2, jarFile)
}
+ /**
+ * Create a jar file containing multiple files. The `files` map contains a mapping of
+ * file names in the jar file to their contents.
+ */
+ def createJarWithFiles(files: Map[String, String], dir: File = null): URL = {
+ val tempDir = Option(dir).getOrElse(Utils.createTempDir())
+ val jarFile = File.createTempFile("testJar", ".jar", tempDir)
+ val jarStream = new JarOutputStream(new FileOutputStream(jarFile))
+ files.foreach { case (k, v) =>
+ val entry = new JarEntry(k)
+ jarStream.putNextEntry(entry)
+ ByteStreams.copy(new ByteArrayInputStream(v.getBytes(UTF_8)), jarStream)
+ }
+ jarStream.close()
+ jarFile.toURI.toURL
+ }
/**
* Create a jar file that contains this set of files. All files will be located at the root
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
new file mode 100644
index 0000000000000..164e95081583f
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import java.io.DataOutputStream
+import java.net.Socket
+
+import py4j.GatewayServer
+
+import org.apache.spark.Logging
+import org.apache.spark.util.Utils
+
+/**
+ * Process that starts a Py4J GatewayServer on an ephemeral port and communicates the bound port
+ * back to its caller via a callback port specified by the caller.
+ *
+ * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py).
+ */
+private[spark] object PythonGatewayServer extends Logging {
+ def main(args: Array[String]): Unit = Utils.tryOrExit {
+ // Start a GatewayServer on an ephemeral port
+ val gatewayServer: GatewayServer = new GatewayServer(null, 0)
+ gatewayServer.start()
+ val boundPort: Int = gatewayServer.getListeningPort
+ if (boundPort == -1) {
+ logError("GatewayServer failed to bind; exiting")
+ System.exit(1)
+ } else {
+ logDebug(s"Started PythonGatewayServer on port $boundPort")
+ }
+
+ // Communicate the bound port back to the caller via the caller-specified callback port
+ val callbackHost = sys.env("_PYSPARK_DRIVER_CALLBACK_HOST")
+ val callbackPort = sys.env("_PYSPARK_DRIVER_CALLBACK_PORT").toInt
+ logDebug(s"Communicating GatewayServer port to Python driver at $callbackHost:$callbackPort")
+ val callbackSocket = new Socket(callbackHost, callbackPort)
+ val dos = new DataOutputStream(callbackSocket.getOutputStream)
+ dos.writeInt(boundPort)
+ dos.close()
+ callbackSocket.close()
+
+ // Exit on EOF or broken pipe to ensure that this process dies when the Python driver dies:
+ while (System.in.read() != -1) {
+ // Do nothing
+ }
+ logDebug("Exiting due to broken pipe from Python driver")
+ System.exit(0)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index b89effc16d36d..b1cec0f6472b0 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -219,14 +219,13 @@ private[spark] class PythonRDD(
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
val newBids = broadcastVars.map(_.id).toSet
// number of different broadcasts
- val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
+ val toRemove = oldBids.diff(newBids)
+ val cnt = toRemove.size + newBids.diff(oldBids).size
dataOut.writeInt(cnt)
- for (bid <- oldBids) {
- if (!newBids.contains(bid)) {
- // remove the broadcast from worker
- dataOut.writeLong(- bid - 1) // bid >= 0
- oldBids.remove(bid)
- }
+ for (bid <- toRemove) {
+ // remove the broadcast from worker
+ dataOut.writeLong(- bid - 1) // bid >= 0
+ oldBids.remove(bid)
}
for (broadcast <- broadcastVars) {
if (!oldBids.contains(broadcast.id)) {
@@ -248,13 +247,13 @@ private[spark] class PythonRDD(
} catch {
case e: Exception if context.isCompleted || context.isInterrupted =>
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
- worker.shutdownOutput()
+ Utils.tryLog(worker.shutdownOutput())
case e: Exception =>
// We must avoid throwing exceptions here, because the thread uncaught exception handler
// will kill the whole executor (see org.apache.spark.executor.Executor).
_exception = e
- worker.shutdownOutput()
+ Utils.tryLog(worker.shutdownOutput())
} finally {
// Release memory used by this thread for shuffles
env.shuffleMemoryManager.releaseMemoryForThisThread()
@@ -303,6 +302,7 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Long, Array[Byte])](prev) {
override def getPartitions = prev.partitions
+ override val partitioner = prev.partitioner
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (Utils.deserializeLongValue(a), b)
@@ -329,6 +329,15 @@ private[spark] object PythonRDD extends Logging {
}
}
+ /**
+ * Return an RDD of values from an RDD of (Long, Array[Byte]), with preservePartitions=true
+ *
+ * This is useful for PySpark to have the partitioner after partitionBy()
+ */
+ def valueOfPair(pair: JavaPairRDD[Long, Array[Byte]]): JavaRDD[Array[Byte]] = {
+ pair.rdd.mapPartitions(it => it.map(_._2), true)
+ }
+
/**
* Adapter for calling SparkContext#runJob from Python.
*
diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala
index 38b3da0b13756..237d26fc6bd0e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala
@@ -68,8 +68,9 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf)
.map(Utils.splitCommandString).getOrElse(Seq.empty)
val sparkJavaOpts = Utils.sparkJavaOpts(conf)
val javaOpts = sparkJavaOpts ++ extraJavaOpts
- val command = new Command(mainClass, Seq("{{WORKER_URL}}", driverArgs.mainClass) ++
- driverArgs.driverOptions, sys.env, classPathEntries, libraryPathEntries, javaOpts)
+ val command = new Command(mainClass,
+ Seq("{{WORKER_URL}}", "{{USER_JAR}}", driverArgs.mainClass) ++ driverArgs.driverOptions,
+ sys.env, classPathEntries, libraryPathEntries, javaOpts)
val driverDescription = new DriverDescription(
driverArgs.jarUrl,
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index 03238e9fa0088..e0a32fb65cd51 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -52,18 +52,13 @@ class SparkHadoopUtil extends Logging {
* do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems
*/
def runAsSparkUser(func: () => Unit) {
- val user = Option(System.getenv("SPARK_USER")).getOrElse(SparkContext.SPARK_UNKNOWN_USER)
- if (user != SparkContext.SPARK_UNKNOWN_USER) {
- logDebug("running as user: " + user)
- val ugi = UserGroupInformation.createRemoteUser(user)
- transferCredentials(UserGroupInformation.getCurrentUser(), ugi)
- ugi.doAs(new PrivilegedExceptionAction[Unit] {
- def run: Unit = func()
- })
- } else {
- logDebug("running as SPARK_UNKNOWN_USER")
- func()
- }
+ val user = Utils.getCurrentUserName()
+ logDebug("running as user: " + user)
+ val ugi = UserGroupInformation.createRemoteUser(user)
+ transferCredentials(UserGroupInformation.getCurrentUser(), ugi)
+ ugi.doAs(new PrivilegedExceptionAction[Unit] {
+ def run: Unit = func()
+ })
}
def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index 6d213926f3d7b..4a74641f4e1fa 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -18,12 +18,14 @@
package org.apache.spark.deploy
import java.io.{File, PrintStream}
-import java.lang.reflect.{InvocationTargetException, Modifier}
+import java.lang.reflect.{InvocationTargetException, Modifier, UndeclaredThrowableException}
import java.net.URL
+import java.security.PrivilegedExceptionAction
import scala.collection.mutable.{ArrayBuffer, HashMap, Map}
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.security.UserGroupInformation
import org.apache.ivy.Ivy
import org.apache.ivy.core.LogOptions
import org.apache.ivy.core.module.descriptor._
@@ -35,9 +37,9 @@ import org.apache.ivy.core.settings.IvySettings
import org.apache.ivy.plugins.matcher.GlobPatternMatcher
import org.apache.ivy.plugins.resolver.{ChainResolver, IBiblioResolver}
+import org.apache.spark.SPARK_VERSION
import org.apache.spark.deploy.rest._
-import org.apache.spark.executor._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils}
/**
* Whether to submit, kill, or request the status of an application.
@@ -79,16 +81,27 @@ object SparkSubmit {
private val CLASS_NOT_FOUND_EXIT_STATUS = 101
// Exposed for testing
- private[spark] var exitFn: () => Unit = () => System.exit(-1)
+ private[spark] var exitFn: () => Unit = () => System.exit(1)
private[spark] var printStream: PrintStream = System.err
- private[spark] def printWarning(str: String) = printStream.println("Warning: " + str)
- private[spark] def printErrorAndExit(str: String) = {
+ private[spark] def printWarning(str: String): Unit = printStream.println("Warning: " + str)
+ private[spark] def printErrorAndExit(str: String): Unit = {
printStream.println("Error: " + str)
printStream.println("Run with --help for usage help or --verbose for debug output")
exitFn()
}
+ private[spark] def printVersionAndExit(): Unit = {
+ printStream.println("""Welcome to
+ ____ __
+ / __/__ ___ _____/ /__
+ _\ \/ _ \/ _ `/ __/ '_/
+ /___/ .__/\_,_/_/ /_/\_\ version %s
+ /_/
+ """.format(SPARK_VERSION))
+ printStream.println("Type --help for more information.")
+ exitFn()
+ }
- def main(args: Array[String]) {
+ def main(args: Array[String]): Unit = {
val appArgs = new SparkSubmitArguments(args)
if (appArgs.verbose) {
printStream.println(appArgs)
@@ -126,6 +139,34 @@ object SparkSubmit {
*/
private[spark] def submit(args: SparkSubmitArguments): Unit = {
val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args)
+
+ def doRunMain(): Unit = {
+ if (args.proxyUser != null) {
+ val proxyUser = UserGroupInformation.createProxyUser(args.proxyUser,
+ UserGroupInformation.getCurrentUser())
+ try {
+ proxyUser.doAs(new PrivilegedExceptionAction[Unit]() {
+ override def run(): Unit = {
+ runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose)
+ }
+ })
+ } catch {
+ case e: Exception =>
+ // Hadoop's AuthorizationException suppresses the exception's stack trace, which
+ // makes the message printed to the output by the JVM not very helpful. Instead,
+ // detect exceptions with empty stack traces here, and treat them differently.
+ if (e.getStackTrace().length == 0) {
+ printStream.println(s"ERROR: ${e.getClass().getName()}: ${e.getMessage()}")
+ exitFn()
+ } else {
+ throw e
+ }
+ }
+ } else {
+ runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose)
+ }
+ }
+
// In standalone cluster mode, there are two submission gateways:
// (1) The traditional Akka gateway using o.a.s.deploy.Client as a wrapper
// (2) The new REST-based gateway introduced in Spark 1.3
@@ -134,7 +175,7 @@ object SparkSubmit {
if (args.isStandaloneCluster && args.useRest) {
try {
printStream.println("Running Spark using the REST application submission protocol.")
- runMain(childArgs, childClasspath, sysProps, childMainClass)
+ doRunMain()
} catch {
// Fail over to use the legacy submission gateway
case e: SubmitRestConnectionException =>
@@ -145,7 +186,7 @@ object SparkSubmit {
}
// In all other modes, just run the main class as prepared
} else {
- runMain(childArgs, childClasspath, sysProps, childMainClass)
+ doRunMain()
}
}
@@ -211,6 +252,26 @@ object SparkSubmit {
val isYarnCluster = clusterManager == YARN && deployMode == CLUSTER
+ // Resolve maven dependencies if there are any and add classpath to jars. Add them to py-files
+ // too for packages that include Python code
+ val resolvedMavenCoordinates =
+ SparkSubmitUtils.resolveMavenCoordinates(
+ args.packages, Option(args.repositories), Option(args.ivyRepoPath))
+ if (!resolvedMavenCoordinates.trim.isEmpty) {
+ if (args.jars == null || args.jars.trim.isEmpty) {
+ args.jars = resolvedMavenCoordinates
+ } else {
+ args.jars += s",$resolvedMavenCoordinates"
+ }
+ if (args.isPython) {
+ if (args.pyFiles == null || args.pyFiles.trim.isEmpty) {
+ args.pyFiles = resolvedMavenCoordinates
+ } else {
+ args.pyFiles += s",$resolvedMavenCoordinates"
+ }
+ }
+ }
+
// Require all python files to be local, so we can add them to the PYTHONPATH
// In YARN cluster mode, python files are distributed as regular files, which can be non-local
if (args.isPython && !isYarnCluster) {
@@ -242,8 +303,7 @@ object SparkSubmit {
// If we're running a python app, set the main class to our specific python runner
if (args.isPython && deployMode == CLIENT) {
if (args.primaryResource == PYSPARK_SHELL) {
- args.mainClass = "py4j.GatewayServer"
- args.childArgs = ArrayBuffer("--die-on-broken-pipe", "0")
+ args.mainClass = "org.apache.spark.api.python.PythonGatewayServer"
} else {
// If a python file is provided, add it to the child arguments and list of files to deploy.
// Usage: PythonAppRunner [app arguments]
@@ -267,18 +327,6 @@ object SparkSubmit {
// Special flag to avoid deprecation warnings at the client
sysProps("SPARK_SUBMIT") = "true"
- // Resolve maven dependencies if there are any and add classpath to jars
- val resolvedMavenCoordinates =
- SparkSubmitUtils.resolveMavenCoordinates(
- args.packages, Option(args.repositories), Option(args.ivyRepoPath))
- if (!resolvedMavenCoordinates.trim.isEmpty) {
- if (args.jars == null || args.jars.trim.isEmpty) {
- args.jars = resolvedMavenCoordinates
- } else {
- args.jars += s",$resolvedMavenCoordinates"
- }
- }
-
// A list of rules to map each argument to system properties or command-line options in
// each deploy mode; we iterate through these below
val options = List[OptionAssigner](
@@ -457,7 +505,7 @@ object SparkSubmit {
childClasspath: Seq[String],
sysProps: Map[String, String],
childMainClass: String,
- verbose: Boolean = false) {
+ verbose: Boolean): Unit = {
if (verbose) {
printStream.println(s"Main class:\n$childMainClass")
printStream.println(s"Arguments:\n${childArgs.mkString("\n")}")
@@ -467,11 +515,11 @@ object SparkSubmit {
}
val loader =
- if (sysProps.getOrElse("spark.files.userClassPathFirst", "false").toBoolean) {
- new ChildExecutorURLClassLoader(new Array[URL](0),
+ if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) {
+ new ChildFirstURLClassLoader(new Array[URL](0),
Thread.currentThread.getContextClassLoader)
} else {
- new ExecutorURLClassLoader(new Array[URL](0),
+ new MutableURLClassLoader(new Array[URL](0),
Thread.currentThread.getContextClassLoader)
}
Thread.currentThread.setContextClassLoader(loader)
@@ -507,13 +555,21 @@ object SparkSubmit {
if (!Modifier.isStatic(mainMethod.getModifiers)) {
throw new IllegalStateException("The main method in the given main class must be static")
}
+
+ def findCause(t: Throwable): Throwable = t match {
+ case e: UndeclaredThrowableException =>
+ if (e.getCause() != null) findCause(e.getCause()) else e
+ case e: InvocationTargetException =>
+ if (e.getCause() != null) findCause(e.getCause()) else e
+ case e: Throwable =>
+ e
+ }
+
try {
mainMethod.invoke(null, childArgs.toArray)
} catch {
- case e: InvocationTargetException => e.getCause match {
- case cause: Throwable => throw cause
- case null => throw e
- }
+ case t: Throwable =>
+ throw findCause(t)
}
}
@@ -598,13 +654,14 @@ private[spark] object SparkSubmitUtils {
private[spark] case class MavenCoordinate(groupId: String, artifactId: String, version: String)
/**
- * Extracts maven coordinates from a comma-delimited string
+ * Extracts maven coordinates from a comma-delimited string. Coordinates should be provided
+ * in the format `groupId:artifactId:version` or `groupId/artifactId:version`.
* @param coordinates Comma-delimited string of maven coordinates
* @return Sequence of Maven coordinates
*/
private[spark] def extractMavenCoordinates(coordinates: String): Seq[MavenCoordinate] = {
coordinates.split(",").map { p =>
- val splits = p.split(":")
+ val splits = p.replace("/", ":").split(":")
require(splits.length == 3, s"Provided Maven Coordinates must be in the form " +
s"'groupId:artifactId:version'. The coordinate provided is: $p")
require(splits(0) != null && splits(0).trim.nonEmpty, s"The groupId cannot be null or " +
@@ -634,6 +691,13 @@ private[spark] object SparkSubmitUtils {
br.setName("central")
cr.add(br)
+ val sp: IBiblioResolver = new IBiblioResolver
+ sp.setM2compatible(true)
+ sp.setUsepoms(true)
+ sp.setRoot("http://dl.bintray.com/spark-packages/maven")
+ sp.setName("spark-packages")
+ cr.add(sp)
+
val repositoryList = remoteRepos.getOrElse("")
// add any other remote repositories other than maven central
if (repositoryList.trim.nonEmpty) {
@@ -682,6 +746,35 @@ private[spark] object SparkSubmitUtils {
md.addDependency(dd)
}
}
+
+ /** Add exclusion rules for dependencies already included in the spark-assembly */
+ private[spark] def addExclusionRules(
+ ivySettings: IvySettings,
+ ivyConfName: String,
+ md: DefaultModuleDescriptor): Unit = {
+ // Add scala exclusion rule
+ val scalaArtifacts = new ArtifactId(new ModuleId("*", "scala-library"), "*", "*", "*")
+ val scalaDependencyExcludeRule =
+ new DefaultExcludeRule(scalaArtifacts, ivySettings.getMatcher("glob"), null)
+ scalaDependencyExcludeRule.addConfiguration(ivyConfName)
+ md.addExcludeRule(scalaDependencyExcludeRule)
+
+ // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and
+ // other spark-streaming utility components. Underscore is there to differentiate between
+ // spark-streaming_2.1x and spark-streaming-kafka-assembly_2.1x
+ val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_",
+ "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_")
+
+ components.foreach { comp =>
+ val sparkArtifacts =
+ new ArtifactId(new ModuleId("org.apache.spark", s"spark-$comp*"), "*", "*", "*")
+ val sparkDependencyExcludeRule =
+ new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null)
+ sparkDependencyExcludeRule.addConfiguration(ivyConfName)
+
+ md.addExcludeRule(sparkDependencyExcludeRule)
+ }
+ }
/** A nice function to use in tests as well. Values are dummy strings. */
private[spark] def getModuleDescriptor = DefaultModuleDescriptor.newDefaultInstance(
@@ -703,6 +796,9 @@ private[spark] object SparkSubmitUtils {
if (coordinates == null || coordinates.trim.isEmpty) {
""
} else {
+ val sysOut = System.out
+ // To prevent ivy from logging to system out
+ System.setOut(printStream)
val artifacts = extractMavenCoordinates(coordinates)
// Default configuration name for ivy
val ivyConfName = "default"
@@ -746,14 +842,9 @@ private[spark] object SparkSubmitUtils {
val md = getModuleDescriptor
md.setDefaultConf(ivyConfName)
- // Add an exclusion rule for Spark
- val sparkArtifacts = new ArtifactId(new ModuleId("org.apache.spark", "*"), "*", "*", "*")
- val sparkDependencyExcludeRule =
- new DefaultExcludeRule(sparkArtifacts, ivySettings.getMatcher("glob"), null)
- sparkDependencyExcludeRule.addConfiguration(ivyConfName)
-
- // Exclude any Spark dependencies, and add all supplied maven artifacts as dependencies
- md.addExcludeRule(sparkDependencyExcludeRule)
+ // Add exclusion rules for Spark and Scala Library
+ addExclusionRules(ivySettings, ivyConfName, md)
+ // add all supplied maven artifacts as dependencies
addDependenciesToIvy(md, artifacts, ivyConfName)
// resolve dependencies
@@ -765,7 +856,7 @@ private[spark] object SparkSubmitUtils {
ivy.retrieve(rr.getModuleDescriptor.getModuleRevisionId,
packagesDirectory.getAbsolutePath + File.separator + "[artifact](-[classifier]).[ext]",
retrieveOptions.setConfs(Array(ivyConfName)))
-
+ System.setOut(sysOut)
resolveDependencyPaths(rr.getArtifacts.toArray, packagesDirectory)
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index bd0ae26fd8210..82e66a374249c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -57,6 +57,7 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
var pyFiles: String = null
var action: SparkSubmitAction = null
val sparkProperties: HashMap[String, String] = new HashMap[String, String]()
+ var proxyUser: String = null
// Standalone cluster mode only
var supervise: Boolean = false
@@ -405,6 +406,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
}
parse(tail)
+ case ("--proxy-user") :: value :: tail =>
+ proxyUser = value
+ parse(tail)
+
case ("--help" | "-h") :: tail =>
printUsageAndExit(0)
@@ -412,6 +417,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
verbose = true
parse(tail)
+ case ("--version") :: tail =>
+ SparkSubmit.printVersionAndExit()
+
case EQ_SEPARATED_OPT(opt, value) :: tail =>
parse(opt :: value :: tail)
@@ -476,8 +484,11 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
|
| --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G).
|
+ | --proxy-user NAME User to impersonate when submitting the application.
+ |
| --help, -h Show this help message and exit
| --verbose, -v Print additional debug output
+ | --version, Print the version of current Spark
|
| Spark standalone with cluster deploy mode only:
| --driver-cores NUM Cores for driver (Default: 1).
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
index 2eab9981845e8..311048cdaa324 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
@@ -17,8 +17,6 @@
package org.apache.spark.deploy
-import java.io.File
-
import scala.collection.JavaConversions._
import org.apache.spark.util.{RedirectThread, Utils}
@@ -164,6 +162,8 @@ private[spark] object SparkSubmitDriverBootstrapper {
}
}
val returnCode = process.waitFor()
+ stdoutThread.join()
+ stderrThread.join()
sys.exit(returnCode)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 868c63d30a202..3e3d6ff29faf0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -17,9 +17,13 @@
package org.apache.spark.deploy.history
-import java.io.{BufferedInputStream, FileNotFoundException, InputStream}
+import java.io.{IOException, BufferedInputStream, FileNotFoundException, InputStream}
+import java.util.concurrent.{Executors, TimeUnit}
import scala.collection.mutable
+import scala.concurrent.duration.Duration
+
+import com.google.common.util.concurrent.ThreadFactoryBuilder
import org.apache.hadoop.fs.{FileStatus, Path}
import org.apache.hadoop.fs.permission.AccessControlException
@@ -44,8 +48,15 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private val NOT_STARTED = ""
// Interval between each check for event log updates
- private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval",
- conf.getInt("spark.history.updateInterval", 10)) * 1000
+ private val UPDATE_INTERVAL_MS = conf.getOption("spark.history.fs.update.interval.seconds")
+ .orElse(conf.getOption("spark.history.fs.updateInterval"))
+ .orElse(conf.getOption("spark.history.updateInterval"))
+ .map(_.toInt)
+ .getOrElse(10) * 1000
+
+ // Interval between each cleaner checks for event logs to delete
+ private val CLEAN_INTERVAL_MS = conf.getLong("spark.history.fs.cleaner.interval.seconds",
+ DEFAULT_SPARK_HISTORY_FS_CLEANER_INTERVAL_S) * 1000
private val logDir = conf.getOption("spark.history.fs.logDirectory")
.map { d => Utils.resolveURI(d).toString }
@@ -53,8 +64,11 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf))
- // A timestamp of when the disk was last accessed to check for log updates
- private var lastLogCheckTimeMs = -1L
+ // Used by check event thread and clean log thread.
+ // Scheduled thread pool size must be one, otherwise it will have concurrent issues about fs
+ // and applications between check task and clean task.
+ private val pool = Executors.newScheduledThreadPool(1, new ThreadFactoryBuilder()
+ .setNameFormat("spark-history-task-%d").setDaemon(true).build())
// The modification time of the newest log detected during the last scan. This is used
// to ignore logs that are older during subsequent scans, to avoid processing data that
@@ -73,25 +87,13 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private[history] val APPLICATION_COMPLETE = "APPLICATION_COMPLETE"
/**
- * A background thread that periodically checks for event log updates on disk.
- *
- * If a log check is invoked manually in the middle of a period, this thread re-adjusts the
- * time at which it performs the next log check to maintain the same period as before.
- *
- * TODO: Add a mechanism to update manually.
+ * Return a runnable that performs the given operation on the event logs.
+ * This operation is expected to be executed periodically.
*/
- private val logCheckingThread = new Thread("LogCheckingThread") {
- override def run() = Utils.logUncaughtExceptions {
- while (true) {
- val now = getMonotonicTimeMs()
- if (now - lastLogCheckTimeMs > UPDATE_INTERVAL_MS) {
- Thread.sleep(UPDATE_INTERVAL_MS)
- } else {
- // If the user has manually checked for logs recently, wait until
- // UPDATE_INTERVAL_MS after the last check time
- Thread.sleep(lastLogCheckTimeMs + UPDATE_INTERVAL_MS - now)
- }
- checkForLogs()
+ private def getRunner(operateFun: () => Unit): Runnable = {
+ new Runnable() {
+ override def run() = Utils.logUncaughtExceptions {
+ operateFun()
}
}
}
@@ -113,12 +115,17 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
"Logging directory specified is not a directory: %s".format(logDir))
}
- checkForLogs()
-
// Disable the background thread during tests.
if (!conf.contains("spark.testing")) {
- logCheckingThread.setDaemon(true)
- logCheckingThread.start()
+ // A task that periodically checks for event log updates on disk.
+ pool.scheduleAtFixedRate(getRunner(checkForLogs), 0, UPDATE_INTERVAL_MS,
+ TimeUnit.MILLISECONDS)
+
+ if (conf.getBoolean("spark.history.fs.cleaner.enabled", false)) {
+ // A task that periodically cleans event logs on disk.
+ pool.scheduleAtFixedRate(getRunner(cleanLogs), 0, CLEAN_INTERVAL_MS,
+ TimeUnit.MILLISECONDS)
+ }
}
}
@@ -163,9 +170,6 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
* applications that haven't been updated since last time the logs were checked.
*/
private[history] def checkForLogs(): Unit = {
- lastLogCheckTimeMs = getMonotonicTimeMs()
- logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs))
-
try {
var newLastModifiedTime = lastModifiedTime
val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq)
@@ -230,6 +234,45 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
}
}
+ /**
+ * Delete event logs from the log directory according to the clean policy defined by the user.
+ */
+ private def cleanLogs(): Unit = {
+ try {
+ val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq)
+ .getOrElse(Seq[FileStatus]())
+ val maxAge = conf.getLong("spark.history.fs.cleaner.maxAge.seconds",
+ DEFAULT_SPARK_HISTORY_FS_MAXAGE_S) * 1000
+
+ val now = System.currentTimeMillis()
+ val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]()
+
+ applications.values.foreach { info =>
+ if (now - info.lastUpdated <= maxAge) {
+ appsToRetain += (info.id -> info)
+ }
+ }
+
+ applications = appsToRetain
+
+ // Scan all logs from the log directory.
+ // Only directories older than the specified max age will be deleted
+ statusList.foreach { dir =>
+ try {
+ if (now - dir.getModificationTime() > maxAge) {
+ // if path is a directory and set to true,
+ // the directory is deleted else throws an exception
+ fs.delete(dir.getPath, true)
+ }
+ } catch {
+ case t: IOException => logError(s"IOException in cleaning logs of $dir", t)
+ }
+ }
+ } catch {
+ case t: Exception => logError("Exception in cleaning logs", t)
+ }
+ }
+
/**
* Comparison function that defines the sort order for the application listing.
*
@@ -247,6 +290,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
*/
private def replay(eventLog: FileStatus, bus: ReplayListenerBus): FsApplicationHistoryInfo = {
val logPath = eventLog.getPath()
+ logInfo(s"Replaying log path: $logPath")
val (logInput, sparkVersion) =
if (isLegacyLogDirectory(eventLog)) {
openLegacyEventLog(logPath)
@@ -256,7 +300,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
try {
val appListener = new ApplicationEventListener
bus.addListener(appListener)
- bus.replay(logInput, sparkVersion)
+ bus.replay(logInput, sparkVersion, logPath.toString)
new FsApplicationHistoryInfo(
logPath.getName(),
appListener.appId.getOrElse(logPath.getName()),
@@ -335,9 +379,6 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
}
}
- /** Returns the system's mononotically increasing time. */
- private def getMonotonicTimeMs(): Long = System.nanoTime() / (1000 * 1000)
-
/**
* Return true when the application has completed.
*/
@@ -353,6 +394,12 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private object FsHistoryProvider {
val DEFAULT_LOG_DIR = "file:/tmp/spark-events"
+
+ // One day
+ val DEFAULT_SPARK_HISTORY_FS_CLEANER_INTERVAL_S = Duration(1, TimeUnit.DAYS).toSeconds
+
+ // One week
+ val DEFAULT_SPARK_HISTORY_FS_MAXAGE_S = Duration(7, TimeUnit.DAYS).toSeconds
}
private class FsApplicationHistoryInfo(
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index e4e7bc2216014..26ebc75971c66 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -61,9 +61,10 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
// page, `...` will be displayed.
if (allApps.size > 0) {
val leftSideIndices =
- rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _)
+ rangeIndices(actualPage - plusOrMinus until actualPage, 1 < _, requestedIncomplete)
val rightSideIndices =
- rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount)
+ rangeIndices(actualPage + 1 to actualPage + plusOrMinus, _ < pageCount,
+ requestedIncomplete)
Showing {actualFirst + 1}-{last + 1} of {allApps.size}
@@ -122,8 +123,10 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
"Spark User",
"Last Updated")
- private def rangeIndices(range: Seq[Int], condition: Int => Boolean): Seq[Node] = {
- range.filter(condition).map(nextPage => {nextPage} )
+ private def rangeIndices(range: Seq[Int], condition: Int => Boolean, showIncomplete: Boolean):
+ Seq[Node] = {
+ range.filter(condition).map(nextPage =>
+ {nextPage} )
}
private def appRow(info: ApplicationHistoryInfo): Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index fa9bfe5426b6c..af483d560b33e 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -96,6 +96,10 @@ class HistoryServer(
}
}
}
+ // SPARK-5983 ensure TRACE is not supported
+ protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = {
+ res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
+ }
}
initialize()
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
index ede0a9dbefb8d..a962dc4af2f6c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala
@@ -90,9 +90,9 @@ private[spark] class ApplicationInfo(
}
}
- private val myMaxCores = desc.maxCores.getOrElse(defaultCores)
+ val requestedCores = desc.maxCores.getOrElse(defaultCores)
- def coresLeft: Int = myMaxCores - coresGranted
+ def coresLeft: Int = requestedCores - coresGranted
private var _retryCount = 0
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index 67e6c5d66af0e..f5b946329ae9b 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -21,7 +21,7 @@ private[spark] object ApplicationState extends Enumeration {
type ApplicationState = Value
- val WAITING, RUNNING, FINISHED, FAILED, UNKNOWN = Value
+ val WAITING, RUNNING, FINISHED, FAILED, KILLED, UNKNOWN = Value
val MAX_NUM_RETRY = 10
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 53e453990f8c7..8cc6ec1e8192c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -761,7 +761,7 @@ private[spark] class Master(
val ui = SparkUI.createHistoryUI(new SparkConf, replayBus, new SecurityManager(conf),
appName + " (completed)", HistoryServer.UI_PATH_PREFIX + s"/${app.id}")
try {
- replayBus.replay(logInput, sparkVersion)
+ replayBus.replay(logInput, sparkVersion, eventLogFile)
} finally {
logInput.close()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index 3aae2b95d7396..76fc40e17d9a8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -24,6 +24,7 @@ import scala.xml.Node
import akka.pattern.ask
import org.json4s.JValue
+import org.json4s.JsonAST.JNothing
import org.apache.spark.deploy.{ExecutorState, JsonProtocol}
import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
@@ -44,7 +45,11 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
- JsonProtocol.writeApplicationInfo(app)
+ if (app == null) {
+ JNothing
+ } else {
+ JsonProtocol.writeApplicationInfo(app)
+ }
}
/** Executor details for a particular application */
@@ -55,6 +60,10 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app
val app = state.activeApps.find(_.id == appId).getOrElse({
state.completedApps.find(_.id == appId).getOrElse(null)
})
+ if (app == null) {
+ val msg =
No running application with ID {appId}
+ return UIUtils.basicSparkPage(msg, "Not Found")
+ }
val executorHeaders = Seq("ExecutorID", "Worker", "Cores", "Memory", "State", "Logs")
val allExecutors = (app.executors.values ++ app.removedExecutors).toSet.toSeq
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
index b47a081053e77..c086cadca2c7d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala
@@ -26,8 +26,8 @@ import akka.pattern.ask
import org.json4s.JValue
import org.apache.spark.deploy.JsonProtocol
-import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState}
-import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo}
+import org.apache.spark.deploy.DeployMessages.{RequestKillDriver, MasterStateResponse, RequestMasterState}
+import org.apache.spark.deploy.master._
import org.apache.spark.ui.{WebUIPage, UIUtils}
import org.apache.spark.util.Utils
@@ -41,6 +41,31 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
JsonProtocol.writeMasterState(state)
}
+ def handleAppKillRequest(request: HttpServletRequest): Unit = {
+ handleKillRequest(request, id => {
+ parent.master.idToApp.get(id).foreach { app =>
+ parent.master.removeApplication(app, ApplicationState.KILLED)
+ }
+ })
+ }
+
+ def handleDriverKillRequest(request: HttpServletRequest): Unit = {
+ handleKillRequest(request, id => { master ! RequestKillDriver(id) })
+ }
+
+ private def handleKillRequest(request: HttpServletRequest, action: String => Unit): Unit = {
+ if (parent.killEnabled &&
+ parent.master.securityMgr.checkModifyPermissions(request.getRemoteUser)) {
+ val killFlag = Option(request.getParameter("terminate")).getOrElse("false").toBoolean
+ val id = Option(request.getParameter("id"))
+ if (id.isDefined && killFlag) {
+ action(id.get)
+ }
+
+ Thread.sleep(100)
+ }
+ }
+
/** Index view listing applications and executors */
def render(request: HttpServletRequest): Seq[Node] = {
val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse]
@@ -50,12 +75,16 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
val workers = state.workers.sortBy(_.id)
val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers)
- val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time",
- "User", "State", "Duration")
+ val activeAppHeaders = Seq("Application ID", "Name", "Cores in Use",
+ "Cores Requested", "Memory per Node", "Submitted Time", "User", "State", "Duration")
val activeApps = state.activeApps.sortBy(_.startTime).reverse
- val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps)
+ val activeAppsTable = UIUtils.listingTable(activeAppHeaders, activeAppRow, activeApps)
+
+ val completedAppHeaders = Seq("Application ID", "Name", "Cores Requested", "Memory per Node",
+ "Submitted Time", "User", "State", "Duration")
val completedApps = state.completedApps.sortBy(_.endTime).reverse
- val completedAppsTable = UIUtils.listingTable(appHeaders, appRow, completedApps)
+ val completedAppsTable = UIUtils.listingTable(completedAppHeaders, completeAppRow,
+ completedApps)
val driverHeaders = Seq("Submission ID", "Submitted Time", "Worker", "State", "Cores",
"Memory", "Main Class")
@@ -162,16 +191,34 @@ private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") {
}
- private def appRow(app: ApplicationInfo): Seq[Node] = {
+ private def appRow(app: ApplicationInfo, active: Boolean): Seq[Node] = {
+ val killLink = if (parent.killEnabled &&
+ (app.state == ApplicationState.RUNNING || app.state == ApplicationState.WAITING)) {
+ val killLinkUri = s"app/kill?id=${app.id}&terminate=true"
+ val confirm = "return window.confirm(" +
+ s"'Are you sure you want to kill application ${app.id} ?');"
+
+ (kill)
+
+ }
+
{driver.driverDesc.cores.toString}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 3a42f8b157977..dd19e4947db1e 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -17,8 +17,10 @@
package org.apache.spark.executor
+import java.net.URL
import java.nio.ByteBuffer
+import scala.collection.mutable
import scala.concurrent.Await
import akka.actor.{Actor, ActorSelection, Props}
@@ -38,6 +40,7 @@ private[spark] class CoarseGrainedExecutorBackend(
executorId: String,
hostPort: String,
cores: Int,
+ userClassPath: Seq[URL],
env: SparkEnv)
extends Actor with ActorLogReceive with ExecutorBackend with Logging {
@@ -63,7 +66,7 @@ private[spark] class CoarseGrainedExecutorBackend(
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
val (hostname, _) = Utils.parseHostPort(hostPort)
- executor = new Executor(executorId, hostname, env, isLocal = false)
+ executor = new Executor(executorId, hostname, env, userClassPath, isLocal = false)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -117,7 +120,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
hostname: String,
cores: Int,
appId: String,
- workerUrl: Option[String]) {
+ workerUrl: Option[String],
+ userClassPath: Seq[URL]) {
SignalLogger.register(log)
@@ -162,7 +166,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
val sparkHostPort = hostname + ":" + boundPort
env.actorSystem.actorOf(
Props(classOf[CoarseGrainedExecutorBackend],
- driverUrl, executorId, sparkHostPort, cores, env),
+ driverUrl, executorId, sparkHostPort, cores, userClassPath, env),
name = "Executor")
workerUrl.foreach { url =>
env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher")
@@ -172,20 +176,69 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
}
def main(args: Array[String]) {
- args.length match {
- case x if x < 5 =>
- System.err.println(
+ var driverUrl: String = null
+ var executorId: String = null
+ var hostname: String = null
+ var cores: Int = 0
+ var appId: String = null
+ var workerUrl: Option[String] = None
+ val userClassPath = new mutable.ListBuffer[URL]()
+
+ var argv = args.toList
+ while (!argv.isEmpty) {
+ argv match {
+ case ("--driver-url") :: value :: tail =>
+ driverUrl = value
+ argv = tail
+ case ("--executor-id") :: value :: tail =>
+ executorId = value
+ argv = tail
+ case ("--hostname") :: value :: tail =>
+ hostname = value
+ argv = tail
+ case ("--cores") :: value :: tail =>
+ cores = value.toInt
+ argv = tail
+ case ("--app-id") :: value :: tail =>
+ appId = value
+ argv = tail
+ case ("--worker-url") :: value :: tail =>
// Worker url is used in spark standalone mode to enforce fate-sharing with worker
- "Usage: CoarseGrainedExecutorBackend " +
- " [] ")
- System.exit(1)
+ workerUrl = Some(value)
+ argv = tail
+ case ("--user-class-path") :: value :: tail =>
+ userClassPath += new URL(value)
+ argv = tail
+ case Nil =>
+ case tail =>
+ System.err.println(s"Unrecognized options: ${tail.mkString(" ")}")
+ printUsageAndExit()
+ }
+ }
- // NB: These arguments are provided by SparkDeploySchedulerBackend (for standalone mode)
- // and CoarseMesosSchedulerBackend (for mesos mode).
- case 5 =>
- run(args(0), args(1), args(2), args(3).toInt, args(4), None)
- case x if x > 5 =>
- run(args(0), args(1), args(2), args(3).toInt, args(4), Some(args(5)))
+ if (driverUrl == null || executorId == null || hostname == null || cores <= 0 ||
+ appId == null) {
+ printUsageAndExit()
}
+
+ run(driverUrl, executorId, hostname, cores, appId, workerUrl, userClassPath)
}
+
+ private def printUsageAndExit() = {
+ System.err.println(
+ """
+ |"Usage: CoarseGrainedExecutorBackend [options]
+ |
+ | Options are:
+ | --driver-url
+ | --executor-id
+ | --hostname
+ | --cores
+ | --app-id
+ | --worker-url
+ | --user-class-path
+ |""".stripMargin)
+ System.exit(1)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
new file mode 100644
index 0000000000000..f7604a321f007
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.executor
+
+import org.apache.spark.{TaskCommitDenied, TaskEndReason}
+
+/**
+ * Exception thrown when a task attempts to commit output to HDFS but is denied by the driver.
+ */
+class CommitDeniedException(
+ msg: String,
+ jobID: Int,
+ splitID: Int,
+ attemptID: Int)
+ extends Exception(msg) {
+
+ def toTaskEndReason: TaskEndReason = new TaskCommitDenied(jobID, splitID, attemptID)
+
+}
+
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 5141483d1e745..b684fb704956b 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -19,6 +19,7 @@ package org.apache.spark.executor
import java.io.File
import java.lang.management.ManagementFactory
+import java.net.URL
import java.nio.ByteBuffer
import java.util.concurrent._
@@ -33,7 +34,8 @@ import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.scheduler._
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{StorageLevel, TaskResultBlockId}
-import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
+import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader,
+ SparkUncaughtExceptionHandler, AkkaUtils, Utils}
/**
* Spark executor used with Mesos, YARN, and the standalone scheduler.
@@ -43,6 +45,7 @@ private[spark] class Executor(
executorId: String,
executorHostname: String,
env: SparkEnv,
+ userClassPath: Seq[URL] = Nil,
isLocal: Boolean = false)
extends Logging
{
@@ -250,6 +253,11 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
}
+ case cDE: CommitDeniedException => {
+ val reason = cDE.toTaskEndReason
+ execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
+ }
+
case t: Throwable => {
// Attempt to exit cleanly by informing the driver of our failure.
// If anything goes wrong (or this was a fatal exception), we will delegate to
@@ -288,17 +296,23 @@ private[spark] class Executor(
* created by the interpreter to the search path
*/
private def createClassLoader(): MutableURLClassLoader = {
+ // Bootstrap the list of jars with the user class path.
+ val now = System.currentTimeMillis()
+ userClassPath.foreach { url =>
+ currentJars(url.getPath().split("/").last) = now
+ }
+
val currentLoader = Utils.getContextOrSparkClassLoader
// For each of the jars in the jarSet, add them to the class loader.
// We assume each of the files has already been fetched.
- val urls = currentJars.keySet.map { uri =>
+ val urls = userClassPath.toArray ++ currentJars.keySet.map { uri =>
new File(uri.split("/").last).toURI.toURL
- }.toArray
- val userClassPathFirst = conf.getBoolean("spark.files.userClassPathFirst", false)
- userClassPathFirst match {
- case true => new ChildExecutorURLClassLoader(urls, currentLoader)
- case false => new ExecutorURLClassLoader(urls, currentLoader)
+ }
+ if (conf.getBoolean("spark.executor.userClassPathFirst", false)) {
+ new ChildFirstURLClassLoader(urls, currentLoader)
+ } else {
+ new MutableURLClassLoader(urls, currentLoader)
}
}
@@ -311,7 +325,7 @@ private[spark] class Executor(
if (classUri != null) {
logInfo("Using REPL class URI: " + classUri)
val userClassPathFirst: java.lang.Boolean =
- conf.getBoolean("spark.files.userClassPathFirst", false)
+ conf.getBoolean("spark.executor.userClassPathFirst", false)
try {
val klass = Class.forName("org.apache.spark.repl.ExecutorClassLoader")
.asInstanceOf[Class[_ <: ClassLoader]]
@@ -344,18 +358,23 @@ private[spark] class Executor(
env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
currentFiles(name) = timestamp
}
- for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) {
- logInfo("Fetching " + name + " with timestamp " + timestamp)
- // Fetch file with useCache mode, close cache for local mode.
- Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
- env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
- currentJars(name) = timestamp
- // Add it to our class loader
+ for ((name, timestamp) <- newJars) {
val localName = name.split("/").last
- val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
- if (!urlClassLoader.getURLs.contains(url)) {
- logInfo("Adding " + url + " to class loader")
- urlClassLoader.addURL(url)
+ val currentTimeStamp = currentJars.get(name)
+ .orElse(currentJars.get(localName))
+ .getOrElse(-1L)
+ if (currentTimeStamp < timestamp) {
+ logInfo("Fetching " + name + " with timestamp " + timestamp)
+ // Fetch file with useCache mode, close cache for local mode.
+ Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf,
+ env.securityManager, hadoopConf, timestamp, useCache = !isLocal)
+ currentJars(name) = timestamp
+ // Add it to our class loader
+ val url = new File(SparkFiles.getRootDirectory, localName).toURI.toURL
+ if (!urlClassLoader.getURLs.contains(url)) {
+ logInfo("Adding " + url + " to class loader")
+ urlClassLoader.addURL(url)
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
deleted file mode 100644
index 8011e75944aac..0000000000000
--- a/core/src/main/scala/org/apache/spark/executor/ExecutorURLClassLoader.scala
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.executor
-
-import java.net.{URLClassLoader, URL}
-
-import org.apache.spark.util.ParentClassLoader
-
-/**
- * The addURL method in URLClassLoader is protected. We subclass it to make this accessible.
- * We also make changes so user classes can come before the default classes.
- */
-
-private[spark] trait MutableURLClassLoader extends ClassLoader {
- def addURL(url: URL)
- def getURLs: Array[URL]
-}
-
-private[spark] class ChildExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
- extends MutableURLClassLoader {
-
- private object userClassLoader extends URLClassLoader(urls, null){
- override def addURL(url: URL) {
- super.addURL(url)
- }
- override def findClass(name: String): Class[_] = {
- val loaded = super.findLoadedClass(name)
- if (loaded != null) {
- return loaded
- }
- try {
- super.findClass(name)
- } catch {
- case e: ClassNotFoundException => {
- parentClassLoader.loadClass(name)
- }
- }
- }
- }
-
- private val parentClassLoader = new ParentClassLoader(parent)
-
- override def findClass(name: String): Class[_] = {
- try {
- userClassLoader.findClass(name)
- } catch {
- case e: ClassNotFoundException => {
- parentClassLoader.loadClass(name)
- }
- }
- }
-
- def addURL(url: URL) {
- userClassLoader.addURL(url)
- }
-
- def getURLs() = {
- userClassLoader.getURLs()
- }
-}
-
-private[spark] class ExecutorURLClassLoader(urls: Array[URL], parent: ClassLoader)
- extends URLClassLoader(urls, parent) with MutableURLClassLoader {
-
- override def addURL(url: URL) {
- super.addURL(url)
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index d05659193b334..07b152651dedf 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -177,8 +177,8 @@ class TaskMetrics extends Serializable {
* Once https://issues.apache.org/jira/browse/SPARK-5225 is addressed,
* we can store all the different inputMetrics (one per readMethod).
*/
- private[spark] def getInputMetricsForReadMethod(readMethod: DataReadMethod):
- InputMetrics =synchronized {
+ private[spark] def getInputMetricsForReadMethod(
+ readMethod: DataReadMethod): InputMetrics = synchronized {
_inputMetrics match {
case None =>
val metrics = new InputMetrics(readMethod)
@@ -195,15 +195,18 @@ class TaskMetrics extends Serializable {
* Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
*/
private[spark] def updateShuffleReadMetrics(): Unit = synchronized {
- val merged = new ShuffleReadMetrics()
- for (depMetrics <- depsShuffleReadMetrics) {
- merged.incFetchWaitTime(depMetrics.fetchWaitTime)
- merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
- merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
- merged.incRemoteBytesRead(depMetrics.remoteBytesRead)
- merged.incRecordsRead(depMetrics.recordsRead)
+ if (!depsShuffleReadMetrics.isEmpty) {
+ val merged = new ShuffleReadMetrics()
+ for (depMetrics <- depsShuffleReadMetrics) {
+ merged.incFetchWaitTime(depMetrics.fetchWaitTime)
+ merged.incLocalBlocksFetched(depMetrics.localBlocksFetched)
+ merged.incRemoteBlocksFetched(depMetrics.remoteBlocksFetched)
+ merged.incRemoteBytesRead(depMetrics.remoteBytesRead)
+ merged.incLocalBytesRead(depMetrics.localBytesRead)
+ merged.incRecordsRead(depMetrics.recordsRead)
+ }
+ _shuffleReadMetrics = Some(merged)
}
- _shuffleReadMetrics = Some(merged)
}
private[spark] def updateInputMetrics(): Unit = synchronized {
@@ -341,6 +344,18 @@ class ShuffleReadMetrics extends Serializable {
private[spark] def incRemoteBytesRead(value: Long) = _remoteBytesRead += value
private[spark] def decRemoteBytesRead(value: Long) = _remoteBytesRead -= value
+ /**
+ * Shuffle data that was read from the local disk (as opposed to from a remote executor).
+ */
+ private var _localBytesRead: Long = _
+ def localBytesRead = _localBytesRead
+ private[spark] def incLocalBytesRead(value: Long) = _localBytesRead += value
+
+ /**
+ * Total bytes fetched in the shuffle by this task (both remote and local).
+ */
+ def totalBytesRead = _remoteBytesRead + _localBytesRead
+
/**
* Number of blocks fetched in this shuffle by this task (remote or local)
*/
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
index 1b7a5d1f1980a..8edf493780687 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala
@@ -28,12 +28,12 @@ import org.apache.spark.util.Utils
private[spark] class MetricsConfig(val configFile: Option[String]) extends Logging {
- val DEFAULT_PREFIX = "*"
- val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r
- val METRICS_CONF = "metrics.properties"
+ private val DEFAULT_PREFIX = "*"
+ private val INSTANCE_REGEX = "^(\\*|[a-zA-Z]+)\\.(.+)".r
+ private val DEFAULT_METRICS_CONF_FILENAME = "metrics.properties"
- val properties = new Properties()
- var propertyCategories: mutable.HashMap[String, Properties] = null
+ private[metrics] val properties = new Properties()
+ private[metrics] var propertyCategories: mutable.HashMap[String, Properties] = null
private def setDefaultProperties(prop: Properties) {
prop.setProperty("*.sink.servlet.class", "org.apache.spark.metrics.sink.MetricsServlet")
@@ -47,20 +47,22 @@ private[spark] class MetricsConfig(val configFile: Option[String]) extends Loggi
setDefaultProperties(properties)
// If spark.metrics.conf is not set, try to get file in class path
- var is: InputStream = null
- try {
- is = configFile match {
- case Some(f) => new FileInputStream(f)
- case None => Utils.getSparkClassLoader.getResourceAsStream(METRICS_CONF)
+ val isOpt: Option[InputStream] = configFile.map(new FileInputStream(_)).orElse {
+ try {
+ Option(Utils.getSparkClassLoader.getResourceAsStream(DEFAULT_METRICS_CONF_FILENAME))
+ } catch {
+ case e: Exception =>
+ logError("Error loading default configuration file", e)
+ None
}
+ }
- if (is != null) {
+ isOpt.foreach { is =>
+ try {
properties.load(is)
+ } finally {
+ is.close()
}
- } catch {
- case e: Exception => logError("Error loading configure file", e)
- } finally {
- if (is != null) is.close()
}
propertyCategories = subProperties(properties, INSTANCE_REGEX)
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 83e8eb71260eb..345db36630fd5 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -191,7 +191,10 @@ private[spark] class MetricsSystem private (
sinks += sink.asInstanceOf[Sink]
}
} catch {
- case e: Exception => logError("Sink class " + classPath + " cannot be instantialized", e)
+ case e: Exception => {
+ logError("Sink class " + classPath + " cannot be instantialized")
+ throw e
+ }
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala
new file mode 100644
index 0000000000000..e8b3074e8f1a6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+
+import com.codahale.metrics.{Slf4jReporter, MetricRegistry}
+
+import org.apache.spark.SecurityManager
+import org.apache.spark.metrics.MetricsSystem
+
+private[spark] class Slf4jSink(
+ val property: Properties,
+ val registry: MetricRegistry,
+ securityMgr: SecurityManager)
+ extends Sink {
+ val SLF4J_DEFAULT_PERIOD = 10
+ val SLF4J_DEFAULT_UNIT = "SECONDS"
+
+ val SLF4J_KEY_PERIOD = "period"
+ val SLF4J_KEY_UNIT = "unit"
+
+ val pollPeriod = Option(property.getProperty(SLF4J_KEY_PERIOD)) match {
+ case Some(s) => s.toInt
+ case None => SLF4J_DEFAULT_PERIOD
+ }
+
+ val pollUnit: TimeUnit = Option(property.getProperty(SLF4J_KEY_UNIT)) match {
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case None => TimeUnit.valueOf(SLF4J_DEFAULT_UNIT)
+ }
+
+ MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
+
+ val reporter: Slf4jReporter = Slf4jReporter.forRegistry(registry)
+ .convertDurationsTo(TimeUnit.MILLISECONDS)
+ .convertRatesTo(TimeUnit.SECONDS)
+ .build()
+
+ override def start() {
+ reporter.start(pollPeriod, pollUnit)
+ }
+
+ override def stop() {
+ reporter.stop()
+ }
+
+ override def report() {
+ reporter.report()
+ }
+}
+
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index e66f83bb34e30..03afc289736bb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -213,7 +213,14 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
} else {
basicBucketFunction _
}
- self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
+ if (self.partitions.length == 0) {
+ new Array[Long](buckets.length - 1)
+ } else {
+ // reduce() requires a non-empty RDD. This works because the mapPartitions will make
+ // non-empty partitions out of empty ones. But it doesn't handle the no-partitions case,
+ // which is below
+ self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 642a12c1edf6c..e2267861e79df 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -62,11 +62,11 @@ class JdbcRDD[T: ClassTag](
override def getPartitions: Array[Partition] = {
// bounds are inclusive, hence the + 1 here and - 1 on end
- val length = 1 + upperBound - lowerBound
+ val length = BigInt(1) + upperBound - lowerBound
(0 until numPartitions).map(i => {
- val start = lowerBound + ((i * length) / numPartitions).toLong
- val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1
- new JdbcPartition(i, start, end)
+ val start = lowerBound + ((i * length) / numPartitions)
+ val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
+ new JdbcPartition(i, start.toLong, end.toLong)
}).toArray
}
@@ -99,21 +99,21 @@ class JdbcRDD[T: ClassTag](
override def close() {
try {
- if (null != rs && ! rs.isClosed()) {
+ if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
- if (null != stmt && ! stmt.isClosed()) {
+ if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
- if (null != conn && ! conn.isClosed()) {
+ if (null != conn) {
conn.close()
}
logInfo("closed connection")
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index 144f679a59460..6fdfdb734d1b8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -75,4 +75,27 @@ class OrderedRDDFunctions[K : Ordering : ClassTag,
new ShuffledRDD[K, V, V](self, partitioner).setKeyOrdering(ordering)
}
+ /**
+ * Returns an RDD containing only the elements in the the inclusive range `lower` to `upper`.
+ * If the RDD has been partitioned using a `RangePartitioner`, then this operation can be
+ * performed efficiently by only scanning the partitions that might contain matching elements.
+ * Otherwise, a standard `filter` is applied to all partitions.
+ */
+ def filterByRange(lower: K, upper: K): RDD[P] = {
+
+ def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper)
+
+ val rddToFilter: RDD[P] = self.partitioner match {
+ case Some(rp: RangePartitioner[K, V]) => {
+ val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match {
+ case (l, u) => Math.min(l, u) to Math.max(l, u)
+ }
+ PartitionPruningRDD.create(self, partitionIndicies.contains)
+ }
+ case _ =>
+ self
+ }
+ rddToFilter.filter { case (k, v) => inRange(k) }
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 955b42c3baaa1..6b4f097ea9ae5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -993,6 +993,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context)
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
+ require(writer != null, "Unable to obtain RecordWriter")
var recordsWritten = 0L
try {
while (iter.hasNext) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index fe55a5124f3b6..cf0433010aa03 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -462,7 +462,13 @@ abstract class RDD[T: ClassTag](
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
- def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
+ def union(other: RDD[T]): RDD[T] = {
+ if (partitioner.isDefined && other.partitioner == partitioner) {
+ new PartitionerAwareUnionRDD(sc, Array(this, other))
+ } else {
+ new UnionRDD(sc, Array(this, other))
+ }
+ }
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
@@ -1140,6 +1146,9 @@ abstract class RDD[T: ClassTag](
* Take the first num elements of the RDD. It works by first scanning one partition, and use the
* results from that partition to estimate the number of additional partitions needed to satisfy
* the limit.
+ *
+ * @note due to complications in the internal implementation, this method will raise
+ * an exception if called on an RDD of `Nothing` or `Null`.
*/
def take(num: Int): Array[T] = {
if (num == 0) {
@@ -1252,6 +1261,10 @@ abstract class RDD[T: ClassTag](
def min()(implicit ord: Ordering[T]): T = this.reduce(ord.min)
/**
+ * @note due to complications in the internal implementation, this method will raise an
+ * exception if called on an RDD of `Nothing` or `Null`. This may be come up in practice
+ * because, for example, the type of `parallelize(Seq())` is `RDD[Nothing]`.
+ * (`parallelize(Seq())` should be avoided anyway in favor of `parallelize(Seq[T]())`.)
* @return true if and only if the RDD contains no elements at all. Note that an RDD
* may be empty even when it has at least 1 partition.
*/
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 1cfe98673773a..bc84e2351ad74 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -38,7 +38,7 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
-import org.apache.spark.util.{CallSite, EventLoop, SystemClock, Clock, Utils}
+import org.apache.spark.util._
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
/**
@@ -63,7 +63,7 @@ class DAGScheduler(
mapOutputTracker: MapOutputTrackerMaster,
blockManagerMaster: BlockManagerMaster,
env: SparkEnv,
- clock: Clock = SystemClock)
+ clock: Clock = new SystemClock())
extends Logging {
def this(sc: SparkContext, taskScheduler: TaskScheduler) = {
@@ -98,7 +98,13 @@ class DAGScheduler(
private[scheduler] val activeJobs = new HashSet[ActiveJob]
- // Contains the locations that each RDD's partitions are cached on
+ /**
+ * Contains the locations that each RDD's partitions are cached on. This map's keys are RDD ids
+ * and its values are arrays indexed by partition numbers. Each array value is the set of
+ * locations where that RDD partition is cached.
+ *
+ * All accesses to this map should be guarded by synchronizing on it (see SPARK-4454).
+ */
private val cacheLocs = new HashMap[Int, Array[Seq[TaskLocation]]]
// For tracking failed nodes, we use the MapOutputTracker's epoch number, which is sent with
@@ -126,6 +132,8 @@ class DAGScheduler(
private[scheduler] val eventProcessLoop = new DAGSchedulerEventProcessLoop(this)
taskScheduler.setDAGScheduler(this)
+ private val outputCommitCoordinator = env.outputCommitCoordinator
+
// Called by TaskScheduler to report task's starting.
def taskStarted(task: Task[_], taskInfo: TaskInfo) {
eventProcessLoop.post(BeginEvent(task, taskInfo))
@@ -181,7 +189,8 @@ class DAGScheduler(
eventProcessLoop.post(TaskSetFailed(taskSet, reason))
}
- private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = {
+ private def getCacheLocs(rdd: RDD[_]): Array[Seq[TaskLocation]] = cacheLocs.synchronized {
+ // Note: this doesn't use `getOrElse()` because this method is called O(num tasks) times
if (!cacheLocs.contains(rdd.id)) {
val blockIds = rdd.partitions.indices.map(index => RDDBlockId(rdd.id, index)).toArray[BlockId]
val locs = BlockManager.blockIdsToBlockManagers(blockIds, env, blockManagerMaster)
@@ -192,7 +201,7 @@ class DAGScheduler(
cacheLocs(rdd.id)
}
- private def clearCacheLocs() {
+ private def clearCacheLocs(): Unit = cacheLocs.synchronized {
cacheLocs.clear()
}
@@ -648,7 +657,7 @@ class DAGScheduler(
// completion events or stage abort
stageIdToStage -= s.id
jobIdToStageIds -= job.jobId
- listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), jobResult))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), jobResult))
}
}
@@ -697,7 +706,7 @@ class DAGScheduler(
stage.latestInfo.stageFailed(stageFailedMessage)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
}
- listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error)))
}
}
@@ -736,7 +745,7 @@ class DAGScheduler(
logInfo("Missing parents: " + getMissingParentStages(finalStage))
val shouldRunLocally =
localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
- val jobSubmissionTime = clock.getTime()
+ val jobSubmissionTime = clock.getTimeMillis()
if (shouldRunLocally) {
// Compute very short actions like first() or take() with no parent stages locally.
listenerBus.post(
@@ -808,6 +817,7 @@ class DAGScheduler(
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
// event.
stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
+ outputCommitCoordinator.stageStart(stage.id)
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
@@ -861,10 +871,11 @@ class DAGScheduler(
logDebug("New pending tasks: " + stage.pendingTasks)
taskScheduler.submitTasks(
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
- stage.latestInfo.submissionTime = Some(clock.getTime())
+ stage.latestInfo.submissionTime = Some(clock.getTimeMillis())
} else {
// Because we posted SparkListenerStageSubmitted earlier, we should post
// SparkListenerStageCompleted here in case there are no tasks to run.
+ outputCommitCoordinator.stageEnd(stage.id)
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
@@ -879,8 +890,16 @@ class DAGScheduler(
if (event.accumUpdates != null) {
try {
Accumulators.add(event.accumUpdates)
+
event.accumUpdates.foreach { case (id, partialValue) =>
- val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
+ // In this instance, although the reference in Accumulators.originals is a WeakRef,
+ // it's guaranteed to exist since the event.accumUpdates Map exists
+
+ val acc = Accumulators.originals(id).get match {
+ case Some(accum) => accum.asInstanceOf[Accumulable[Any, Any]]
+ case None => throw new NullPointerException("Non-existent reference to Accumulator")
+ }
+
// To avoid UI cruft, ignore cases where value wasn't updated
if (acc.name.isDefined && partialValue != acc.zero) {
val name = acc.name.get
@@ -909,6 +928,9 @@ class DAGScheduler(
val stageId = task.stageId
val taskType = Utils.getFormattedClassName(task)
+ outputCommitCoordinator.taskCompleted(stageId, task.partitionId,
+ event.taskInfo.attempt, event.reason)
+
// The success case is dealt with separately below, since we need to compute accumulator
// updates before posting.
if (event.reason != Success) {
@@ -921,16 +943,17 @@ class DAGScheduler(
// Skip all the actions if the stage has been cancelled.
return
}
+
val stage = stageIdToStage(task.stageId)
def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None) = {
val serviceTime = stage.latestInfo.submissionTime match {
- case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0)
+ case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
case _ => "Unknown"
}
if (errorMessage.isEmpty) {
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
- stage.latestInfo.completionTime = Some(clock.getTime())
+ stage.latestInfo.completionTime = Some(clock.getTimeMillis())
} else {
stage.latestInfo.stageFailed(errorMessage.get)
logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
@@ -956,7 +979,7 @@ class DAGScheduler(
markStageAsFinished(stage)
cleanupStateForJobAndIndependentStages(job)
listenerBus.post(
- SparkListenerJobEnd(job.jobId, clock.getTime(), JobSucceeded))
+ SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobSucceeded))
}
// taskSucceeded runs some user code that might throw an exception. Make sure
@@ -1073,6 +1096,9 @@ class DAGScheduler(
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
+ case commitDenied: TaskCommitDenied =>
+ // Do nothing here, left up to the TaskScheduler to decide how to handle denied commits
+
case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
@@ -1169,7 +1195,7 @@ class DAGScheduler(
}
val dependentJobs: Seq[ActiveJob] =
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
- failedStage.latestInfo.completionTime = Some(clock.getTime())
+ failedStage.latestInfo.completionTime = Some(clock.getTimeMillis())
for (job <- dependentJobs) {
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
}
@@ -1224,7 +1250,7 @@ class DAGScheduler(
if (ableToCancelStages) {
job.listener.jobFailed(error)
cleanupStateForJobAndIndependentStages(job)
- listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTime(), JobFailed(error)))
+ listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error)))
}
}
@@ -1265,17 +1291,26 @@ class DAGScheduler(
}
/**
- * Synchronized method that might be called from other threads.
+ * Gets the locality information associated with a partition of a particular RDD.
+ *
+ * This method is thread-safe and is called from both DAGScheduler and SparkContext.
+ *
* @param rdd whose partitions are to be looked at
* @param partition to lookup locality information for
* @return list of machines that are preferred by the partition
*/
private[spark]
- def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = synchronized {
+ def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
getPreferredLocsInternal(rdd, partition, new HashSet)
}
- /** Recursive implementation for getPreferredLocs. */
+ /**
+ * Recursive implementation for getPreferredLocs.
+ *
+ * This method is thread-safe because it only accesses DAGScheduler state through thread-safe
+ * methods (getCacheLocs()); please be careful when modifying this method, because any new
+ * DAGScheduler state accessed by it may require additional synchronization.
+ */
private def getPreferredLocsInternal(
rdd: RDD[_],
partition: Int,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 3bb54855bae44..8aa528ac573d0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -169,7 +169,8 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
" BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched +
" BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched +
" REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime +
- " REMOTE_BYTES_READ=" + metrics.remoteBytesRead
+ " REMOTE_BYTES_READ=" + metrics.remoteBytesRead +
+ " LOCAL_BYTES_READ=" + metrics.localBytesRead
case None => ""
}
val writeMetrics = taskMetrics.shuffleWriteMetrics match {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
new file mode 100644
index 0000000000000..759df023a6dcf
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import scala.collection.mutable
+
+import akka.actor.{ActorRef, Actor}
+
+import org.apache.spark._
+import org.apache.spark.util.{AkkaUtils, ActorLogReceive}
+
+private sealed trait OutputCommitCoordinationMessage extends Serializable
+
+private case object StopCoordinator extends OutputCommitCoordinationMessage
+private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttempt: Long)
+
+/**
+ * Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
+ * policy.
+ *
+ * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is
+ * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit
+ * output will be forwarded to the driver's OutputCommitCoordinator.
+ *
+ * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests)
+ * for an extensive design discussion.
+ */
+private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging {
+
+ // Initialized by SparkEnv
+ var coordinatorActor: Option[ActorRef] = None
+ private val timeout = AkkaUtils.askTimeout(conf)
+ private val maxAttempts = AkkaUtils.numRetries(conf)
+ private val retryInterval = AkkaUtils.retryWaitMs(conf)
+
+ private type StageId = Int
+ private type PartitionId = Long
+ private type TaskAttemptId = Long
+
+ /**
+ * Map from active stages's id => partition id => task attempt with exclusive lock on committing
+ * output for that partition.
+ *
+ * Entries are added to the top-level map when stages start and are removed they finish
+ * (either successfully or unsuccessfully).
+ *
+ * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
+ */
+ private val authorizedCommittersByStage: CommittersByStageMap = mutable.Map()
+ private type CommittersByStageMap = mutable.Map[StageId, mutable.Map[PartitionId, TaskAttemptId]]
+
+ /**
+ * Called by tasks to ask whether they can commit their output to HDFS.
+ *
+ * If a task attempt has been authorized to commit, then all other attempts to commit the same
+ * task will be denied. If the authorized task attempt fails (e.g. due to its executor being
+ * lost), then a subsequent task attempt may be authorized to commit its output.
+ *
+ * @param stage the stage number
+ * @param partition the partition number
+ * @param attempt a unique identifier for this task attempt
+ * @return true if this task is authorized to commit, false otherwise
+ */
+ def canCommit(
+ stage: StageId,
+ partition: PartitionId,
+ attempt: TaskAttemptId): Boolean = {
+ val msg = AskPermissionToCommitOutput(stage, partition, attempt)
+ coordinatorActor match {
+ case Some(actor) =>
+ AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout)
+ case None =>
+ logError(
+ "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?")
+ false
+ }
+ }
+
+ // Called by DAGScheduler
+ private[scheduler] def stageStart(stage: StageId): Unit = synchronized {
+ authorizedCommittersByStage(stage) = mutable.HashMap[PartitionId, TaskAttemptId]()
+ }
+
+ // Called by DAGScheduler
+ private[scheduler] def stageEnd(stage: StageId): Unit = synchronized {
+ authorizedCommittersByStage.remove(stage)
+ }
+
+ // Called by DAGScheduler
+ private[scheduler] def taskCompleted(
+ stage: StageId,
+ partition: PartitionId,
+ attempt: TaskAttemptId,
+ reason: TaskEndReason): Unit = synchronized {
+ val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, {
+ logDebug(s"Ignoring task completion for completed stage")
+ return
+ })
+ reason match {
+ case Success =>
+ // The task output has been committed successfully
+ case denied: TaskCommitDenied =>
+ logInfo(
+ s"Task was denied committing, stage: $stage, partition: $partition, attempt: $attempt")
+ case otherReason =>
+ logDebug(s"Authorized committer $attempt (stage=$stage, partition=$partition) failed;" +
+ s" clearing lock")
+ authorizedCommitters.remove(partition)
+ }
+ }
+
+ def stop(): Unit = synchronized {
+ coordinatorActor.foreach(_ ! StopCoordinator)
+ coordinatorActor = None
+ authorizedCommittersByStage.clear()
+ }
+
+ // Marked private[scheduler] instead of private so this can be mocked in tests
+ private[scheduler] def handleAskPermissionToCommit(
+ stage: StageId,
+ partition: PartitionId,
+ attempt: TaskAttemptId): Boolean = synchronized {
+ authorizedCommittersByStage.get(stage) match {
+ case Some(authorizedCommitters) =>
+ authorizedCommitters.get(partition) match {
+ case Some(existingCommitter) =>
+ logDebug(s"Denying $attempt to commit for stage=$stage, partition=$partition; " +
+ s"existingCommitter = $existingCommitter")
+ false
+ case None =>
+ logDebug(s"Authorizing $attempt to commit for stage=$stage, partition=$partition")
+ authorizedCommitters(partition) = attempt
+ true
+ }
+ case None =>
+ logDebug(s"Stage $stage has completed, so not allowing task attempt $attempt to commit")
+ false
+ }
+ }
+}
+
+private[spark] object OutputCommitCoordinator {
+
+ // This actor is used only for RPC
+ class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator)
+ extends Actor with ActorLogReceive with Logging {
+
+ override def receiveWithLogging = {
+ case AskPermissionToCommitOutput(stage, partition, taskAttempt) =>
+ sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)
+ case StopCoordinator =>
+ logInfo("OutputCommitCoordinator stopped!")
+ context.stop(self)
+ sender ! true
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
index 584f4e7789d1a..d9c3a10dc5413 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ReplayListenerBus.scala
@@ -40,21 +40,24 @@ private[spark] class ReplayListenerBus extends SparkListenerBus with Logging {
*
* @param logData Stream containing event log data.
* @param version Spark version that generated the events.
+ * @param sourceName Filename (or other source identifier) from whence @logData is being read
*/
- def replay(logData: InputStream, version: String) {
+ def replay(logData: InputStream, version: String, sourceName: String) {
var currentLine: String = null
+ var lineNumber: Int = 1
try {
val lines = Source.fromInputStream(logData).getLines()
lines.foreach { line =>
currentLine = line
postToAll(JsonProtocol.sparkEventFromJson(parse(line)))
+ lineNumber += 1
}
} catch {
case ioe: IOException =>
throw ioe
case e: Exception =>
- logError("Exception in parsing Spark event log.", e)
- logError("Malformed line: %s\n".format(currentLine))
+ logError(s"Exception parsing Spark event log: $sourceName", e)
+ logError(s"Malformed line #$lineNumber: $currentLine\n")
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 774f3d8cdb275..3938580aeea59 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -18,6 +18,7 @@
package org.apache.spark.scheduler
import java.nio.ByteBuffer
+import java.util.concurrent.RejectedExecutionException
import scala.language.existentials
import scala.util.control.NonFatal
@@ -95,25 +96,30 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
serializedData: ByteBuffer) {
var reason : TaskEndReason = UnknownReason
- getTaskResultExecutor.execute(new Runnable {
- override def run(): Unit = Utils.logUncaughtExceptions {
- try {
- if (serializedData != null && serializedData.limit() > 0) {
- reason = serializer.get().deserialize[TaskEndReason](
- serializedData, Utils.getSparkClassLoader)
+ try {
+ getTaskResultExecutor.execute(new Runnable {
+ override def run(): Unit = Utils.logUncaughtExceptions {
+ try {
+ if (serializedData != null && serializedData.limit() > 0) {
+ reason = serializer.get().deserialize[TaskEndReason](
+ serializedData, Utils.getSparkClassLoader)
+ }
+ } catch {
+ case cnd: ClassNotFoundException =>
+ // Log an error but keep going here -- the task failed, so not catastrophic
+ // if we can't deserialize the reason.
+ val loader = Utils.getContextOrSparkClassLoader
+ logError(
+ "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
+ case ex: Exception => {}
}
- } catch {
- case cnd: ClassNotFoundException =>
- // Log an error but keep going here -- the task failed, so not catastrophic if we can't
- // deserialize the reason.
- val loader = Utils.getContextOrSparkClassLoader
- logError(
- "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
- case ex: Exception => {}
+ scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
- scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
- }
- })
+ })
+ } catch {
+ case e: RejectedExecutionException if sparkEnv.isStopped =>
+ // ignore it
+ }
}
def stop() {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
index f095915352b17..ed3418676e077 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -73,5 +73,9 @@ private[spark] trait TaskScheduler {
* @return An application ID
*/
def applicationId(): String = appId
-
+
+ /**
+ * Process a lost executor
+ */
+ def executorLost(executorId: String, reason: ExecutorLossReason): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 79f84e70df9d5..7a9cf1c2e7f30 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -158,7 +158,7 @@ private[spark] class TaskSchedulerImpl(
val tasks = taskSet.tasks
logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
this.synchronized {
- val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
+ val manager = createTaskSetManager(taskSet, maxTaskFailures)
activeTaskSets(taskSet.id) = manager
schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
@@ -180,6 +180,13 @@ private[spark] class TaskSchedulerImpl(
backend.reviveOffers()
}
+ // Label as private[scheduler] to allow tests to swap in different task set managers if necessary
+ private[scheduler] def createTaskSetManager(
+ taskSet: TaskSet,
+ maxTaskFailures: Int): TaskSetManager = {
+ new TaskSetManager(this, taskSet, maxTaskFailures)
+ }
+
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
logInfo("Cancelling stage " + stageId)
activeTaskSets.find(_._2.stageId == stageId).foreach { case (_, tsm) =>
@@ -429,7 +436,7 @@ private[spark] class TaskSchedulerImpl(
}
}
- def executorLost(executorId: String, reason: ExecutorLossReason) {
+ override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {
var failedExecutor: Option[String] = None
synchronized {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 55024ecd55e61..529237f0d35dc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -51,7 +51,7 @@ private[spark] class TaskSetManager(
sched: TaskSchedulerImpl,
val taskSet: TaskSet,
val maxTaskFailures: Int,
- clock: Clock = SystemClock)
+ clock: Clock = new SystemClock())
extends Schedulable with Logging {
val conf = sched.sc.conf
@@ -166,7 +166,7 @@ private[spark] class TaskSetManager(
// last launched a task at that level, and move up a level when localityWaits[curLevel] expires.
// We then move down if we manage to launch a "more local" task.
var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels
- var lastLaunchTime = clock.getTime() // Time we last launched a task at this level
+ var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level
override def schedulableQueue = null
@@ -281,7 +281,7 @@ private[spark] class TaskSetManager(
val failed = failedExecutors.get(taskId).get
return failed.contains(execId) &&
- clock.getTime() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT
+ clock.getTimeMillis() - failed.get(execId).get < EXECUTOR_TASK_BLACKLIST_TIMEOUT
}
false
@@ -292,7 +292,8 @@ private[spark] class TaskSetManager(
* an attempt running on this host, in case the host is slow. In addition, the task should meet
* the given locality constraint.
*/
- private def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
+ // Labeled as protected to allow tests to override providing speculative tasks if necessary
+ protected def dequeueSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
: Option[(Int, TaskLocality.Value)] =
{
speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
@@ -427,7 +428,7 @@ private[spark] class TaskSetManager(
: Option[TaskDescription] =
{
if (!isZombie) {
- val curTime = clock.getTime()
+ val curTime = clock.getTimeMillis()
var allowedLocality = maxLocality
@@ -458,7 +459,7 @@ private[spark] class TaskSetManager(
lastLaunchTime = curTime
}
// Serialize and return the task
- val startTime = clock.getTime()
+ val startTime = clock.getTimeMillis()
val serializedTask: ByteBuffer = try {
Task.serializeWithDependencies(task, sched.sc.addedFiles, sched.sc.addedJars, ser)
} catch {
@@ -673,7 +674,7 @@ private[spark] class TaskSetManager(
return
}
val key = ef.description
- val now = clock.getTime()
+ val now = clock.getTimeMillis()
val (printFull, dupCount) = {
if (recentExceptions.contains(key)) {
val (dupCount, printTime) = recentExceptions(key)
@@ -705,10 +706,13 @@ private[spark] class TaskSetManager(
}
// always add to failed executors
failedExecutors.getOrElseUpdate(index, new HashMap[String, Long]()).
- put(info.executorId, clock.getTime())
+ put(info.executorId, clock.getTimeMillis())
sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
addPendingTask(index)
- if (!isZombie && state != TaskState.KILLED) {
+ if (!isZombie && state != TaskState.KILLED && !reason.isInstanceOf[TaskCommitDenied]) {
+ // If a task failed because its attempt to commit was denied, do not count this failure
+ // towards failing the stage. This is intended to prevent spurious stage failures in cases
+ // where many speculative tasks are launched and denied to commit.
assert (null != failureReason)
numFailures(index) += 1
if (numFailures(index) >= maxTaskFailures) {
@@ -817,7 +821,7 @@ private[spark] class TaskSetManager(
val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
if (tasksSuccessful >= minFinishedForSpeculation && tasksSuccessful > 0) {
- val time = clock.getTime()
+ val time = clock.getTimeMillis()
val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
Arrays.sort(durations)
val medianDuration = durations(min((0.5 * tasksSuccessful).round.toInt, durations.size - 1))
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index f9ca93432bf41..6f77fa32ce37b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -311,7 +311,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
/**
* Request an additional number of executors from the cluster manager.
- * Return whether the request is acknowledged.
+ * @return whether the request is acknowledged.
*/
final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
if (numAdditionalExecutors < 0) {
@@ -327,6 +327,22 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
doRequestTotalExecutors(newTotal)
}
+ /**
+ * Express a preference to the cluster manager for a given total number of executors. This can
+ * result in canceling pending requests or filing additional requests.
+ * @return whether the request is acknowledged.
+ */
+ final override def requestTotalExecutors(numExecutors: Int): Boolean = synchronized {
+ if (numExecutors < 0) {
+ throw new IllegalArgumentException(
+ "Attempted to request a negative number of executor(s) " +
+ s"$numExecutors from the cluster manager. Please specify a positive number!")
+ }
+ numPendingExecutors =
+ math.max(numExecutors - numExistingExecutors + executorsPendingToRemove.size, 0)
+ doRequestTotalExecutors(numExecutors)
+ }
+
/**
* Request executors from the cluster manager by specifying the total number desired,
* including existing pending and running executors.
@@ -337,7 +353,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
* insufficient resources to satisfy the first request. We make the assumption here that the
* cluster manager will eventually fulfill all requests when resources free up.
*
- * Return whether the request is acknowledged.
+ * @return whether the request is acknowledged.
*/
protected def doRequestTotalExecutors(requestedTotal: Int): Boolean = false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index d2e1680a5fd1b..a0aa555f6244f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -17,6 +17,8 @@
package org.apache.spark.scheduler.cluster
+import java.util.concurrent.Semaphore
+
import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv}
import org.apache.spark.deploy.{ApplicationDescription, Command}
import org.apache.spark.deploy.client.{AppClient, AppClientListener}
@@ -31,16 +33,16 @@ private[spark] class SparkDeploySchedulerBackend(
with AppClientListener
with Logging {
- var client: AppClient = null
- var stopping = false
- var shutdownCallback : (SparkDeploySchedulerBackend) => Unit = _
- @volatile var appId: String = _
+ private var client: AppClient = null
+ private var stopping = false
+
+ @volatile var shutdownCallback: SparkDeploySchedulerBackend => Unit = _
+ @volatile private var appId: String = _
- val registrationLock = new Object()
- var registrationDone = false
+ private val registrationBarrier = new Semaphore(0)
- val maxCores = conf.getOption("spark.cores.max").map(_.toInt)
- val totalExpectedCores = maxCores.getOrElse(0)
+ private val maxCores = conf.getOption("spark.cores.max").map(_.toInt)
+ private val totalExpectedCores = maxCores.getOrElse(0)
override def start() {
super.start()
@@ -52,8 +54,13 @@ private[spark] class SparkDeploySchedulerBackend(
conf.get("spark.driver.host"),
conf.get("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
- val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}", "{{APP_ID}}",
- "{{WORKER_URL}}")
+ val args = Seq(
+ "--driver-url", driverUrl,
+ "--executor-id", "{{EXECUTOR_ID}}",
+ "--hostname", "{{HOSTNAME}}",
+ "--cores", "{{CORES}}",
+ "--app-id", "{{APP_ID}}",
+ "--worker-url", "{{WORKER_URL}}")
val extraJavaOpts = sc.conf.getOption("spark.executor.extraJavaOptions")
.map(Utils.splitCommandString).getOrElse(Seq.empty)
val classPathEntries = sc.conf.getOption("spark.executor.extraClassPath")
@@ -90,8 +97,10 @@ private[spark] class SparkDeploySchedulerBackend(
stopping = true
super.stop()
client.stop()
- if (shutdownCallback != null) {
- shutdownCallback(this)
+
+ val callback = shutdownCallback
+ if (callback != null) {
+ callback(this)
}
}
@@ -144,18 +153,11 @@ private[spark] class SparkDeploySchedulerBackend(
}
private def waitForRegistration() = {
- registrationLock.synchronized {
- while (!registrationDone) {
- registrationLock.wait()
- }
- }
+ registrationBarrier.acquire()
}
private def notifyContext() = {
- registrationLock.synchronized {
- registrationDone = true
- registrationLock.notifyAll()
- }
+ registrationBarrier.release()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index 0d1c2a916ca7f..90dfe14352a8e 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -154,18 +154,25 @@ private[spark] class CoarseMesosSchedulerBackend(
if (uri == null) {
val runScript = new File(executorSparkHome, "./bin/spark-class").getCanonicalPath
command.setValue(
- "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s".format(
- prefixEnv, runScript, driverUrl, offer.getSlaveId.getValue,
- offer.getHostname, numCores, appId))
+ "%s \"%s\" org.apache.spark.executor.CoarseGrainedExecutorBackend"
+ .format(prefixEnv, runScript) +
+ s" --driver-url $driverUrl" +
+ s" --executor-id ${offer.getSlaveId.getValue}" +
+ s" --hostname ${offer.getHostname}" +
+ s" --cores $numCores" +
+ s" --app-id $appId")
} else {
// Grab everything to the first '.'. We'll use that and '*' to
// glob the directory "correctly".
val basename = uri.split('/').last.split('.').head
command.setValue(
- ("cd %s*; %s " +
- "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend %s %s %s %d %s")
- .format(basename, prefixEnv, driverUrl, offer.getSlaveId.getValue,
- offer.getHostname, numCores, appId))
+ s"cd $basename*; $prefixEnv " +
+ "./bin/spark-class org.apache.spark.executor.CoarseGrainedExecutorBackend" +
+ s" --driver-url $driverUrl" +
+ s" --executor-id ${offer.getSlaveId.getValue}" +
+ s" --hostname ${offer.getHostname}" +
+ s" --cores $numCores" +
+ s" --app-id $appId")
command.addUris(CommandInfo.URI.newBuilder().setValue(uri))
}
command.build()
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 64133464d8daa..787b0f96bec32 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -24,7 +24,7 @@ import scala.collection.JavaConversions._
import scala.concurrent.Future
import scala.concurrent.duration._
-import akka.actor.{Actor, ActorRef, Cancellable}
+import akka.actor.{Actor, ActorRef}
import akka.pattern.ask
import org.apache.spark.{Logging, SparkConf, SparkException}
@@ -52,19 +52,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
private val akkaTimeout = AkkaUtils.askTimeout(conf)
- val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120 * 1000)
-
- val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000)
-
- var timeoutCheckingTask: Cancellable = null
-
- override def preStart() {
- import context.dispatcher
- timeoutCheckingTask = context.system.scheduler.schedule(0.seconds,
- checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
- super.preStart()
- }
-
override def receiveWithLogging = {
case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) =>
register(blockManagerId, maxMemSize, slaveActor)
@@ -118,14 +105,8 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case StopBlockManagerMaster =>
sender ! true
- if (timeoutCheckingTask != null) {
- timeoutCheckingTask.cancel()
- }
context.stop(self)
- case ExpireDeadHosts =>
- expireDeadHosts()
-
case BlockManagerHeartbeat(blockManagerId) =>
sender ! heartbeatReceived(blockManagerId)
@@ -207,21 +188,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
logInfo(s"Removing block manager $blockManagerId")
}
- private def expireDeadHosts() {
- logTrace("Checking for hosts with no recent heart beats in BlockManagerMaster.")
- val now = System.currentTimeMillis()
- val minSeenTime = now - slaveTimeout
- val toRemove = new mutable.HashSet[BlockManagerId]
- for (info <- blockManagerInfo.values) {
- if (info.lastSeenMs < minSeenTime && !info.blockManagerId.isDriver) {
- logWarning("Removing BlockManager " + info.blockManagerId + " with no recent heart beats: "
- + (now - info.lastSeenMs) + "ms exceeds " + slaveTimeout + "ms")
- toRemove += info.blockManagerId
- }
- }
- toRemove.foreach(removeBlockManager)
- }
-
private def removeExecutor(execId: String) {
logInfo("Trying to remove executor " + execId + " from BlockManagerMaster.")
blockManagerIdByExecutor.get(execId).foreach(removeBlockManager)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 3f32099d08cc9..48247453edef0 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -109,6 +109,4 @@ private[spark] object BlockManagerMessages {
extends ToBlockManagerMaster
case class BlockManagerHeartbeat(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
-
- case object ExpireDeadHosts extends ToBlockManagerMaster
}
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 53eaedacbf291..12cd8ea3bdf1f 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -49,7 +49,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
private val subDirs = Array.fill(localDirs.length)(new Array[File](subDirsPerLocalDir))
- addShutdownHook()
+ private val shutdownHook = addShutdownHook()
/** Looks up a file by hashing it into one of our local subdirectories. */
// This method should be kept in sync with
@@ -134,17 +134,29 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
}
- private def addShutdownHook() {
- Runtime.getRuntime.addShutdownHook(new Thread("delete Spark local dirs") {
+ private def addShutdownHook(): Thread = {
+ val shutdownHook = new Thread("delete Spark local dirs") {
override def run(): Unit = Utils.logUncaughtExceptions {
logDebug("Shutdown hook called")
- DiskBlockManager.this.stop()
+ DiskBlockManager.this.doStop()
}
- })
+ }
+ Runtime.getRuntime.addShutdownHook(shutdownHook)
+ shutdownHook
}
/** Cleanup local dirs and stop shuffle sender. */
private[spark] def stop() {
+ // Remove the shutdown hook. It causes memory leaks if we leave it around.
+ try {
+ Runtime.getRuntime.removeShutdownHook(shutdownHook)
+ } catch {
+ case e: IllegalStateException => None
+ }
+ doStop()
+ }
+
+ private def doStop(): Unit = {
// Only perform cleanup if an external service is not serving our shuffle files.
if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) {
localDirs.foreach { localDir =>
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index ab9ee4f0096bf..8f28ef49a8a6f 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -234,6 +234,7 @@ final class ShuffleBlockFetcherIterator(
try {
val buf = blockManager.getBlockData(blockId)
shuffleMetrics.incLocalBlocksFetched(1)
+ shuffleMetrics.incLocalBytesRead(buf.size)
buf.retain()
results.put(new SuccessFetchResult(blockId, 0, buf))
} catch {
diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
index 27ba9e18237b5..67f572e79314d 100644
--- a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
@@ -28,7 +28,6 @@ import org.apache.spark._
* of them will be combined together, showed in one line.
*/
private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
-
// Carrige return
val CR = '\r'
// Update period of progress bar, in milliseconds
@@ -121,4 +120,10 @@ private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
clear()
lastFinishTime = System.currentTimeMillis()
}
+
+ /**
+ * Tear down the timer thread. The timer thread is a GC root, and it retains the entire
+ * SparkContext if it's not terminated.
+ */
+ def stop(): Unit = timer.cancel()
}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index bf4b24e98b134..95f254a9ef22a 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -80,6 +80,10 @@ private[spark] object JettyUtils extends Logging {
response.sendError(HttpServletResponse.SC_BAD_REQUEST, e.getMessage)
}
}
+ // SPARK-5983 ensure TRACE is not supported
+ protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = {
+ res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
+ }
}
}
@@ -119,6 +123,10 @@ private[spark] object JettyUtils extends Logging {
val newUrl = new URL(new URL(request.getRequestURL.toString), prefixedDestPath).toString
response.sendRedirect(newUrl)
}
+ // SPARK-5983 ensure TRACE is not supported
+ protected override def doTrace(req: HttpServletRequest, res: HttpServletResponse): Unit = {
+ res.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED)
+ }
}
createServletHandler(srcPath, servlet, basePath)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index 3a15e603b1969..cae6870c2ab20 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -37,8 +37,12 @@ private[spark] object ToolTips {
"Bytes and records written to disk in order to be read by a shuffle in a future stage."
val SHUFFLE_READ =
- """Bytes and records read from remote executors. Typically less than shuffle write bytes
- because this does not include shuffle data read locally."""
+ """Total shuffle bytes and records read (includes both data read locally and data read from
+ remote executors). """
+
+ val SHUFFLE_READ_REMOTE_SIZE =
+ """Total shuffle bytes read from remote executors. This is a subset of the shuffle
+ read bytes; the remaining shuffle data is read locally. """
val GETTING_RESULT_TIME =
"""Time that the driver spends fetching task results from workers. If this is large, consider
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
index 045c69da06feb..bd923d78a86ce 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -42,7 +42,9 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
}
def makeRow(job: JobUIData): Seq[Node] = {
- val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max)
+ val lastStageInfo = Option(job.stageIds)
+ .filter(_.nonEmpty)
+ .flatMap { ids => listener.stageIdToInfo.get(ids.max) }
val lastStageData = lastStageInfo.flatMap { s =>
listener.stageIdToData.get((s.stageId, s.attemptId))
}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index f463f8d7c7215..937d95a934b59 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -203,6 +203,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
for (stageId <- jobData.stageIds) {
stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage =>
jobsUsingStage.remove(jobEnd.jobId)
+ if (jobsUsingStage.isEmpty) {
+ stageIdToActiveJobIds.remove(stageId)
+ }
stageIdToInfo.get(stageId).foreach { stageInfo =>
if (stageInfo.submissionTime.isEmpty) {
// if this stage is pending, it won't complete, so mark it as "skipped":
@@ -401,9 +404,9 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta
val shuffleReadDelta =
- (taskMetrics.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L)
- - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.remoteBytesRead).getOrElse(0L))
- stageData.shuffleReadBytes += shuffleReadDelta
+ (taskMetrics.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L)
+ - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.totalBytesRead).getOrElse(0L))
+ stageData.shuffleReadTotalBytes += shuffleReadDelta
execSummary.shuffleRead += shuffleReadDelta
val shuffleReadRecordsDelta =
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 05ffd5bc58fbb..110f8780a9a12 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -85,7 +85,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
{if (stageData.hasShuffleRead) {
@@ -109,23 +101,6 @@ of the most common options to set are:
Number of cores to use for the driver process, only in cluster mode.
-
-
spark.driver.memory
-
512m
-
- Amount of memory to use for the driver process, i.e. where SparkContext is initialized.
- (e.g. 512m, 2g).
-
-
-
-
spark.executor.memory
-
512m
-
- Amount of memory to use per executor process, in the same format as JVM memory strings
- (e.g. 512m, 2g).
-
-
-
spark.driver.maxResultSize
1g
@@ -138,38 +113,35 @@ of the most common options to set are:
-
spark.serializer
-
org.apache.spark.serializer. JavaSerializer
+
spark.driver.memory
+
512m
- Class to use for serializing objects that will be sent over the network or need to be cached
- in serialized form. The default of Java serialization works with any Serializable Java object
- but is quite slow, so we recommend using
- org.apache.spark.serializer.KryoSerializer and configuring Kryo serialization
- when speed is necessary. Can be any subclass of
-
- org.apache.spark.Serializer.
+ Amount of memory to use for the driver process, i.e. where SparkContext is initialized.
+ (e.g. 512m, 2g).
+
+ Note: In client mode, this config must not be set through the SparkConf
+ directly in your application, because the driver JVM has already started at that point.
+ Instead, please set this through the --driver-memory command line option
+ or in your default properties file.
-
spark.kryo.classesToRegister
-
(none)
+
spark.executor.memory
+
512m
- If you use Kryo serialization, give a comma-separated list of custom class names to register
- with Kryo.
- See the tuning guide for more details.
+ Amount of memory to use per executor process, in the same format as JVM memory strings
+ (e.g. 512m, 2g).
-
spark.kryo.registrator
+
spark.extraListeners
(none)
- If you use Kryo serialization, set this class to register your custom classes with Kryo. This
- property is useful if you need to register your classes in a custom way, e.g. to specify a custom
- field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be
- set to a class that extends
-
- KryoRegistrator.
- See the tuning guide for more details.
+ A comma-separated list of classes that implement SparkListener; when initializing
+ SparkContext, instances of these classes will be created and registered with Spark's listener
+ bus. If a class has a single-argument constructor that accepts a SparkConf, that constructor
+ will be called; otherwise, a zero-argument constructor will be called. If no valid constructor
+ can be found, the SparkContext creation will fail with an exception.
@@ -192,14 +164,11 @@ of the most common options to set are:
-
spark.extraListeners
+
spark.master
(none)
- A comma-separated list of classes that implement SparkListener; when initializing
- SparkContext, instances of these classes will be created and registered with Spark's listener
- bus. If a class has a single-argument constructor that accepts a SparkConf, that constructor
- will be called; otherwise, a zero-argument constructor will be called. If no valid constructor
- can be found, the SparkContext creation will fail with an exception.
+ The cluster manager to connect to. See the list of
+ allowed master URL's.
@@ -210,17 +179,27 @@ Apart from these, the following properties are also available, and may be useful
Property Name
Default
Meaning
-
spark.driver.extraJavaOptions
+
spark.driver.extraClassPath
(none)
- A string of extra JVM options to pass to the driver. For instance, GC settings or other logging.
+ Extra classpath entries to append to the classpath of the driver.
+
+ Note: In client mode, this config must not be set through the SparkConf
+ directly in your application, because the driver JVM has already started at that point.
+ Instead, please set this through the --driver-class-path command line option or in
+ your default properties file.
-
spark.driver.extraClassPath
+
spark.driver.extraJavaOptions
(none)
- Extra classpath entries to append to the classpath of the driver.
+ A string of extra JVM options to pass to the driver. For instance, GC settings or other logging.
+
+ Note: In client mode, this config must not be set through the SparkConf
+ directly in your application, because the driver JVM has already started at that point.
+ Instead, please set this through the --driver-java-options command line option or in
+ your default properties file.
@@ -228,54 +207,56 @@ Apart from these, the following properties are also available, and may be useful
(none)
Set a special library path to use when launching the driver JVM.
+
+ Note: In client mode, this config must not be set through the SparkConf
+ directly in your application, because the driver JVM has already started at that point.
+ Instead, please set this through the --driver-library-path command line option or in
+ your default properties file.
-
spark.executor.extraJavaOptions
-
(none)
+
spark.driver.userClassPathFirst
+
false
- A string of extra JVM options to pass to executors. For instance, GC settings or other
- logging. Note that it is illegal to set Spark properties or heap size settings with this
- option. Spark properties should be set using a SparkConf object or the
- spark-defaults.conf file used with the spark-submit script. Heap size settings can be set
- with spark.executor.memory.
+ (Experimental) Whether to give user-added jars precedence over Spark's own jars when loading
+ classes in the the driver. This feature can be used to mitigate conflicts between Spark's
+ dependencies and user dependencies. It is currently an experimental feature.
+
+ This is used in cluster mode only.
spark.executor.extraClassPath
(none)
- Extra classpath entries to append to the classpath of executors. This exists primarily
- for backwards-compatibility with older versions of Spark. Users typically should not need
- to set this option.
+ Extra classpath entries to append to the classpath of executors. This exists primarily for
+ backwards-compatibility with older versions of Spark. Users typically should not need to set
+ this option.
-
spark.executor.extraLibraryPath
+
spark.executor.extraJavaOptions
(none)
- Set a special library path to use when launching executor JVM's.
+ A string of extra JVM options to pass to executors. For instance, GC settings or other logging.
+ Note that it is illegal to set Spark properties or heap size settings with this option. Spark
+ properties should be set using a SparkConf object or the spark-defaults.conf file used with the
+ spark-submit script. Heap size settings can be set with spark.executor.memory.
-
spark.executor.logs.rolling.strategy
+
spark.executor.extraLibraryPath
(none)
- Set the strategy of rolling of executor logs. By default it is disabled. It can
- be set to "time" (time-based rolling) or "size" (size-based rolling). For "time",
- use spark.executor.logs.rolling.time.interval to set the rolling interval.
- For "size", use spark.executor.logs.rolling.size.maxBytes to set
- the maximum file size for rolling.
+ Set a special library path to use when launching executor JVM's.
-
spark.executor.logs.rolling.time.interval
-
daily
+
spark.executor.logs.rolling.maxRetainedFiles
+
(none)
- Set the time interval by which the executor logs will be rolled over.
- Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or
- any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles
- for automatic cleaning of old logs.
+ Sets the number of latest rolling log files that are going to be retained by the system.
+ Older log files will be deleted. Disabled by default.
@@ -289,30 +270,40 @@ Apart from these, the following properties are also available, and may be useful
-
spark.executor.logs.rolling.maxRetainedFiles
+
spark.executor.logs.rolling.strategy
(none)
- Sets the number of latest rolling log files that are going to be retained by the system.
- Older log files will be deleted. Disabled by default.
+ Set the strategy of rolling of executor logs. By default it is disabled. It can
+ be set to "time" (time-based rolling) or "size" (size-based rolling). For "time",
+ use spark.executor.logs.rolling.time.interval to set the rolling interval.
+ For "size", use spark.executor.logs.rolling.size.maxBytes to set
+ the maximum file size for rolling.
-
spark.files.userClassPathFirst
+
spark.executor.logs.rolling.time.interval
+
daily
+
+ Set the time interval by which the executor logs will be rolled over.
+ Rolling is disabled by default. Valid values are `daily`, `hourly`, `minutely` or
+ any interval in seconds. See spark.executor.logs.rolling.maxRetainedFiles
+ for automatic cleaning of old logs.
+
+
+
+
spark.executor.userClassPathFirst
false
- (Experimental) Whether to give user-added jars precedence over Spark's own jars when
- loading classes in Executors. This feature can be used to mitigate conflicts between
- Spark's dependencies and user dependencies. It is currently an experimental feature.
- (Currently, this setting does not work for YARN, see SPARK-2996 for more details).
+ (Experimental) Same functionality as spark.driver.userClassPathFirst, but
+ applied to executor instances.
-
spark.python.worker.memory
-
512m
+
spark.executorEnv.[EnvironmentVariableName]
+
(none)
- Amount of memory to use per python worker process during aggregation, in the same
- format as JVM memory strings (e.g. 512m, 2g). If the memory
- used during aggregation goes above this amount, it will spill the data into disks.
+ Add the environment variable specified by EnvironmentVariableName to the Executor
+ process. The user can specify multiple of these to set multiple environment variables.
@@ -338,6 +329,15 @@ Apart from these, the following properties are also available, and may be useful
automatically.
+
+
spark.python.worker.memory
+
512m
+
+ Amount of memory to use per python worker process during aggregation, in the same
+ format as JVM memory strings (e.g. 512m, 2g). If the memory
+ used during aggregation goes above this amount, it will spill the data into disks.
+
+
spark.python.worker.reuse
true
@@ -348,40 +348,38 @@ Apart from these, the following properties are also available, and may be useful
from JVM to Python worker for every task.
+
+
+#### Shuffle Behavior
+
+
Property Name
Default
Meaning
-
spark.executorEnv.[EnvironmentVariableName]
-
(none)
+
spark.reducer.maxMbInFlight
+
48
- Add the environment variable specified by EnvironmentVariableName to the Executor
- process. The user can specify multiple of these to set multiple environment variables.
+ Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since
+ each output requires us to create a buffer to receive it, this represents a fixed memory
+ overhead per reduce task, so keep it small unless you have a large amount of memory.
-
spark.mesos.executor.home
-
driver side SPARK_HOME
+
spark.shuffle.blockTransferService
+
netty
- Set the directory in which Spark is installed on the executors in Mesos. By default, the
- executors will simply use the driver's Spark home directory, which may not be visible to
- them. Note that this is only relevant if a Spark binary package is not specified through
- spark.executor.uri.
+ Implementation to use for transferring shuffle and cached blocks between executors. There
+ are two implementations available: netty and nio. Netty-based
+ block transfer is intended to be simpler but equally efficient and is the default option
+ starting in 1.2.
-
spark.mesos.executor.memoryOverhead
-
executor memory * 0.07, with minimum of 384
+
spark.shuffle.compress
+
true
- This value is an additive for spark.executor.memory, specified in MiB,
- which is used to calculate the total Mesos task memory. A value of 384
- implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum
- overhead. The final overhead will be the larger of either
- `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`.
+ Whether to compress map output files. Generally a good idea. Compression will use
+ spark.io.compression.codec.
-
-
-#### Shuffle Behavior
-
-
Property Name
Default
Meaning
spark.shuffle.consolidateFiles
false
@@ -393,55 +391,46 @@ Apart from these, the following properties are also available, and may be useful
-
spark.shuffle.spill
-
true
+
spark.shuffle.file.buffer.kb
+
32
- If set to "true", limits the amount of memory used during reduces by spilling data out to disk.
- This spilling threshold is specified by spark.shuffle.memoryFraction.
+ Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers
+ reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
-
spark.shuffle.spill.compress
-
true
+
spark.shuffle.io.maxRetries
+
3
- Whether to compress data spilled during shuffles. Compression will use
- spark.io.compression.codec.
+ (Netty only) Fetches that fail due to IO-related exceptions are automatically retried if this is
+ set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC
+ pauses or transient network connectivity issues.
-
spark.shuffle.memoryFraction
-
0.2
+
spark.shuffle.io.numConnectionsPerPeer
+
1
- Fraction of Java heap to use for aggregation and cogroups during shuffles, if
- spark.shuffle.spill is true. At any given time, the collective size of
- all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will
- begin to spill to disk. If spills are often, consider increasing this value at the expense of
- spark.storage.memoryFraction.
+ (Netty only) Connections between hosts are reused in order to reduce connection buildup for
+ large clusters. For clusters with many hard disks and few hosts, this may result in insufficient
+ concurrency to saturate all disks, and so users may consider increasing this value.
-
spark.shuffle.compress
+
spark.shuffle.io.preferDirectBufs
true
- Whether to compress map output files. Generally a good idea. Compression will use
- spark.io.compression.codec.
-
-
-
-
spark.shuffle.file.buffer.kb
-
32
-
- Size of the in-memory buffer for each shuffle file output stream, in kilobytes. These buffers
- reduce the number of disk seeks and system calls made in creating intermediate shuffle files.
+ (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache
+ block transfer. For environments where off-heap memory is tightly limited, users may wish to
+ turn this off to force all allocations from Netty to be on-heap.
-
spark.reducer.maxMbInFlight
-
48
+
spark.shuffle.io.retryWait
+
5
- Maximum size (in megabytes) of map outputs to fetch simultaneously from each reduce task. Since
- each output requires us to create a buffer to receive it, this represents a fixed memory
- overhead per reduce task, so keep it small unless you have a large amount of memory.
+ (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying
+ is simply maxRetries * retryWait, by default 15 seconds.
@@ -453,6 +442,17 @@ Apart from these, the following properties are also available, and may be useful
the default option starting in 1.2.
+
+
spark.shuffle.memoryFraction
+
0.2
+
+ Fraction of Java heap to use for aggregation and cogroups during shuffles, if
+ spark.shuffle.spill is true. At any given time, the collective size of
+ all in-memory maps used for shuffles is bounded by this limit, beyond which the contents will
+ begin to spill to disk. If spills are often, consider increasing this value at the expense of
+ spark.storage.memoryFraction.
+
+
spark.shuffle.sort.bypassMergeThreshold
200
@@ -462,13 +462,19 @@ Apart from these, the following properties are also available, and may be useful
-
spark.shuffle.blockTransferService
-
netty
+
spark.shuffle.spill
+
true
- Implementation to use for transferring shuffle and cached blocks between executors. There
- are two implementations available: netty and nio. Netty-based
- block transfer is intended to be simpler but equally efficient and is the default option
- starting in 1.2.
+ If set to "true", limits the amount of memory used during reduces by spilling data out to disk.
+ This spilling threshold is specified by spark.shuffle.memoryFraction.
+
+
+
+
spark.shuffle.spill.compress
+
true
+
+ Whether to compress data spilled during shuffles. Compression will use
+ spark.io.compression.codec.
@@ -477,26 +483,28 @@ Apart from these, the following properties are also available, and may be useful
Property Name
Default
Meaning
-
spark.ui.port
-
4040
+
spark.eventLog.compress
+
false
- Port for your application's dashboard, which shows memory and workload data.
+ Whether to compress logged events, if spark.eventLog.enabled is true.
-
spark.ui.retainedStages
-
1000
+
spark.eventLog.dir
+
file:///tmp/spark-events
- How many stages the Spark UI and status APIs remember before garbage
- collecting.
+ Base directory in which Spark events are logged, if spark.eventLog.enabled is true.
+ Within this base directory, Spark creates a sub-directory for each application, and logs the
+ events specific to the application in this directory. Users may want to set this to
+ a unified location like an HDFS directory so history files can be read by the history server.
-
spark.ui.retainedJobs
-
1000
+
spark.eventLog.enabled
+
false
- How many jobs the Spark UI and status APIs remember before garbage
- collecting.
+ Whether to log Spark events, useful for reconstructing the Web UI after the application has
+ finished.
@@ -507,28 +515,26 @@ Apart from these, the following properties are also available, and may be useful
-
spark.eventLog.enabled
-
false
+
spark.ui.port
+
4040
- Whether to log Spark events, useful for reconstructing the Web UI after the application has
- finished.
+ Port for your application's dashboard, which shows memory and workload data.
-
spark.eventLog.compress
-
false
+
spark.ui.retainedJobs
+
1000
- Whether to compress logged events, if spark.eventLog.enabled is true.
+ How many jobs the Spark UI and status APIs remember before garbage
+ collecting.
-
spark.eventLog.dir
-
file:///tmp/spark-events
+
spark.ui.retainedStages
+
1000
- Base directory in which Spark events are logged, if spark.eventLog.enabled is true.
- Within this base directory, Spark creates a sub-directory for each application, and logs the
- events specific to the application in this directory. Users may want to set this to
- a unified location like an HDFS directory so history files can be read by the history server.
+ How many stages the Spark UI and status APIs remember before garbage
+ collecting.
@@ -544,12 +550,10 @@ Apart from these, the following properties are also available, and may be useful
-
spark.rdd.compress
-
false
+
spark.closure.serializer
+
org.apache.spark.serializer. JavaSerializer
- Whether to compress serialized RDD partitions (e.g. for
- StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some
- extra CPU time.
+ Serializer class to use for closures. Currently only the Java serializer is supported.
@@ -565,14 +569,6 @@ Apart from these, the following properties are also available, and may be useful
and org.apache.spark.io.SnappyCompressionCodec.
-
-
spark.io.compression.snappy.block.size
-
32768
-
- Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec
- is used. Lowering this block size will also lower shuffle memory usage when Snappy is used.
-
-
spark.io.compression.lz4.block.size
32768
@@ -582,21 +578,20 @@ Apart from these, the following properties are also available, and may be useful
-
spark.closure.serializer
-
org.apache.spark.serializer. JavaSerializer
+
spark.io.compression.snappy.block.size
+
32768
- Serializer class to use for closures. Currently only the Java serializer is supported.
+ Block size (in bytes) used in Snappy compression, in the case when Snappy compression codec
+ is used. Lowering this block size will also lower shuffle memory usage when Snappy is used.
-
spark.serializer.objectStreamReset
-
100
+
spark.kryo.classesToRegister
+
(none)
- When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches
- objects to prevent writing redundant data, however that stops garbage collection of those
- objects. By calling 'reset' you flush that info from the serializer, and allow old
- objects to be collected. To turn off this periodic reset set it to -1.
- By default it will reset the serializer every 100 objects.
+ If you use Kryo serialization, give a comma-separated list of custom class names to register
+ with Kryo.
+ See the tuning guide for more details.
@@ -621,12 +616,16 @@ Apart from these, the following properties are also available, and may be useful
-
spark.kryoserializer.buffer.mb
-
0.064
+
spark.kryo.registrator
+
(none)
- Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer
- per core on each worker. This buffer will grow up to
- spark.kryoserializer.buffer.max.mb if needed.
+ If you use Kryo serialization, set this class to register your custom classes with Kryo. This
+ property is useful if you need to register your classes in a custom way, e.g. to specify a custom
+ field serializer. Otherwise spark.kryo.classesToRegister is simpler. It should be
+ set to a class that extends
+
+ KryoRegistrator.
+ See the tuning guide for more details.
@@ -638,11 +637,80 @@ Apart from these, the following properties are also available, and may be useful
inside Kryo.
+
+
spark.kryoserializer.buffer.mb
+
0.064
+
+ Initial size of Kryo's serialization buffer, in megabytes. Note that there will be one buffer
+ per core on each worker. This buffer will grow up to
+ spark.kryoserializer.buffer.max.mb if needed.
+
+
+
+
spark.rdd.compress
+
false
+
+ Whether to compress serialized RDD partitions (e.g. for
+ StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some
+ extra CPU time.
+
+ When serializing using org.apache.spark.serializer.JavaSerializer, the serializer caches
+ objects to prevent writing redundant data, however that stops garbage collection of those
+ objects. By calling 'reset' you flush that info from the serializer, and allow old
+ objects to be collected. To turn off this periodic reset set it to -1.
+ By default it will reset the serializer every 100 objects.
+
+
#### Execution Behavior
Property Name
Default
Meaning
+
+
spark.broadcast.blockSize
+
4096
+
+ Size of each piece of a block in kilobytes for TorrentBroadcastFactory.
+ Too large a value decreases parallelism during broadcast (makes it slower); however, if it is
+ too small, BlockManager might take a performance hit.
+
+ Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks
+ generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be
+ forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in
+ case of Spark Streaming applications). Note that any RDD that persists in memory for more than
+ this duration will be cleared as well.
+
+
spark.default.parallelism
@@ -661,19 +729,18 @@ Apart from these, the following properties are also available, and may be useful
Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let
+ the driver know that the executor is still alive and update it with metrics for in-progress
+ tasks.
-
spark.broadcast.blockSize
-
4096
+
spark.files.fetchTimeout
+
60
- Size of each piece of a block in kilobytes for TorrentBroadcastFactory.
- Too large a value decreases parallelism during broadcast (makes it slower); however, if it is
- too small, BlockManager might take a performance hit.
+ Communication timeout to use when fetching files added through SparkContext.addFile() from
+ the driver, in seconds.
@@ -685,12 +752,23 @@ Apart from these, the following properties are also available, and may be useful
-
spark.files.fetchTimeout
-
60
-
- Communication timeout to use when fetching files added through SparkContext.addFile() from
- the driver, in seconds.
-
+
spark.hadoop.cloneConf
+
false
+
If set to true, clones a new Hadoop Configuration object for each task. This
+ option should be enabled to work around Configuration thread-safety issues (see
+ SPARK-2546 for more details).
+ This is disabled by default in order to avoid unexpected performance regressions for jobs that
+ are not affected by these issues.
+
+
+
spark.hadoop.validateOutputSpecs
+
true
+
If set to true, validates the output specification (e.g. checking if the output directory already exists)
+ used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing
+ output directories. We recommend that users do not disable this except if trying to achieve compatibility with
+ previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand.
+ This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since
+ data may need to be rewritten to pre-existing output directories during checkpoint recovery.
spark.storage.memoryFraction
@@ -701,6 +779,15 @@ Apart from these, the following properties are also available, and may be useful
increase it if you configure your own old generation size.
+
+
spark.storage.memoryMapThreshold
+
2097152
+
+ Size of a block, in bytes, above which Spark memory maps when reading a block from disk.
+ This prevents Spark from memory mapping very small blocks. In general, memory
+ mapping has high overhead for blocks close to or below the page size of the operating system.
+
+
spark.storage.unrollFraction
0.2
@@ -719,15 +806,6 @@ Apart from these, the following properties are also available, and may be useful
directories on Tachyon file system.
-
-
spark.storage.memoryMapThreshold
-
2097152
-
- Size of a block, in bytes, above which Spark memory maps when reading a block from disk.
- This prevents Spark from memory mapping very small blocks. In general, memory
- mapping has high overhead for blocks close to or below the page size of the operating system.
-
-
spark.tachyonStore.url
tachyon://localhost:19998
@@ -735,106 +813,19 @@ Apart from these, the following properties are also available, and may be useful
The URL of the underlying Tachyon file system in the TachyonStore.
-
-
spark.cleaner.ttl
-
(infinite)
-
- Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks
- generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be
- forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in
- case of Spark Streaming applications). Note that any RDD that persists in memory for more than
- this duration will be cleared as well.
-
-
-
-
spark.hadoop.validateOutputSpecs
-
true
-
If set to true, validates the output specification (e.g. checking if the output directory already exists)
- used in saveAsHadoopFile and other variants. This can be disabled to silence exceptions due to pre-existing
- output directories. We recommend that users do not disable this except if trying to achieve compatibility with
- previous versions of Spark. Simply use Hadoop's FileSystem API to delete output directories by hand.
- This setting is ignored for jobs generated through Spark Streaming's StreamingContext, since
- data may need to be rewritten to pre-existing output directories during checkpoint recovery.
-
-
-
spark.hadoop.cloneConf
-
false
-
If set to true, clones a new Hadoop Configuration object for each task. This
- option should be enabled to work around Configuration thread-safety issues (see
- SPARK-2546 for more details).
- This is disabled by default in order to avoid unexpected performance regressions for jobs that
- are not affected by these issues.
-
-
-
spark.executor.heartbeatInterval
-
10000
-
Interval (milliseconds) between each executor's heartbeats to the driver. Heartbeats let
- the driver know that the executor is still alive and update it with metrics for in-progress
- tasks.
-
#### Networking
Property Name
Default
Meaning
-
spark.driver.host
-
(local hostname)
-
- Hostname or IP address for the driver to listen on.
- This is used for communicating with the executors and the standalone Master.
-
-
-
-
spark.driver.port
-
(random)
-
- Port for the driver to listen on.
- This is used for communicating with the executors and the standalone Master.
-
-
-
-
spark.fileserver.port
-
(random)
-
- Port for the driver's HTTP file server to listen on.
-
-
-
-
spark.broadcast.port
-
(random)
-
- Port for the driver's HTTP broadcast server to listen on.
- This is not relevant for torrent broadcast.
-
-
-
-
spark.replClassServer.port
-
(random)
-
- Port for the driver's HTTP class server to listen on.
- This is only relevant for the Spark shell.
-
-
-
-
spark.blockManager.port
-
(random)
-
- Port for all block managers to listen on. These exist on both the driver and the executors.
-
-
-
-
spark.executor.port
-
(random)
-
- Port for the executor to listen on. This is used for communicating with the driver.
-
-
-
-
spark.port.maxRetries
-
16
+
spark.akka.failure-detector.threshold
+
300.0
- Default maximum number of retries when binding to a port before giving up.
+ This is set to a larger value to disable failure detector that comes inbuilt akka. It can be
+ enabled again, if you plan to use this feature (Not recommended). This maps to akka's
+ `akka.remote.transport-failure-detector.threshold`. Tune this in combination of
+ `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to.
@@ -847,181 +838,139 @@ Apart from these, the following properties are also available, and may be useful
-
spark.akka.threads
-
4
-
- Number of actor threads to use for communication. Can be useful to increase on large clusters
- when the driver has a lot of CPU cores.
-
-
-
-
spark.akka.timeout
-
100
-
- Communication timeout between Spark nodes, in seconds.
-
-
-
-
spark.network.timeout
-
120
+
spark.akka.heartbeat.interval
+
1000
- Default timeout for all network interactions, in seconds. This config will be used in
- place of spark.core.connection.ack.wait.timeout, spark.akka.timeout,
- spark.storage.blockManagerSlaveTimeoutMs or
- spark.shuffle.io.connectionTimeout, if they are not configured.
+ This is set to a larger value to disable the transport failure detector that comes built in to
+ Akka. It can be enabled again, if you plan to use this feature (Not recommended). A larger
+ interval value in seconds reduces network overhead and a smaller value ( ~ 1 s) might be more
+ informative for Akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses`
+ if you need to. A likely positive use case for using failure detector would be: a sensistive
+ failure detector can help evict rogue executors quickly. However this is usually not the case
+ as GC pauses and network lags are expected in a real Spark cluster. Apart from that enabling
+ this leads to a lot of exchanges of heart beats between nodes leading to flooding the network
+ with those.
spark.akka.heartbeat.pauses
6000
- This is set to a larger value to disable failure detector that comes inbuilt akka. It can be
- enabled again, if you plan to use this feature (Not recommended). Acceptable heart beat pause
- in seconds for akka. This can be used to control sensitivity to gc pauses. Tune this in
- combination of `spark.akka.heartbeat.interval` and `spark.akka.failure-detector.threshold`
- if you need to.
+ This is set to a larger value to disable the transport failure detector that comes built in to Akka.
+ It can be enabled again, if you plan to use this feature (Not recommended). Acceptable heart
+ beat pause in seconds for Akka. This can be used to control sensitivity to GC pauses. Tune
+ this along with `spark.akka.heartbeat.interval` if you need to.
-
spark.akka.failure-detector.threshold
-
300.0
-
- This is set to a larger value to disable failure detector that comes inbuilt akka. It can be
- enabled again, if you plan to use this feature (Not recommended). This maps to akka's
- `akka.remote.transport-failure-detector.threshold`. Tune this in combination of
- `spark.akka.heartbeat.pauses` and `spark.akka.heartbeat.interval` if you need to.
-
-
-
-
spark.akka.heartbeat.interval
-
1000
-
- This is set to a larger value to disable failure detector that comes inbuilt akka. It can be
- enabled again, if you plan to use this feature (Not recommended). A larger interval value in
- seconds reduces network overhead and a smaller value ( ~ 1 s) might be more informative for
- akka's failure detector. Tune this in combination of `spark.akka.heartbeat.pauses` and
- `spark.akka.failure-detector.threshold` if you need to. Only positive use case for using
- failure detector can be, a sensistive failure detector can help evict rogue executors really
- quick. However this is usually not the case as gc pauses and network lags are expected in a
- real Spark cluster. Apart from that enabling this leads to a lot of exchanges of heart beats
- between nodes leading to flooding the network with those.
-
-
-
-
spark.shuffle.io.preferDirectBufs
-
true
+
spark.akka.threads
+
4
- (Netty only) Off-heap buffers are used to reduce garbage collection during shuffle and cache
- block transfer. For environments where off-heap memory is tightly limited, users may wish to
- turn this off to force all allocations from Netty to be on-heap.
+ Number of actor threads to use for communication. Can be useful to increase on large clusters
+ when the driver has a lot of CPU cores.
-
spark.shuffle.io.numConnectionsPerPeer
-
1
+
spark.akka.timeout
+
100
- (Netty only) Connections between hosts are reused in order to reduce connection buildup for
- large clusters. For clusters with many hard disks and few hosts, this may result in insufficient
- concurrency to saturate all disks, and so users may consider increasing this value.
+ Communication timeout between Spark nodes, in seconds.
-
spark.shuffle.io.maxRetries
-
3
+
spark.blockManager.port
+
(random)
- (Netty only) Fetches that fail due to IO-related exceptions are automatically retried if this is
- set to a non-zero value. This retry logic helps stabilize large shuffles in the face of long GC
- pauses or transient network connectivity issues.
+ Port for all block managers to listen on. These exist on both the driver and the executors.
-
spark.shuffle.io.retryWait
-
5
+
spark.broadcast.port
+
(random)
- (Netty only) Seconds to wait between retries of fetches. The maximum delay caused by retrying
- is simply maxRetries * retryWait, by default 15 seconds.
+ Port for the driver's HTTP broadcast server to listen on.
+ This is not relevant for torrent broadcast.
-
-
-#### Scheduling
-
-
Property Name
Default
Meaning
-
spark.task.cpus
-
1
+
spark.driver.host
+
(local hostname)
- Number of cores to allocate for each task.
+ Hostname or IP address for the driver to listen on.
+ This is used for communicating with the executors and the standalone Master.
-
spark.task.maxFailures
-
4
+
spark.driver.port
+
(random)
- Number of individual task failures before giving up on the job.
- Should be greater than or equal to 1. Number of allowed retries = this value - 1.
+ Port for the driver to listen on.
+ This is used for communicating with the executors and the standalone Master.
-
spark.scheduler.mode
-
FIFO
+
spark.executor.port
+
(random)
- The scheduling mode between
- jobs submitted to the same SparkContext. Can be set to FAIR
- to use fair sharing instead of queueing jobs one after another. Useful for
- multi-user services.
+ Port for the executor to listen on. This is used for communicating with the driver.
-
spark.cores.max
-
(not set)
-
- When running on a standalone deploy cluster or a
- Mesos cluster in "coarse-grained"
- sharing mode, the maximum amount of CPU cores to request for the application from
- across the cluster (not from each machine). If not set, the default will be
- spark.deploy.defaultCores on Spark's standalone cluster manager, or
- infinite (all available cores) on Mesos.
+
spark.fileserver.port
+
(random)
+
+ Port for the driver's HTTP file server to listen on.
-
spark.mesos.coarse
-
false
+
spark.network.timeout
+
120
- If set to "true", runs over Mesos clusters in
- "coarse-grained" sharing mode,
- where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per
- Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use
- for the whole duration of the Spark job.
+ Default timeout for all network interactions, in seconds. This config will be used in
+ place of spark.core.connection.ack.wait.timeout, spark.akka.timeout,
+ spark.storage.blockManagerSlaveTimeoutMs or
+ spark.shuffle.io.connectionTimeout, if they are not configured.
-
spark.speculation
-
false
+
spark.port.maxRetries
+
16
- If set to "true", performs speculative execution of tasks. This means if one or more tasks are
- running slowly in a stage, they will be re-launched.
+ Default maximum number of retries when binding to a port before giving up.
-
spark.speculation.interval
-
100
+
spark.replClassServer.port
+
(random)
- How often Spark will check for tasks to speculate, in milliseconds.
+ Port for the driver's HTTP class server to listen on.
+ This is only relevant for the Spark shell.
+
+
+#### Scheduling
+
+
Property Name
Default
Meaning
-
spark.speculation.quantile
-
0.75
+
spark.cores.max
+
(not set)
- Percentage of tasks which must be complete before speculation is enabled for a particular stage.
+ When running on a standalone deploy cluster or a
+ Mesos cluster in "coarse-grained"
+ sharing mode, the maximum amount of CPU cores to request for the application from
+ across the cluster (not from each machine). If not set, the default will be
+ spark.deploy.defaultCores on Spark's standalone cluster manager, or
+ infinite (all available cores) on Mesos.
-
spark.speculation.multiplier
-
1.5
+
spark.localExecution.enabled
+
false
- How many times slower a task is than the median to be considered for speculation.
+ Enables Spark to run certain jobs, such as first() or take() on the driver, without sending
+ tasks to the cluster. This can make certain jobs execute very quickly, but may require
+ shipping a whole partition of data to the driver.
@@ -1037,19 +986,19 @@ Apart from these, the following properties are also available, and may be useful
-
spark.locality.wait.process
+
spark.locality.wait.node
spark.locality.wait
- Customize the locality wait for process locality. This affects tasks that attempt to access
- cached data in a particular executor process.
+ Customize the locality wait for node locality. For example, you can set this to 0 to skip
+ node locality and search immediately for rack locality (if your cluster has rack information).
-
spark.locality.wait.node
+
spark.locality.wait.process
spark.locality.wait
- Customize the locality wait for node locality. For example, you can set this to 0 to skip
- node locality and search immediately for rack locality (if your cluster has rack information).
+ Customize the locality wait for process locality. This affects tasks that attempt to access
+ cached data in a particular executor process.
@@ -1060,16 +1009,16 @@ Apart from these, the following properties are also available, and may be useful
-
spark.scheduler.revive.interval
-
1000
+
spark.scheduler.maxRegisteredResourcesWaitingTime
+
30000
- The interval length for the scheduler to revive the worker resource offers to run tasks
+ Maximum amount of time to wait for resources to register before scheduling begins
(in milliseconds).
-
+
spark.scheduler.minRegisteredResourcesRatio
-
0.0 for Mesos and Standalone mode, 0.8 for YARN
+
0.8 for YARN mode; 0.0 otherwise
The minimum ratio of registered resources (registered resources / total expected resources)
(resources are executors in yarn mode, CPU cores in standalone mode)
@@ -1080,25 +1029,70 @@ Apart from these, the following properties are also available, and may be useful
-
spark.scheduler.maxRegisteredResourcesWaitingTime
-
30000
+
spark.scheduler.mode
+
FIFO
- Maximum amount of time to wait for resources to register before scheduling begins
+ The scheduling mode between
+ jobs submitted to the same SparkContext. Can be set to FAIR
+ to use fair sharing instead of queueing jobs one after another. Useful for
+ multi-user services.
+
+
+
+
spark.scheduler.revive.interval
+
1000
+
+ The interval length for the scheduler to revive the worker resource offers to run tasks
(in milliseconds).
-
spark.localExecution.enabled
+
spark.speculation
false
- Enables Spark to run certain jobs, such as first() or take() on the driver, without sending
- tasks to the cluster. This can make certain jobs execute very quickly, but may require
- shipping a whole partition of data to the driver.
+ If set to "true", performs speculative execution of tasks. This means if one or more tasks are
+ running slowly in a stage, they will be re-launched.
+
+
+
+
spark.speculation.interval
+
100
+
+ How often Spark will check for tasks to speculate, in milliseconds.
+
+
+
+
spark.speculation.multiplier
+
1.5
+
+ How many times slower a task is than the median to be considered for speculation.
+
+
+
+
spark.speculation.quantile
+
0.75
+
+ Percentage of tasks which must be complete before speculation is enabled for a particular stage.
+
+
+
+
spark.task.cpus
+
1
+
+ Number of cores to allocate for each task.
+
+
+
+
spark.task.maxFailures
+
4
+
+ Number of individual task failures before giving up on the job.
+ Should be greater than or equal to 1. Number of allowed retries = this value - 1.
-#### Dynamic allocation
+#### Dynamic Allocation
Property Name
Default
Meaning
@@ -1118,10 +1112,19 @@ Apart from these, the following properties are also available, and may be useful
+
spark.dynamicAllocation.executorIdleTimeout
+
600
+
+ If dynamic allocation is enabled and an executor has been idle for more than this duration
+ (in seconds), the executor will be removed. For more detail, see this
+ description.
+
+
+
+
spark.dynamicAllocation.initialExecutors
spark.dynamicAllocation.minExecutors
-
0
- Lower bound for the number of executors if dynamic allocation is enabled.
+ Initial number of executors to run if dynamic allocation is enabled.
@@ -1132,10 +1135,10 @@ Apart from these, the following properties are also available, and may be useful
-
spark.dynamicAllocation.maxExecutors
spark.dynamicAllocation.minExecutors
+
0
- Initial number of executors to run if dynamic allocation is enabled.
+ Lower bound for the number of executors if dynamic allocation is enabled.
@@ -1156,20 +1159,30 @@ Apart from these, the following properties are also available, and may be useful
description.
-
-
spark.dynamicAllocation.executorIdleTimeout
-
600
-
- If dynamic allocation is enabled and an executor has been idle for more than this duration
- (in seconds), the executor will be removed. For more detail, see this
- description.
-
-
#### Security
Property Name
Default
Meaning
+
+
spark.acls.enable
+
false
+
+ Whether Spark acls should are enabled. If enabled, this checks to see if the user has
+ access permissions to view or modify the job. Note this requires the user to be known,
+ so if the user comes across as null no checks are done. Filters can be used with the UI
+ to authenticate and set the user.
+
+
+
+
spark.admin.acls
+
Empty
+
+ Comma separated list of users/administrators that have view and modify access to all Spark jobs.
+ This can be used if you run on a shared cluster and have a set of administrators or devs who
+ help debug when things work.
+
+
spark.authenticate
false
@@ -1186,6 +1199,15 @@ Apart from these, the following properties are also available, and may be useful
not running on YARN and authentication is enabled.
+
+
spark.core.connection.ack.wait.timeout
+
60
+
+ Number of seconds for the connection to wait for ack to occur before timing
+ out and giving up. To avoid unwilling timeout caused by long pause like GC,
+ you can set larger value.
+
+
spark.core.connection.auth.wait.timeout
30
@@ -1195,12 +1217,11 @@ Apart from these, the following properties are also available, and may be useful
-
spark.core.connection.ack.wait.timeout
-
60
+
spark.modify.acls
+
Empty
- Number of seconds for the connection to wait for ack to occur before timing
- out and giving up. To avoid unwilling timeout caused by long pause like GC,
- you can set larger value.
+ Comma separated list of users that have modify access to the Spark job. By default only the
+ user that started the Spark job has access to modify it (kill it for example).
@@ -1217,16 +1238,6 @@ Apart from these, the following properties are also available, and may be useful
-Dspark.com.test.filter1.params='param1=foo,param2=testing'
-
-
spark.acls.enable
-
false
-
- Whether Spark acls should are enabled. If enabled, this checks to see if the user has
- access permissions to view or modify the job. Note this requires the user to be known,
- so if the user comes across as null no checks are done. Filters can be used with the UI
- to authenticate and set the user.
-
-
spark.ui.view.acls
Empty
@@ -1235,23 +1246,6 @@ Apart from these, the following properties are also available, and may be useful
user that started the Spark job has view access.
-
-
spark.modify.acls
-
Empty
-
- Comma separated list of users that have modify access to the Spark job. By default only the
- user that started the Spark job has access to modify it (kill it for example).
-
-
-
-
spark.admin.acls
-
Empty
-
- Comma separated list of users/administrators that have view and modify access to all Spark jobs.
- This can be used if you run on a shared cluster and have a set of administrators or devs who
- help debug when things work.
-
-
#### Encryption
@@ -1275,6 +1269,23 @@ Apart from these, the following properties are also available, and may be useful
file server.
+
+
spark.ssl.enabledAlgorithms
+
Empty
+
+ A comma separated list of ciphers. The specified ciphers must be supported by JVM.
+ The reference list of protocols one can find on
+ this
+ page.
+
+
+
+
spark.ssl.keyPassword
+
None
+
+ A password to the private key in key-store.
+
+
spark.ssl.keyStore
None
@@ -1291,10 +1302,12 @@ Apart from these, the following properties are also available, and may be useful
-
spark.ssl.keyPassword
+
spark.ssl.protocol
None
- A password to the private key in key-store.
+ A protocol name. The protocol must be supported by JVM. The reference list of protocols
+ one can find on this
+ page.
@@ -1312,25 +1325,6 @@ Apart from these, the following properties are also available, and may be useful
A password to the trust-store.
-
-
spark.ssl.protocol
-
None
-
- A protocol name. The protocol must be supported by JVM. The reference list of protocols
- one can find on this
- page.
-
-
-
-
spark.ssl.enabledAlgorithms
-
Empty
-
- A comma separated list of ciphers. The specified ciphers must be supported by JVM.
- The reference list of protocols one can find on
- this
- page.
-
-
diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md
index 826f6d8f371c7..28bdf81ca0ca5 100644
--- a/docs/graphx-programming-guide.md
+++ b/docs/graphx-programming-guide.md
@@ -538,7 +538,7 @@ val joinedGraph = graph.joinVertices(uniqueCosts,
## Neighborhood Aggregation
-A key step in may graph analytics tasks is aggregating information about the neighborhood of each
+A key step in many graph analytics tasks is aggregating information about the neighborhood of each
vertex.
For example, we might want to know the number of followers each user has or the average age of the
the followers of each user. Many iterative graph algorithms (e.g., PageRank, Shortest Path, and
@@ -634,7 +634,7 @@ avgAgeOfOlderFollowers.collect.foreach(println(_))
### Map Reduce Triplets Transition Guide (Legacy)
-In earlier versions of GraphX we neighborhood aggregation was accomplished using the
+In earlier versions of GraphX neighborhood aggregation was accomplished using the
[`mapReduceTriplets`][Graph.mapReduceTriplets] operator:
{% highlight scala %}
@@ -682,8 +682,8 @@ val result = graph.aggregateMessages[String](msgFun, reduceFun)
### Computing Degree Information
A common aggregation task is computing the degree of each vertex: the number of edges adjacent to
-each vertex. In the context of directed graphs it often necessary to know the in-degree, out-
-degree, and the total degree of each vertex. The [`GraphOps`][GraphOps] class contains a
+each vertex. In the context of directed graphs it is often necessary to know the in-degree,
+out-degree, and the total degree of each vertex. The [`GraphOps`][GraphOps] class contains a
collection of operators to compute the degrees of each vertex. For example in the following we
compute the max in, out, and total degrees:
diff --git a/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png b/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png
deleted file mode 100644
index ed9adad11d03a..0000000000000
Binary files a/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png and /dev/null differ
diff --git a/docs/ml-guide.md b/docs/ml-guide.md
index be178d7689fdd..da6aef7f14c4c 100644
--- a/docs/ml-guide.md
+++ b/docs/ml-guide.md
@@ -23,13 +23,13 @@ to `spark.ml`.
Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API.
-* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`SchemaRDD`](api/scala/index.html#org.apache.spark.sql.SchemaRDD) from Spark SQL as a dataset which can hold a variety of data types.
+* **[ML Dataset](ml-guide.html#ml-dataset)**: Spark ML uses the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL as a dataset which can hold a variety of data types.
E.g., a dataset could have different columns storing text, feature vectors, true labels, and predictions.
-* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `SchemaRDD` into another `SchemaRDD`.
+* **[`Transformer`](ml-guide.html#transformers)**: A `Transformer` is an algorithm which can transform one `DataFrame` into another `DataFrame`.
E.g., an ML model is a `Transformer` which transforms an RDD with features into an RDD with predictions.
-* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `SchemaRDD` to produce a `Transformer`.
+* **[`Estimator`](ml-guide.html#estimators)**: An `Estimator` is an algorithm which can be fit on a `DataFrame` to produce a `Transformer`.
E.g., a learning algorithm is an `Estimator` which trains on a dataset and produces a model.
* **[`Pipeline`](ml-guide.html#pipeline)**: A `Pipeline` chains multiple `Transformer`s and `Estimator`s together to specify an ML workflow.
@@ -39,20 +39,20 @@ E.g., a learning algorithm is an `Estimator` which trains on a dataset and produ
## ML Dataset
Machine learning can be applied to a wide variety of data types, such as vectors, text, images, and structured data.
-Spark ML adopts the [`SchemaRDD`](api/scala/index.html#org.apache.spark.sql.SchemaRDD) from Spark SQL in order to support a variety of data types under a unified Dataset concept.
+Spark ML adopts the [`DataFrame`](api/scala/index.html#org.apache.spark.sql.DataFrame) from Spark SQL in order to support a variety of data types under a unified Dataset concept.
-`SchemaRDD` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types.
-In addition to the types listed in the Spark SQL guide, `SchemaRDD` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types.
+`DataFrame` supports many basic and structured types; see the [Spark SQL datatype reference](sql-programming-guide.html#spark-sql-datatype-reference) for a list of supported types.
+In addition to the types listed in the Spark SQL guide, `DataFrame` can use ML [`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) types.
-A `SchemaRDD` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples.
+A `DataFrame` can be created either implicitly or explicitly from a regular `RDD`. See the code examples below and the [Spark SQL programming guide](sql-programming-guide.html) for examples.
-Columns in a `SchemaRDD` are named. The code examples below use names such as "text," "features," and "label."
+Columns in a `DataFrame` are named. The code examples below use names such as "text," "features," and "label."
## ML Algorithms
### Transformers
-A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `SchemaRDD` into another, generally by appending one or more columns.
+A [`Transformer`](api/scala/index.html#org.apache.spark.ml.Transformer) is an abstraction which includes feature transformers and learned models. Technically, a `Transformer` implements a method `transform()` which converts one `DataFrame` into another, generally by appending one or more columns.
For example:
* A feature transformer might take a dataset, read a column (e.g., text), convert it into a new column (e.g., feature vectors), append the new column to the dataset, and output the updated dataset.
@@ -60,7 +60,7 @@ For example:
### Estimators
-An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `SchemaRDD` and produces a `Transformer`.
+An [`Estimator`](api/scala/index.html#org.apache.spark.ml.Estimator) abstracts the concept of a learning algorithm or any algorithm which fits or trains on data. Technically, an `Estimator` implements a method `fit()` which accepts a `DataFrame` and produces a `Transformer`.
For example, a learning algorithm such as `LogisticRegression` is an `Estimator`, and calling `fit()` trains a `LogisticRegressionModel`, which is a `Transformer`.
### Properties of ML Algorithms
@@ -101,7 +101,7 @@ We illustrate this for the simple text document workflow. The figure below is f
Above, the top row represents a `Pipeline` with three stages.
The first two (`Tokenizer` and `HashingTF`) are `Transformer`s (blue), and the third (`LogisticRegression`) is an `Estimator` (red).
-The bottom row represents data flowing through the pipeline, where cylinders indicate `SchemaRDD`s.
+The bottom row represents data flowing through the pipeline, where cylinders indicate `DataFrame`s.
The `Pipeline.fit()` method is called on the original dataset which has raw text documents and labels.
The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words into the dataset.
The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the dataset.
@@ -130,7 +130,7 @@ Each stage's `transform()` method updates the dataset and passes it to the next
*DAG `Pipeline`s*: A `Pipeline`'s stages are specified as an ordered array. The examples given here are all for linear `Pipeline`s, i.e., `Pipeline`s in which each stage uses data produced by the previous stage. It is possible to create non-linear `Pipeline`s as long as the data flow graph forms a Directed Acyclic Graph (DAG). This graph is currently specified implicitly based on the input and output column names of each stage (generally specified as parameters). If the `Pipeline` forms a DAG, then the stages must be specified in topological order.
-*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `SchemaRDD`.
+*Runtime checking*: Since `Pipeline`s can operate on datasets with varied types, they cannot use compile-time type checking. `Pipeline`s and `PipelineModel`s instead do runtime checking before actually running the `Pipeline`. This type checking is done using the dataset *schema*, a description of the data types of columns in the `DataFrame`.
## Parameters
@@ -171,12 +171,12 @@ import org.apache.spark.sql.{Row, SQLContext}
val conf = new SparkConf().setAppName("SimpleParamsExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
-import sqlContext._
+import sqlContext.implicits._
// Prepare training data.
// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes
-// into SchemaRDDs, where it uses the case class metadata to infer the schema.
-val training = sparkContext.parallelize(Seq(
+// into DataFrames, where it uses the case class metadata to infer the schema.
+val training = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
@@ -192,7 +192,7 @@ lr.setMaxIter(10)
.setRegParam(0.01)
// Learn a LogisticRegression model. This uses the parameters stored in lr.
-val model1 = lr.fit(training)
+val model1 = lr.fit(training.toDF)
// Since model1 is a Model (i.e., a Transformer produced by an Estimator),
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
@@ -203,33 +203,35 @@ println("Model 1 was fit using parameters: " + model1.fittingParamMap)
// which supports several methods for specifying parameters.
val paramMap = ParamMap(lr.maxIter -> 20)
paramMap.put(lr.maxIter, 30) // Specify 1 Param. This overwrites the original maxIter.
-paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.5) // Specify multiple Params.
+paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
// One can also combine ParamMaps.
-val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Changes output column name.
+val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name
val paramMapCombined = paramMap ++ paramMap2
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
-val model2 = lr.fit(training, paramMapCombined)
+val model2 = lr.fit(training.toDF, paramMapCombined)
println("Model 2 was fit using parameters: " + model2.fittingParamMap)
-// Prepare test documents.
-val test = sparkContext.parallelize(Seq(
+// Prepare test data.
+val test = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
-// Make predictions on test documents using the Transformer.transform() method.
+// Make predictions on test data using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
-// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
-// column since we renamed the lr.scoreCol parameter previously.
-model2.transform(test)
- .select('features, 'label, 'probability, 'prediction)
+// Note that model2.transform() outputs a 'myProbability' column instead of the usual
+// 'probability' column since we renamed the lr.probabilityCol parameter previously.
+model2.transform(test.toDF)
+ .select("features", "label", "myProbability", "prediction")
.collect()
- .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) =>
- println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)
+ .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
+ println("($features, $label) -> prob=$prob, prediction=$prediction")
}
+
+sc.stop()
{% endhighlight %}
@@ -244,23 +246,23 @@ import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import org.apache.spark.sql.api.java.Row;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Row;
SparkConf conf = new SparkConf().setAppName("JavaSimpleParamsExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
-JavaSQLContext jsql = new JavaSQLContext(jsc);
+SQLContext jsql = new SQLContext(jsc);
// Prepare training data.
-// We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes
-// into SchemaRDDs, where it uses the case class metadata to infer the schema.
+// We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans
+// into DataFrames, where it uses the bean metadata to infer the schema.
List localTraining = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
-JavaSchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -281,13 +283,13 @@ System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap
// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
-paramMap.put(lr.maxIter(), 20); // Specify 1 Param.
+paramMap.put(lr.maxIter().w(20)); // Specify 1 Param.
paramMap.put(lr.maxIter(), 30); // This overwrites the original maxIter.
-paramMap.put(lr.regParam(), 0.1);
+paramMap.put(lr.regParam().w(0.1), lr.threshold().w(0.55)); // Specify multiple Params.
// One can also combine ParamMaps.
ParamMap paramMap2 = new ParamMap();
-paramMap2.put(lr.scoreCol(), "probability"); // Changes output column name.
+paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name
ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
// Now learn a new model using the paramMapCombined parameters.
@@ -300,19 +302,19 @@ List localTest = Lists.newArrayList(
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
-JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
-// Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
-// column since we renamed the lr.scoreCol parameter previously.
-model2.transform(test).registerAsTable("results");
-JavaSchemaRDD results =
- jsql.sql("SELECT features, label, probability, prediction FROM results");
-for (Row r: results.collect()) {
+// Note that model2.transform() outputs a 'myProbability' column instead of the usual
+// 'probability' column since we renamed the lr.probabilityCol parameter previously.
+DataFrame results = model2.transform(test);
+for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+jsc.stop();
{% endhighlight %}
@@ -330,6 +332,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
// Labeled and unlabeled instance types.
@@ -337,14 +340,14 @@ import org.apache.spark.sql.{Row, SQLContext}
case class LabeledDocument(id: Long, text: String, label: Double)
case class Document(id: Long, text: String)
-// Set up contexts. Import implicit conversions to SchemaRDD from sqlContext.
+// Set up contexts. Import implicit conversions to DataFrame from sqlContext.
val conf = new SparkConf().setAppName("SimpleTextClassificationPipeline")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
-import sqlContext._
+import sqlContext.implicits._
// Prepare training documents, which are labeled.
-val training = sparkContext.parallelize(Seq(
+val training = sc.parallelize(Seq(
LabeledDocument(0L, "a b c d e spark", 1.0),
LabeledDocument(1L, "b d", 0.0),
LabeledDocument(2L, "spark f g h", 1.0),
@@ -365,30 +368,32 @@ val pipeline = new Pipeline()
.setStages(Array(tokenizer, hashingTF, lr))
// Fit the pipeline to training documents.
-val model = pipeline.fit(training)
+val model = pipeline.fit(training.toDF)
// Prepare test documents, which are unlabeled.
-val test = sparkContext.parallelize(Seq(
+val test = sc.parallelize(Seq(
Document(4L, "spark i j k"),
Document(5L, "l m n"),
Document(6L, "mapreduce spark"),
Document(7L, "apache hadoop")))
// Make predictions on test documents.
-model.transform(test)
- .select('id, 'text, 'score, 'prediction)
+model.transform(test.toDF)
+ .select("id", "text", "probability", "prediction")
.collect()
- .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
- println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+ .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
+ println("($id, $text) --> prob=$prob, prediction=$prediction")
}
+
+sc.stop()
{% endhighlight %}
{% highlight java %}
-import java.io.Serializable;
import java.util.List;
import com.google.common.collect.Lists;
+import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineModel;
@@ -396,10 +401,9 @@ import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import org.apache.spark.sql.api.java.Row;
-import org.apache.spark.SparkConf;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
// Labeled and unlabeled instance types.
// Spark SQL can infer schema from Java Beans.
@@ -434,7 +438,7 @@ public class LabeledDocument extends Document implements Serializable {
// Set up contexts.
SparkConf conf = new SparkConf().setAppName("JavaSimpleTextClassificationPipeline");
JavaSparkContext jsc = new JavaSparkContext(conf);
-JavaSQLContext jsql = new JavaSQLContext(jsc);
+SQLContext jsql = new SQLContext(jsc);
// Prepare training documents, which are labeled.
List localTraining = Lists.newArrayList(
@@ -442,8 +446,7 @@ List localTraining = Lists.newArrayList(
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
-JavaSchemaRDD training =
- jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -468,16 +471,62 @@ List localTest = Lists.newArrayList(
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
-JavaSchemaRDD test =
- jsql.applySchema(jsc.parallelize(localTest), Document.class);
+DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
-model.transform(test).registerAsTable("prediction");
-JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
-for (Row r: predictions.collect()) {
- System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+DataFrame predictions = model.transform(test);
+for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
+ System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+jsc.stop();
+{% endhighlight %}
+
+
+
+{% highlight python %}
+from pyspark import SparkContext
+from pyspark.ml import Pipeline
+from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.sql import Row, SQLContext
+
+sc = SparkContext(appName="SimpleTextClassificationPipeline")
+sqlCtx = SQLContext(sc)
+
+# Prepare training documents, which are labeled.
+LabeledDocument = Row("id", "text", "label")
+training = sc.parallelize([(0L, "a b c d e spark", 1.0),
+ (1L, "b d", 0.0),
+ (2L, "spark f g h", 1.0),
+ (3L, "hadoop mapreduce", 0.0)]) \
+ .map(lambda x: LabeledDocument(*x)).toDF()
+
+# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
+tokenizer = Tokenizer(inputCol="text", outputCol="words")
+hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
+lr = LogisticRegression(maxIter=10, regParam=0.01)
+pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
+
+# Fit the pipeline to training documents.
+model = pipeline.fit(training)
+
+# Prepare test documents, which are unlabeled.
+Document = Row("id", "text")
+test = sc.parallelize([(4L, "spark i j k"),
+ (5L, "l m n"),
+ (6L, "mapreduce spark"),
+ (7L, "apache hadoop")]) \
+ .map(lambda x: Document(*x)).toDF()
+
+# Make predictions on test documents and print columns of interest.
+prediction = model.transform(test)
+selected = prediction.select("id", "text", "prediction")
+for row in selected.collect():
+ print row
+
+sc.stop()
{% endhighlight %}
@@ -508,21 +557,21 @@ However, it is also a well-established method for choosing parameters which is m
{% highlight scala %}
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
val conf = new SparkConf().setAppName("CrossValidatorExample")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
-import sqlContext._
+import sqlContext.implicits._
// Prepare training documents, which are labeled.
-val training = sparkContext.parallelize(Seq(
+val training = sc.parallelize(Seq(
LabeledDocument(0L, "a b c d e spark", 1.0),
LabeledDocument(1L, "b d", 0.0),
LabeledDocument(2L, "spark f g h", 1.0),
@@ -565,24 +614,24 @@ crossval.setEstimatorParamMaps(paramGrid)
crossval.setNumFolds(2) // Use 3+ in practice
// Run cross-validation, and choose the best set of parameters.
-val cvModel = crossval.fit(training)
-// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
-val lrModel = cvModel.bestModel
+val cvModel = crossval.fit(training.toDF)
// Prepare test documents, which are unlabeled.
-val test = sparkContext.parallelize(Seq(
+val test = sc.parallelize(Seq(
Document(4L, "spark i j k"),
Document(5L, "l m n"),
Document(6L, "mapreduce spark"),
Document(7L, "apache hadoop")))
// Make predictions on test documents. cvModel uses the best model found (lrModel).
-cvModel.transform(test)
- .select('id, 'text, 'score, 'prediction)
+cvModel.transform(test.toDF)
+ .select("id", "text", "probability", "prediction")
.collect()
- .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
- println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+ .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
+ println(s"($id, $text) --> prob=$prob, prediction=$prediction")
}
+
+sc.stop()
{% endhighlight %}
@@ -592,7 +641,6 @@ import java.util.List;
import com.google.common.collect.Lists;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.ml.Model;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
@@ -603,13 +651,13 @@ import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import org.apache.spark.sql.api.java.Row;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
SparkConf conf = new SparkConf().setAppName("JavaCrossValidatorExample");
JavaSparkContext jsc = new JavaSparkContext(conf);
-JavaSQLContext jsql = new JavaSQLContext(jsc);
+SQLContext jsql = new SQLContext(jsc);
// Prepare training documents, which are labeled.
List localTraining = Lists.newArrayList(
@@ -625,8 +673,7 @@ List localTraining = Lists.newArrayList(
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
-JavaSchemaRDD training =
- jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -660,8 +707,6 @@ crossval.setNumFolds(2); // Use 3+ in practice
// Run cross-validation, and choose the best set of parameters.
CrossValidatorModel cvModel = crossval.fit(training);
-// Get the best LogisticRegression model (with the best set of parameters from paramGrid).
-Model lrModel = cvModel.bestModel();
// Prepare test documents, which are unlabeled.
List localTest = Lists.newArrayList(
@@ -669,15 +714,16 @@ List localTest = Lists.newArrayList(
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
-JavaSchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
-cvModel.transform(test).registerAsTable("prediction");
-JavaSchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
-for (Row r: predictions.collect()) {
- System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+DataFrame predictions = cvModel.transform(test);
+for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
+ System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+jsc.stop();
{% endhighlight %}
@@ -686,6 +732,21 @@ for (Row r: predictions.collect()) {
# Dependencies
Spark ML currently depends on MLlib and has the same dependencies.
-Please see the [MLlib Dependencies guide](mllib-guide.html#Dependencies) for more info.
+Please see the [MLlib Dependencies guide](mllib-guide.html#dependencies) for more info.
Spark ML also depends upon Spark SQL, but the relevant parts of Spark SQL do not bring additional dependencies.
+
+# Migration Guide
+
+## From 1.2 to 1.3
+
+The main API changes are from Spark SQL. We list the most important changes here:
+
+* The old [SchemaRDD](http://spark.apache.org/docs/1.2.1/api/scala/index.html#org.apache.spark.sql.SchemaRDD) has been replaced with [DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame) with a somewhat modified API. All algorithms in Spark ML which used to use SchemaRDD now use DataFrame.
+* In Spark 1.2, we used implicit conversions from `RDD`s of `LabeledPoint` into `SchemaRDD`s by calling `import sqlContext._` where `sqlContext` was an instance of `SQLContext`. These implicits have been moved, so we now call `import sqlContext.implicits._`.
+* Java APIs for SQL have also changed accordingly. Please see the examples above and the [Spark SQL Programming Guide](sql-programming-guide.html) for details.
+
+Other changes were in `LogisticRegression`:
+
+* The `scoreCol` output column (with default value "score") was renamed to be `probabilityCol` (with default value "probability"). The type was originally `Double` (for the probability of class 1.0), but it is now `Vector` (for the probability of each class, to support multiclass classification in the future).
+* In Spark 1.2, `LogisticRegressionModel` did not include an intercept. In Spark 1.3, it includes an intercept; however, it will always be 0.0 since it uses the default settings for [spark.mllib.LogisticRegressionWithLBFGS](api/scala/index.html#org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS). The option to use an intercept will be added in the future.
diff --git a/docs/mllib-classification-regression.md b/docs/mllib-classification-regression.md
index 719cc95767b00..8e91d62f4a907 100644
--- a/docs/mllib-classification-regression.md
+++ b/docs/mllib-classification-regression.md
@@ -17,13 +17,13 @@ the supported algorithms for each type of problem.
-
Binary Classification
linear SVMs, logistic regression, decision trees, naive Bayes
+
Binary Classification
linear SVMs, logistic regression, decision trees, random forests, gradient-boosted trees, naive Bayes
-
Multiclass Classification
decision trees, naive Bayes
+
Multiclass Classification
decision trees, random forests, naive Bayes
-
Regression
linear least squares, Lasso, ridge regression, decision trees
+
Regression
linear least squares, Lasso, ridge regression, decision trees, random forests, gradient-boosted trees, isotonic regression
@@ -34,4 +34,8 @@ More details for these methods can be found here:
* [binary classification (SVMs, logistic regression)](mllib-linear-methods.html#binary-classification)
* [linear regression (least squares, Lasso, ridge)](mllib-linear-methods.html#linear-least-squares-lasso-and-ridge-regression)
* [Decision trees](mllib-decision-tree.html)
+* [Ensembles of decision trees](mllib-ensembles.html)
+ * [random forests](mllib-ensembles.html#random-forests)
+ * [gradient-boosted trees](mllib-ensembles.html#gradient-boosted-trees-gbts)
* [Naive Bayes](mllib-naive-bayes.html)
+* [Isotonic regression](mllib-isotonic-regression.html)
diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md
index 99ed6b60e3f00..0b6db4fcb7b1f 100644
--- a/docs/mllib-clustering.md
+++ b/docs/mllib-clustering.md
@@ -4,28 +4,25 @@ title: Clustering - MLlib
displayTitle: MLlib - Clustering
---
-* Table of contents
-{:toc}
-
-
-## Clustering
-
Clustering is an unsupervised learning problem whereby we aim to group subsets
of entities with one another based on some notion of similarity. Clustering is
often used for exploratory analysis and/or as a component of a hierarchical
supervised learning pipeline (in which distinct classifiers or regression
-models are trained for each cluster).
+models are trained for each cluster).
MLlib supports the following models:
-### k-means
+* Table of contents
+{:toc}
+
+## K-means
[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the
most commonly used clustering algorithms that clusters the data points into a
predefined number of clusters. The MLlib implementation includes a parallelized
variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method
called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf).
-The implementation in MLlib has the following parameters:
+The implementation in MLlib has the following parameters:
* *k* is the number of desired clusters.
* *maxIterations* is the maximum number of iterations to run.
@@ -35,74 +32,9 @@ initialization via k-means\|\|.
guaranteed to find a globally optimal solution, and when run multiple times on
a given dataset, the algorithm returns the best clustering result).
* *initializationSteps* determines the number of steps in the k-means\|\| algorithm.
-* *epsilon* determines the distance threshold within which we consider k-means to have converged.
-
-### Gaussian mixture
-
-A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model)
-represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions,
-each with its own probability. The MLlib implementation uses the
-[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm)
- algorithm to induce the maximum-likelihood model given a set of samples. The implementation
-has the following parameters:
-
-* *k* is the number of desired clusters.
-* *convergenceTol* is the maximum change in log-likelihood at which we consider convergence achieved.
-* *maxIterations* is the maximum number of iterations to perform without reaching convergence.
-* *initialModel* is an optional starting point from which to start the EM algorithm. If this parameter is omitted, a random starting point will be constructed from the data.
-
-### Power Iteration Clustering
-
-Power iteration clustering is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm:
-
-* accepts a [Graph](api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points.
-* calculates the principal eigenvalue and eigenvector
-* Clusters each of the input points according to their principal eigenvector component value
-
-Details of this algorithm are found within [Power Iteration Clustering, Lin and Cohen]{www.icml2010.org/papers/387.pdf}
-
-Example outputs for a dataset inspired by the paper - but with five clusters instead of three- have he following output from our implementation:
-
-
-
-
-
-
-### Latent Dirichlet Allocation (LDA)
-
-[Latent Dirichlet Allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation)
-is a topic model which infers topics from a collection of text documents.
-LDA can be thought of as a clustering algorithm as follows:
-
-* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset.
-* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts.
-* Rather than estimating a clustering using a traditional distance, LDA uses a function based
- on a statistical model of how text documents are generated.
+* *epsilon* determines the distance threshold within which we consider k-means to have converged.
-LDA takes in a collection of documents as vectors of word counts.
-It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm)
-on the likelihood function. After fitting on the documents, LDA provides:
-
-* Topics: Inferred topics, each of which is a probability distribution over terms (words).
-* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics.
-
-LDA takes the following parameters:
-
-* `k`: Number of topics (i.e., cluster centers)
-* `maxIterations`: Limit on the number of iterations of EM used for learning
-* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions.
-* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions.
-* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery.
-
-*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet
-support prediction on new documents, and it does not have a Python API. These will be added in the future.
-
-### Examples
-
-#### k-means
+**Examples**
@@ -216,13 +148,27 @@ print("Within Set Sum of Squared Error = " + str(WSSSE))
-#### GaussianMixture
+## Gaussian mixture
+
+A [Gaussian Mixture Model](http://en.wikipedia.org/wiki/Mixture_model#Multivariate_Gaussian_mixture_model)
+represents a composite distribution whereby points are drawn from one of *k* Gaussian sub-distributions,
+each with its own probability. The MLlib implementation uses the
+[expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm)
+ algorithm to induce the maximum-likelihood model given a set of samples. The implementation
+has the following parameters:
+
+* *k* is the number of desired clusters.
+* *convergenceTol* is the maximum change in log-likelihood at which we consider convergence achieved.
+* *maxIterations* is the maximum number of iterations to perform without reaching convergence.
+* *initialModel* is an optional starting point from which to start the EM algorithm. If this parameter is omitted, a random starting point will be constructed from the data.
+
+**Examples**
In the following example after loading and parsing data, we use a
-[GaussianMixture](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixture)
-object to cluster the data into two clusters. The number of desired clusters is passed
+[GaussianMixture](api/scala/index.html#org.apache.spark.mllib.clustering.GaussianMixture)
+object to cluster the data into two clusters. The number of desired clusters is passed
to the algorithm. We then output the parameters of the mixture model.
{% highlight scala %}
@@ -238,7 +184,7 @@ val gmm = new GaussianMixture().setK(2).run(parsedData)
// output parameters of max-likelihood model
for (i <- 0 until gmm.k) {
- println("weight=%f\nmu=%s\nsigma=\n%s\n" format
+ println("weight=%f\nmu=%s\nsigma=\n%s\n" format
(gmm.weights(i), gmm.gaussians(i).mu, gmm.gaussians(i).sigma))
}
@@ -298,7 +244,7 @@ public class GaussianMixtureExample {
In the following example after loading and parsing data, we use a
[GaussianMixture](api/python/pyspark.mllib.html#pyspark.mllib.clustering.GaussianMixture)
-object to cluster the data into two clusters. The number of desired clusters is passed
+object to cluster the data into two clusters. The number of desired clusters is passed
to the algorithm. We then output the parameters of the mixture model.
{% highlight python %}
@@ -322,11 +268,129 @@ for i in range(2):
-#### Latent Dirichlet Allocation (LDA) Example
+## Power iteration clustering (PIC)
+
+Power iteration clustering (PIC) is a scalable and efficient algorithm for clustering vertices of a
+graph given pairwise similarties as edge properties,
+described in [Lin and Cohen, Power Iteration Clustering](http://www.icml2010.org/papers/387.pdf).
+It computes a pseudo-eigenvector of the normalized affinity matrix of the graph via
+[power iteration](http://en.wikipedia.org/wiki/Power_iteration) and uses it to cluster vertices.
+MLlib includes an implementation of PIC using GraphX as its backend.
+It takes an `RDD` of `(srcId, dstId, similarity)` tuples and outputs a model with the clustering assignments.
+The similarities must be nonnegative.
+PIC assumes that the similarity measure is symmetric.
+A pair `(srcId, dstId)` regardless of the ordering should appear at most once in the input data.
+If a pair is missing from input, their similarity is treated as zero.
+MLlib's PIC implementation takes the following (hyper-)parameters:
+
+* `k`: number of clusters
+* `maxIterations`: maximum number of power iterations
+* `initializationMode`: initialization model. This can be either "random", which is the default,
+ to use a random vector as vertex properties, or "degree" to use normalized sum similarities.
+
+**Examples**
+
+In the following, we show code snippets to demonstrate how to use PIC in MLlib.
+
+
+
+
+[`PowerIterationClustering`](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClustering)
+implements the PIC algorithm.
+It takes an `RDD` of `(srcId: Long, dstId: Long, similarity: Double)` tuples representing the
+affinity matrix.
+Calling `PowerIterationClustering.run` returns a
+[`PowerIterationClusteringModel`](api/scala/index.html#org.apache.spark.mllib.clustering.PowerIterationClusteringModel),
+which contains the computed clustering assignments.
+
+{% highlight scala %}
+import org.apache.spark.mllib.clustering.PowerIterationClustering
+import org.apache.spark.mllib.linalg.Vectors
+
+val similarities: RDD[(Long, Long, Double)] = ...
+
+val pic = new PowerIteartionClustering()
+ .setK(3)
+ .setMaxIterations(20)
+val model = pic.run(similarities)
+
+model.assignments.foreach { a =>
+ println(s"${a.id} -> ${a.cluster}")
+}
+{% endhighlight %}
+
+A full example that produces the experiment described in the PIC paper can be found under
+[`examples/`](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala).
+
+
+
+## Latent Dirichlet allocation (LDA)
+
+[Latent Dirichlet allocation (LDA)](http://en.wikipedia.org/wiki/Latent_Dirichlet_allocation)
+is a topic model which infers topics from a collection of text documents.
+LDA can be thought of as a clustering algorithm as follows:
+
+* Topics correspond to cluster centers, and documents correspond to examples (rows) in a dataset.
+* Topics and documents both exist in a feature space, where feature vectors are vectors of word counts.
+* Rather than estimating a clustering using a traditional distance, LDA uses a function based
+ on a statistical model of how text documents are generated.
+
+LDA takes in a collection of documents as vectors of word counts.
+It learns clustering using [expectation-maximization](http://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm)
+on the likelihood function. After fitting on the documents, LDA provides:
+
+* Topics: Inferred topics, each of which is a probability distribution over terms (words).
+* Topic distributions for documents: For each document in the training set, LDA gives a probability distribution over topics.
+
+LDA takes the following parameters:
+
+* `k`: Number of topics (i.e., cluster centers)
+* `maxIterations`: Limit on the number of iterations of EM used for learning
+* `docConcentration`: Hyperparameter for prior over documents' distributions over topics. Currently must be > 1, where larger values encourage smoother inferred distributions.
+* `topicConcentration`: Hyperparameter for prior over topics' distributions over terms (words). Currently must be > 1, where larger values encourage smoother inferred distributions.
+* `checkpointInterval`: If using checkpointing (set in the Spark configuration), this parameter specifies the frequency with which checkpoints will be created. If `maxIterations` is large, using checkpointing can help reduce shuffle file sizes on disk and help with failure recovery.
+
+*Note*: LDA is a new feature with some missing functionality. In particular, it does not yet
+support prediction on new documents, and it does not have a Python API. These will be added in the future.
+
+**Examples**
In the following example, we load word count vectors representing a corpus of documents.
We then use [LDA](api/scala/index.html#org.apache.spark.mllib.clustering.LDA)
-to infer three topics from the documents. The number of desired clusters is passed
+to infer three topics from the documents. The number of desired clusters is passed
to the algorithm. We then output the topics, represented as probability distributions over words.
@@ -419,42 +483,35 @@ public class JavaLDAExample {
+## Streaming k-means
-In order to run the above application, follow the instructions
-provided in the [Self-Contained Applications](quick-start.html#self-contained-applications)
-section of the Spark
-Quick Start guide. Be sure to also include *spark-mllib* to your build file as
-a dependency.
-
-## Streaming clustering
-
-When data arrive in a stream, we may want to estimate clusters dynamically,
-updating them as new data arrive. MLlib provides support for streaming k-means clustering,
-with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm
-uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign
+When data arrive in a stream, we may want to estimate clusters dynamically,
+updating them as new data arrive. MLlib provides support for streaming k-means clustering,
+with parameters to control the decay (or "forgetfulness") of the estimates. The algorithm
+uses a generalization of the mini-batch k-means update rule. For each batch of data, we assign
all points to their nearest cluster, compute new cluster centers, then update each cluster using:
`\begin{equation}
c_{t+1} = \frac{c_tn_t\alpha + x_tm_t}{n_t\alpha+m_t}
\end{equation}`
`\begin{equation}
- n_{t+1} = n_t + m_t
+ n_{t+1} = n_t + m_t
\end{equation}`
-Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned
-to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$`
-is the number of points added to the cluster in the current batch. The decay factor `$\alpha$`
-can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning;
-with `$\alpha$=0` only the most recent data will be used. This is analogous to an
-exponentially-weighted moving average.
+Where `$c_t$` is the previous center for the cluster, `$n_t$` is the number of points assigned
+to the cluster thus far, `$x_t$` is the new cluster center from the current batch, and `$m_t$`
+is the number of points added to the cluster in the current batch. The decay factor `$\alpha$`
+can be used to ignore the past: with `$\alpha$=1` all data will be used from the beginning;
+with `$\alpha$=0` only the most recent data will be used. This is analogous to an
+exponentially-weighted moving average.
-The decay can be specified using a `halfLife` parameter, which determines the
+The decay can be specified using a `halfLife` parameter, which determines the
correct decay factor `a` such that, for data acquired
at time `t`, its contribution by time `t + halfLife` will have dropped to 0.5.
The unit of time can be specified either as `batches` or `points` and the update rule
will be adjusted accordingly.
-### Examples
+**Examples**
This example shows how to estimate clusters on streaming data.
@@ -472,9 +529,9 @@ import org.apache.spark.mllib.clustering.StreamingKMeans
{% endhighlight %}
-Then we make an input stream of vectors for training, as well as a stream of labeled data
-points for testing. We assume a StreamingContext `ssc` has been created, see
-[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info.
+Then we make an input stream of vectors for training, as well as a stream of labeled data
+points for testing. We assume a StreamingContext `ssc` has been created, see
+[Spark Streaming Programming Guide](streaming-programming-guide.html#initializing) for more info.
{% highlight scala %}
@@ -496,24 +553,24 @@ val model = new StreamingKMeans()
{% endhighlight %}
-Now register the streams for training and testing and start the job, printing
+Now register the streams for training and testing and start the job, printing
the predicted cluster assignments on new data points as they arrive.
{% highlight scala %}
model.trainOn(trainingData)
-model.predictOnValues(testData).print()
+model.predictOnValues(testData.map(lp => (lp.label, lp.features))).print()
ssc.start()
ssc.awaitTermination()
-
+
{% endhighlight %}
-As you add new text files with data the cluster centers will update. Each training
+As you add new text files with data the cluster centers will update. Each training
point should be formatted as `[x1, x2, x3]`, and each test data point
-should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier
-(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir`
-the model will update. Anytime a text file is placed in `/testing/data/dir`
+should be formatted as `(y, [x1, x2, x3])`, where `y` is some useful label or identifier
+(e.g. a true category assignment). Anytime a text file is placed in `/training/data/dir`
+the model will update. Anytime a text file is placed in `/testing/data/dir`
you will see predictions. With new data, the cluster centers will change!
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md
index ef18cec9371d6..27aa4d38b7617 100644
--- a/docs/mllib-collaborative-filtering.md
+++ b/docs/mllib-collaborative-filtering.md
@@ -66,6 +66,7 @@ recommendation model by measuring the Mean Squared Error of rating prediction.
{% highlight scala %}
import org.apache.spark.mllib.recommendation.ALS
+import org.apache.spark.mllib.recommendation.MatrixFactorizationModel
import org.apache.spark.mllib.recommendation.Rating
// Load and parse the data
@@ -95,6 +96,10 @@ val MSE = ratesAndPreds.map { case ((user, product), (r1, r2)) =>
err * err
}.mean()
println("Mean Squared Error = " + MSE)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = MatrixFactorizationModel.load(sc, "myModelPath")
{% endhighlight %}
If the rating matrix is derived from another source of information (e.g., it is inferred from
@@ -181,6 +186,10 @@ public class CollaborativeFiltering {
}
).rdd()).mean();
System.out.println("Mean Squared Error = " + MSE);
+
+ // Save and load model
+ model.save(sc.sc(), "myModelPath");
+ MatrixFactorizationModel sameModel = MatrixFactorizationModel.load(sc.sc(), "myModelPath");
}
}
{% endhighlight %}
@@ -191,6 +200,8 @@ In the following example we load rating data. Each row consists of a user, a pro
We use the default ALS.train() method which assumes ratings are explicit. We evaluate the
recommendation by measuring the Mean Squared Error of rating prediction.
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.recommendation import ALS, Rating
diff --git a/docs/mllib-data-types.md b/docs/mllib-data-types.md
index 101dc2f8695f3..fe6c1bf7bfd99 100644
--- a/docs/mllib-data-types.md
+++ b/docs/mllib-data-types.md
@@ -296,6 +296,70 @@ backed by an RDD of its entries.
The underlying RDDs of a distributed matrix must be deterministic, because we cache the matrix size.
In general the use of non-deterministic RDDs can lead to errors.
+### BlockMatrix
+
+A `BlockMatrix` is a distributed matrix backed by an RDD of `MatrixBlock`s, where a `MatrixBlock` is
+a tuple of `((Int, Int), Matrix)`, where the `(Int, Int)` is the index of the block, and `Matrix` is
+the sub-matrix at the given index with size `rowsPerBlock` x `colsPerBlock`.
+`BlockMatrix` supports methods such as `add` and `multiply` with another `BlockMatrix`.
+`BlockMatrix` also has a helper function `validate` which can be used to check whether the
+`BlockMatrix` is set up properly.
+
+
+
+
+A [`BlockMatrix`](api/scala/index.html#org.apache.spark.mllib.linalg.distributed.BlockMatrix) can be
+most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`.
+`toBlockMatrix` creates blocks of size 1024 x 1024 by default.
+Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`.
+
+{% highlight scala %}
+import org.apache.spark.mllib.linalg.distributed.{BlockMatrix, CoordinateMatrix, MatrixEntry}
+
+val entries: RDD[MatrixEntry] = ... // an RDD of (i, j, v) matrix entries
+// Create a CoordinateMatrix from an RDD[MatrixEntry].
+val coordMat: CoordinateMatrix = new CoordinateMatrix(entries)
+// Transform the CoordinateMatrix to a BlockMatrix
+val matA: BlockMatrix = coordMat.toBlockMatrix().cache()
+
+// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid.
+// Nothing happens if it is valid.
+matA.validate()
+
+// Calculate A^T A.
+val ata = matA.transpose.multiply(matA)
+{% endhighlight %}
+
+
+
+
+A [`BlockMatrix`](api/java/org/apache/spark/mllib/linalg/distributed/BlockMatrix.html) can be
+most easily created from an `IndexedRowMatrix` or `CoordinateMatrix` by calling `toBlockMatrix`.
+`toBlockMatrix` creates blocks of size 1024 x 1024 by default.
+Users may change the block size by supplying the values through `toBlockMatrix(rowsPerBlock, colsPerBlock)`.
+
+{% highlight java %}
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.mllib.linalg.distributed.BlockMatrix;
+import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
+import org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix;
+
+JavaRDD entries = ... // a JavaRDD of (i, j, v) Matrix Entries
+// Create a CoordinateMatrix from a JavaRDD.
+CoordinateMatrix coordMat = new CoordinateMatrix(entries.rdd());
+// Transform the CoordinateMatrix to a BlockMatrix
+BlockMatrix matA = coordMat.toBlockMatrix().cache();
+
+// Validate whether the BlockMatrix is set up properly. Throws an Exception when it is not valid.
+// Nothing happens if it is valid.
+matA.validate();
+
+// Calculate A^T A.
+BlockMatrix ata = matA.transpose().multiply(matA);
+{% endhighlight %}
+
+
+
### RowMatrix
A `RowMatrix` is a row-oriented distributed matrix without meaningful row indices, backed by an RDD
diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index d1537def851e7..8e478ab035582 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -54,8 +54,8 @@ impurity measure for regression (variance).
Variance
Regression
-
$\frac{1}{N} \sum_{i=1}^{N} (x_i - \mu)^2$
$y_i$ is label for an instance,
- $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N x_i$.
+
$\frac{1}{N} \sum_{i=1}^{N} (y_i - \mu)^2$
$y_i$ is label for an instance,
+ $N$ is the number of instances and $\mu$ is the mean given by $\frac{1}{N} \sum_{i=1}^N y_i$.
@@ -194,6 +194,7 @@ maximum tree depth of 5. The test error is calculated to measure the algorithm a
{% highlight scala %}
import org.apache.spark.mllib.tree.DecisionTree
+import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
@@ -221,6 +222,10 @@ val labelAndPreds = testData.map { point =>
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification tree model:\n" + model.toDebugString)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = DecisionTreeModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -279,10 +284,17 @@ Double testErr =
}).count() / testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:\n" + model.toDebugString());
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
+
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
@@ -324,6 +336,7 @@ depth of 5. The Mean Squared Error (MSE) is computed at the end to evaluate
{% highlight scala %}
import org.apache.spark.mllib.tree.DecisionTree
+import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
@@ -350,6 +363,10 @@ val labelsAndPredictions = testData.map { point =>
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression tree model:\n" + model.toDebugString)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = DecisionTreeModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -414,10 +431,17 @@ Double testMSE =
}) / data.count();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression tree model:\n" + model.toDebugString());
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+DecisionTreeModel sameModel = DecisionTreeModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
+
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
diff --git a/docs/mllib-ensembles.md b/docs/mllib-ensembles.md
index 23ede04b62d5b..ec1ef38b453d3 100644
--- a/docs/mllib-ensembles.md
+++ b/docs/mllib-ensembles.md
@@ -98,6 +98,7 @@ The test error is calculated to measure the algorithm accuracy.
{% highlight scala %}
import org.apache.spark.mllib.tree.RandomForest
+import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
@@ -127,6 +128,10 @@ val labelAndPreds = testData.map { point =>
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification forest model:\n" + model.toDebugString)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
+
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.tree import RandomForest
from pyspark.mllib.util import MLUtils
@@ -235,6 +247,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
{% highlight scala %}
import org.apache.spark.mllib.tree.RandomForest
+import org.apache.spark.mllib.tree.model.RandomForestModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
@@ -264,6 +277,10 @@ val labelsAndPredictions = testData.map { point =>
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression forest model:\n" + model.toDebugString)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = RandomForestModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -328,10 +345,17 @@ Double testMSE =
}) / testData.count();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression forest model:\n" + model.toDebugString());
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+RandomForestModel sameModel = RandomForestModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
+
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.tree import RandomForest
from pyspark.mllib.util import MLUtils
@@ -427,10 +451,19 @@ We omit some decision tree parameters since those are covered in the [decision t
* **`algo`**: The algorithm or task (classification vs. regression) is set using the tree [Strategy] parameter.
+#### Validation while training
-### Examples
+Gradient boosting can overfit when trained with more trees. In order to prevent overfitting, it is useful to validate while
+training. The method runWithValidation has been provided to make use of this option. It takes a pair of RDD's as arguments, the
+first one being the training dataset and the second being the validation dataset.
-GBTs currently have APIs in Scala and Java. Examples in both languages are shown below.
+The training is stopped when the improvement in the validation error is not more than a certain tolerance
+(supplied by the `validationTol` argument in `BoostingStrategy`). In practice, the validation error
+decreases initially and later increases. There might be cases in which the validation error does not change monotonically,
+and the user is advised to set a large enough negative tolerance and examine the validation curve to to tune the number of
+iterations.
+
+### Examples
#### Classification
@@ -446,6 +479,7 @@ The test error is calculated to measure the algorithm accuracy.
{% highlight scala %}
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
@@ -458,7 +492,7 @@ val (trainingData, testData) = (splits(0), splits(1))
// The defaultParams for Classification use LogLoss by default.
val boostingStrategy = BoostingStrategy.defaultParams("Classification")
boostingStrategy.numIterations = 3 // Note: Use more iterations in practice.
-boostingStrategy.treeStrategy.numClassesForClassification = 2
+boostingStrategy.treeStrategy.numClasses = 2
boostingStrategy.treeStrategy.maxDepth = 5
// Empty categoricalFeaturesInfo indicates all features are continuous.
boostingStrategy.treeStrategy.categoricalFeaturesInfo = Map[Int, Int]()
@@ -473,6 +507,10 @@ val labelAndPreds = testData.map { point =>
val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count()
println("Test Error = " + testErr)
println("Learned classification GBT model:\n" + model.toDebugString)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
+
+Note that the Python API does not yet support model save/load but will in the future.
+
+{% highlight python %}
+from pyspark.mllib.tree import GradientBoostedTrees
+from pyspark.mllib.util import MLUtils
+
+# Load and parse the data file.
+data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+# Split the data into training and test sets (30% held out for testing)
+(trainingData, testData) = data.randomSplit([0.7, 0.3])
+
+# Train a GradientBoostedTrees model.
+# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous.
+# (b) Use more iterations in practice.
+model = GradientBoostedTrees.trainClassifier(trainingData,
+ categoricalFeaturesInfo={}, numIterations=3)
+
+# Evaluate model on test instances and compute test error
+predictions = model.predict(testData.map(lambda x: x.features))
+labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
+testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() / float(testData.count())
+print('Test Error = ' + str(testErr))
+print('Learned classification GBT model:')
+print(model.toDebugString())
{% endhighlight %}
@@ -554,6 +625,7 @@ The Mean Squared Error (MSE) is computed at the end to evaluate
{% highlight scala %}
import org.apache.spark.mllib.tree.GradientBoostedTrees
import org.apache.spark.mllib.tree.configuration.BoostingStrategy
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLUtils
// Load and parse the data file.
@@ -580,6 +652,10 @@ val labelsAndPredictions = testData.map { point =>
val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean()
println("Test Mean Squared Error = " + testMSE)
println("Learned regression GBT model:\n" + model.toDebugString)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = GradientBoostedTreesModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -647,6 +723,39 @@ Double testMSE =
}) / data.count();
System.out.println("Test Mean Squared Error: " + testMSE);
System.out.println("Learned regression GBT model:\n" + model.toDebugString());
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+GradientBoostedTreesModel sameModel = GradientBoostedTreesModel.load(sc.sc(), "myModelPath");
+{% endhighlight %}
+
+
+
+
+Note that the Python API does not yet support model save/load but will in the future.
+
+{% highlight python %}
+from pyspark.mllib.tree import GradientBoostedTrees
+from pyspark.mllib.util import MLUtils
+
+# Load and parse the data file.
+data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+# Split the data into training and test sets (30% held out for testing)
+(trainingData, testData) = data.randomSplit([0.7, 0.3])
+
+# Train a GradientBoostedTrees model.
+# Notes: (a) Empty categoricalFeaturesInfo indicates all features are continuous.
+# (b) Use more iterations in practice.
+model = GradientBoostedTrees.trainRegressor(trainingData,
+ categoricalFeaturesInfo={}, numIterations=3)
+
+# Evaluate model on test instances and compute test error
+predictions = model.predict(testData.map(lambda x: x.features))
+labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions)
+testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() / float(testData.count())
+print('Test Mean Squared Error = ' + str(testMSE))
+print('Learned regression GBT model:')
+print(model.toDebugString())
{% endhighlight %}
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md
index d4a61a7fbf3d7..80842b27effd8 100644
--- a/docs/mllib-feature-extraction.md
+++ b/docs/mllib-feature-extraction.md
@@ -375,3 +375,105 @@ data2 = labels.zip(normalizer2.transform(features))
{% endhighlight %}
+
+## Feature selection
+[Feature selection](http://en.wikipedia.org/wiki/Feature_selection) allows selecting the most relevant features for use in model construction. Feature selection reduces the size of the vector space and, in turn, the complexity of any subsequent operation with vectors. The number of features to select can be tuned using a held-out validation set.
+
+### ChiSqSelector
+[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) stands for Chi-Squared feature selection. It operates on labeled data with categorical features. `ChiSqSelector` orders features based on a Chi-Squared test of independence from the class, and then filters (selects) the top features which are most closely related to the label.
+
+#### Model Fitting
+
+[`ChiSqSelector`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) has the
+following parameters in the constructor:
+
+* `numTopFeatures` number of top features that the selector will select (filter).
+
+We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.ChiSqSelector) method in
+`ChiSqSelector` which can take an input of `RDD[LabeledPoint]` with categorical features, learn the summary statistics, and then
+return a `ChiSqSelectorModel` which can transform an input dataset into the reduced feature space.
+
+This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer)
+which can apply the Chi-Squared feature selection on a `Vector` to produce a reduced `Vector` or on
+an `RDD[Vector]` to produce a reduced `RDD[Vector]`.
+
+Note that the user can also construct a `ChiSqSelectorModel` by hand by providing an array of selected feature indices (which must be sorted in ascending order).
+
+#### Example
+
+The following example shows the basic use of ChiSqSelector.
+
+
+
+{% highlight scala %}
+import org.apache.spark.SparkContext._
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.util.MLUtils
+
+// Load some data in libsvm format
+val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt")
+// Discretize data in 16 equal bins since ChiSqSelector requires categorical features
+val discretizedData = data.map { lp =>
+ LabeledPoint(lp.label, Vectors.dense(lp.features.toArray.map { x => x / 16 } ) )
+}
+// Create ChiSqSelector that will select 50 features
+val selector = new ChiSqSelector(50)
+// Create ChiSqSelector model (selecting features)
+val transformer = selector.fit(discretizedData)
+// Filter the top 50 features from each feature vector
+val filteredData = discretizedData.map { lp =>
+ LabeledPoint(lp.label, transformer.transform(lp.features))
+}
+{% endhighlight %}
+
+
+
+{% highlight java %}
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.feature.ChiSqSelector;
+import org.apache.spark.mllib.feature.ChiSqSelectorModel;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.mllib.util.MLUtils;
+
+SparkConf sparkConf = new SparkConf().setAppName("JavaChiSqSelector");
+JavaSparkContext sc = new JavaSparkContext(sparkConf);
+JavaRDD points = MLUtils.loadLibSVMFile(sc.sc(),
+ "data/mllib/sample_libsvm_data.txt").toJavaRDD().cache();
+
+// Discretize data in 16 equal bins since ChiSqSelector requires categorical features
+JavaRDD discretizedData = points.map(
+ new Function() {
+ @Override
+ public LabeledPoint call(LabeledPoint lp) {
+ final double[] discretizedFeatures = new double[lp.features().size()];
+ for (int i = 0; i < lp.features().size(); ++i) {
+ discretizedFeatures[i] = lp.features().apply(i) / 16;
+ }
+ return new LabeledPoint(lp.label(), Vectors.dense(discretizedFeatures));
+ }
+ });
+
+// Create ChiSqSelector that will select 50 features
+ChiSqSelector selector = new ChiSqSelector(50);
+// Create ChiSqSelector model (selecting features)
+final ChiSqSelectorModel transformer = selector.fit(discretizedData.rdd());
+// Filter the top 50 features from each feature vector
+JavaRDD filteredData = discretizedData.map(
+ new Function() {
+ @Override
+ public LabeledPoint call(LabeledPoint lp) {
+ return new LabeledPoint(lp.label(), transformer.transform(lp.features()));
+ }
+ }
+);
+
+sc.stop();
+{% endhighlight %}
+
+
+
diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md
new file mode 100644
index 0000000000000..9fd9be0dd01b1
--- /dev/null
+++ b/docs/mllib-frequent-pattern-mining.md
@@ -0,0 +1,98 @@
+---
+layout: global
+title: Frequent Pattern Mining - MLlib
+displayTitle: MLlib - Frequent Pattern Mining
+---
+
+Mining frequent items, itemsets, subsequences, or other substructures is usually among the
+first steps to analyze a large-scale dataset, which has been an active research topic in
+data mining for years.
+We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning)
+for more information.
+MLlib provides a parallel implementation of FP-growth,
+a popular algorithm to mining frequent itemsets.
+
+## FP-growth
+
+The FP-growth algorithm is described in the paper
+[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372),
+where "FP" stands for frequent pattern.
+Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items.
+Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose,
+the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets
+explicitly, which are usually expensive to generate.
+After the second step, the frequent itemsets can be extracted from the FP-tree.
+In MLlib, we implemented a parallel version of FP-growth called PFP,
+as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027).
+PFP distributes the work of growing FP-trees based on the suffices of transactions,
+and hence more scalable than a single-machine implementation.
+We refer users to the papers for more details.
+
+MLlib's FP-growth implementation takes the following (hyper-)parameters:
+
+* `minSupport`: the minimum support for an itemset to be identified as frequent.
+ For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6.
+* `numPartitions`: the number of partitions used to distribute the work.
+
+**Examples**
+
+
+
+
+[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the
+FP-growth algorithm.
+It take a `JavaRDD` of transactions, where each transaction is an `Iterable` of items of a generic type.
+Calling `FPGrowth.run` with transactions returns an
+[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html)
+that stores the frequent itemsets with their frequencies.
+
+{% highlight scala %}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
+
+val transactions: RDD[Array[String]] = ...
+
+val fpg = new FPGrowth()
+ .setMinSupport(0.2)
+ .setNumPartitions(10)
+val model = fpg.run(transactions)
+
+model.freqItemsets.collect().foreach { itemset =>
+ println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq)
+}
+{% endhighlight %}
+
+
+
+
+
+[`FPGrowth`](api/java/org/apache/spark/mllib/fpm/FPGrowth.html) implements the
+FP-growth algorithm.
+It take an `RDD` of transactions, where each transaction is an `Array` of items of a generic type.
+Calling `FPGrowth.run` with transactions returns an
+[`FPGrowthModel`](api/java/org/apache/spark/mllib/fpm/FPGrowthModel.html)
+that stores the frequent itemsets with their frequencies.
+
+{% highlight java %}
+import java.util.List;
+
+import com.google.common.base.Joiner;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.mllib.fpm.FPGrowth;
+import org.apache.spark.mllib.fpm.FPGrowthModel;
+
+JavaRDD> transactions = ...
+
+FPGrowth fpg = new FPGrowth()
+ .setMinSupport(0.2)
+ .setNumPartitions(10);
+FPGrowthModel model = fpg.run(transactions);
+
+for (FPGrowth.FreqItemset itemset: model.freqItemsets().toJavaRDD().collect()) {
+ System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq());
+}
+{% endhighlight %}
+
+
+
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index 3d32d03e35c62..4c7a7d9115ca1 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -21,16 +21,21 @@ filtering, dimensionality reduction, as well as underlying optimization primitiv
* [naive Bayes](mllib-naive-bayes.html)
* [decision trees](mllib-decision-tree.html)
* [ensembles of trees](mllib-ensembles.html) (Random Forests and Gradient-Boosted Trees)
+ * [isotonic regression](mllib-isotonic-regression.html)
* [Collaborative filtering](mllib-collaborative-filtering.html)
* alternating least squares (ALS)
* [Clustering](mllib-clustering.html)
- * k-means
- * Gaussian mixture
- * power iteration
+ * [k-means](mllib-clustering.html#k-means)
+ * [Gaussian mixture](mllib-clustering.html#gaussian-mixture)
+ * [power iteration clustering (PIC)](mllib-clustering.html#power-iteration-clustering-pic)
+ * [latent Dirichlet allocation (LDA)](mllib-clustering.html#latent-dirichlet-allocation-lda)
+ * [streaming k-means](mllib-clustering.html#streaming-k-means)
* [Dimensionality reduction](mllib-dimensionality-reduction.html)
* singular value decomposition (SVD)
* principal component analysis (PCA)
* [Feature extraction and transformation](mllib-feature-extraction.html)
+* [Frequent pattern mining](mllib-frequent-pattern-mining.html)
+ * FP-growth
* [Optimization (developer)](mllib-optimization.html)
* stochastic gradient descent
* limited-memory BFGS (L-BFGS)
@@ -41,7 +46,7 @@ and the migration guide below will explain all changes between releases.
# spark.ml: high-level APIs for ML pipelines
-Spark 1.2 includes a new package called `spark.ml`, which aims to provide a uniform set of
+Spark 1.2 introduced a new package called `spark.ml`, which aims to provide a uniform set of
high-level APIs that help users create and tune practical machine learning pipelines.
It is currently an alpha component, and we would like to hear back from the community about
how it fits real-world use cases and how it could be improved.
@@ -87,125 +92,22 @@ version 1.4 or newer.
# Migration Guide
-## From 1.1 to 1.2
+For the `spark.ml` package, please see the [spark.ml Migration Guide](ml-guide.html#migration-guide).
-The only API changes in MLlib v1.2 are in
-[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
-which continues to be an experimental API in MLlib 1.2:
+## From 1.2 to 1.3
-1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number
-of classes. In MLlib v1.1, this argument was called `numClasses` in Python and
-`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`.
-This `numClasses` parameter is specified either via
-[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy)
-or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree)
-static `trainClassifier` and `trainRegressor` methods.
+In the `spark.mllib` package, there were several breaking changes. The first change (in `ALS`) is the only one in a component not marked as Alpha or Experimental.
-2. *(Breaking change)* The API for
-[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed.
-This should generally not affect user code, unless the user manually constructs decision trees
-(instead of using the `trainClassifier` or `trainRegressor` methods).
-The tree `Node` now includes more information, including the probability of the predicted label
-(for classification).
+* *(Breaking change)* In [`ALS`](api/scala/index.html#org.apache.spark.mllib.recommendation.ALS), the extraneous method `solveLeastSquares` has been removed. The `DeveloperApi` method `analyzeBlocks` was also removed.
+* *(Breaking change)* [`StandardScalerModel`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScalerModel) remains an Alpha component. In it, the `variance` method has been replaced with the `std` method. To compute the column variance values returned by the original `variance` method, simply square the standard deviation values returned by `std`.
+* *(Breaking change)* [`StreamingLinearRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.StreamingLinearRegressionWithSGD) remains an Experimental component. In it, there were two changes:
+ * The constructor taking arguments was removed in favor of a builder patten using the default constructor plus parameter setter methods.
+ * Variable `model` is no longer public.
+* *(Breaking change)* [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree) remains an Experimental component. In it and its associated classes, there were several changes:
+ * In `DecisionTree`, the deprecated class method `train` has been removed. (The object/static `train` methods remain.)
+ * In `Strategy`, the `checkpointDir` parameter has been removed. Checkpointing is still supported, but the checkpoint directory must be set before calling tree and tree ensemble training.
+* `PythonMLlibAPI` (the interface between Scala/Java and Python for MLlib) was a public API but is now private, declared `private[python]`. This was never meant for external use.
-3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`.
+## Previous Spark Versions
-Examples in the Spark distribution and examples in the
-[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly.
-
-## From 1.0 to 1.1
-
-The only API changes in MLlib v1.1 are in
-[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
-which continues to be an experimental API in MLlib 1.1:
-
-1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match
-the implementations of trees in
-[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree)
-and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html).
-In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes.
-In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes.
-This depth is specified by the `maxDepth` parameter in
-[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy)
-or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree)
-static `trainClassifier` and `trainRegressor` methods.
-
-2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor`
-methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
-rather than using the old parameter class `Strategy`. These new training methods explicitly
-separate classification and regression, and they replace specialized parameter types with
-simple `String` types.
-
-Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the
-[Decision Trees Guide](mllib-decision-tree.html#examples).
-
-## From 0.9 to 1.0
-
-In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few
-breaking changes. If your data is sparse, please store it in a sparse format instead of dense to
-take advantage of sparsity in both storage and computation. Details are described below.
-
-
-
-
-We used to represent a feature vector by `Array[Double]`, which is replaced by
-[`Vector`](api/scala/index.html#org.apache.spark.mllib.linalg.Vector) in v1.0. Algorithms that used
-to accept `RDD[Array[Double]]` now take
-`RDD[Vector]`. [`LabeledPoint`](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint)
-is now a wrapper of `(Double, Vector)` instead of `(Double, Array[Double])`. Converting
-`Array[Double]` to `Vector` is straightforward:
-
-{% highlight scala %}
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
-
-val array: Array[Double] = ... // a double array
-val vector: Vector = Vectors.dense(array) // a dense vector
-{% endhighlight %}
-
-[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to create sparse vectors.
-
-*Note*: Scala imports `scala.collection.immutable.Vector` by default, so you have to import `org.apache.spark.mllib.linalg.Vector` explicitly to use MLlib's `Vector`.
-
-
-
-
-
-We used to represent a feature vector by `double[]`, which is replaced by
-[`Vector`](api/java/index.html?org/apache/spark/mllib/linalg/Vector.html) in v1.0. Algorithms that used
-to accept `RDD` now take
-`RDD`. [`LabeledPoint`](api/java/index.html?org/apache/spark/mllib/regression/LabeledPoint.html)
-is now a wrapper of `(double, Vector)` instead of `(double, double[])`. Converting `double[]` to
-`Vector` is straightforward:
-
-{% highlight java %}
-import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.Vectors;
-
-double[] array = ... // a double array
-Vector vector = Vectors.dense(array); // a dense vector
-{% endhighlight %}
-
-[`Vectors`](api/scala/index.html#org.apache.spark.mllib.linalg.Vectors$) provides factory methods to
-create sparse vectors.
-
-
-
-
-
-We used to represent a labeled feature vector in a NumPy array, where the first entry corresponds to
-the label and the rest are features. This representation is replaced by class
-[`LabeledPoint`](api/python/pyspark.mllib.regression.LabeledPoint-class.html), which takes both
-dense and sparse feature vectors.
-
-{% highlight python %}
-from pyspark.mllib.linalg import SparseVector
-from pyspark.mllib.regression import LabeledPoint
-
-# Create a labeled point with a positive label and a dense feature vector.
-pos = LabeledPoint(1.0, [1.0, 0.0, 3.0])
-
-# Create a labeled point with a negative label and a sparse feature vector.
-neg = LabeledPoint(0.0, SparseVector(3, [0, 2], [1.0, 3.0]))
-{% endhighlight %}
-
-
+Earlier migration guides are archived [on this page](mllib-migration-guides.html).
diff --git a/docs/mllib-isotonic-regression.md b/docs/mllib-isotonic-regression.md
new file mode 100644
index 0000000000000..12fb29d426741
--- /dev/null
+++ b/docs/mllib-isotonic-regression.md
@@ -0,0 +1,155 @@
+---
+layout: global
+title: Naive Bayes - MLlib
+displayTitle: MLlib - Regression
+---
+
+## Isotonic regression
+[Isotonic regression](http://en.wikipedia.org/wiki/Isotonic_regression)
+belongs to the family of regression algorithms. Formally isotonic regression is a problem where
+given a finite set of real numbers `$Y = {y_1, y_2, ..., y_n}$` representing observed responses
+and `$X = {x_1, x_2, ..., x_n}$` the unknown response values to be fitted
+finding a function that minimises
+
+`\begin{equation}
+ f(x) = \sum_{i=1}^n w_i (y_i - x_i)^2
+\end{equation}`
+
+with respect to complete order subject to
+`$x_1\le x_2\le ...\le x_n$` where `$w_i$` are positive weights.
+The resulting function is called isotonic regression and it is unique.
+It can be viewed as least squares problem under order restriction.
+Essentially isotonic regression is a
+[monotonic function](http://en.wikipedia.org/wiki/Monotonic_function)
+best fitting the original data points.
+
+MLlib supports a
+[pool adjacent violators algorithm](http://doi.org/10.1198/TECH.2010.10111)
+which uses an approach to
+[parallelizing isotonic regression](http://doi.org/10.1007/978-3-642-99789-1_10).
+The training input is a RDD of tuples of three double values that represent
+label, feature and weight in this order. Additionally IsotonicRegression algorithm has one
+optional parameter called $isotonic$ defaulting to true.
+This argument specifies if the isotonic regression is
+isotonic (monotonically increasing) or antitonic (monotonically decreasing).
+
+Training returns an IsotonicRegressionModel that can be used to predict
+labels for both known and unknown features. The result of isotonic regression
+is treated as piecewise linear function. The rules for prediction therefore are:
+
+* If the prediction input exactly matches a training feature
+ then associated prediction is returned. In case there are multiple predictions with the same
+ feature then one of them is returned. Which one is undefined
+ (same as java.util.Arrays.binarySearch).
+* If the prediction input is lower or higher than all training features
+ then prediction with lowest or highest feature is returned respectively.
+ In case there are multiple predictions with the same feature
+ then the lowest or highest is returned respectively.
+* If the prediction input falls between two training features then prediction is treated
+ as piecewise linear function and interpolated value is calculated from the
+ predictions of the two closest features. In case there are multiple values
+ with the same feature then the same rules as in previous point are used.
+
+### Examples
+
+
+
+Data are read from a file where each line has a format label,feature
+i.e. 4710.28,500.00. The data are split to training and testing set.
+Model is created using the training set and a mean squared error is calculated from the predicted
+labels and real labels in the test set.
+
+{% highlight scala %}
+import org.apache.spark.mllib.regression.IsotonicRegression
+
+val data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt")
+
+// Create label, feature, weight tuples from input data with weight set to default value 1.0.
+val parsedData = data.map { line =>
+ val parts = line.split(',').map(_.toDouble)
+ (parts(0), parts(1), 1.0)
+}
+
+// Split data into training (60%) and test (40%) sets.
+val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
+val training = splits(0)
+val test = splits(1)
+
+// Create isotonic regression model from training data.
+// Isotonic parameter defaults to true so it is only shown for demonstration
+val model = new IsotonicRegression().setIsotonic(true).run(training)
+
+// Create tuples of predicted and real labels.
+val predictionAndLabel = test.map { point =>
+ val predictedLabel = model.predict(point._2)
+ (predictedLabel, point._1)
+}
+
+// Calculate mean squared error between predicted and real labels.
+val meanSquaredError = predictionAndLabel.map{case(p, l) => math.pow((p - l), 2)}.mean()
+println("Mean Squared Error = " + meanSquaredError)
+{% endhighlight %}
+
+
+
+Data are read from a file where each line has a format label,feature
+i.e. 4710.28,500.00. The data are split to training and testing set.
+Model is created using the training set and a mean squared error is calculated from the predicted
+labels and real labels in the test set.
+
+{% highlight java %}
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaDoubleRDD;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.mllib.regression.IsotonicRegressionModel;
+import scala.Tuple2;
+import scala.Tuple3;
+
+JavaRDD data = sc.textFile("data/mllib/sample_isotonic_regression_data.txt");
+
+// Create label, feature, weight tuples from input data with weight set to default value 1.0.
+JavaRDD> parsedData = data.map(
+ new Function>() {
+ public Tuple3 call(String line) {
+ String[] parts = line.split(",");
+ return new Tuple3<>(new Double(parts[0]), new Double(parts[1]), 1.0);
+ }
+ }
+);
+
+// Split data into training (60%) and test (40%) sets.
+JavaRDD>[] splits = parsedData.randomSplit(new double[] {0.6, 0.4}, 11L);
+JavaRDD> training = splits[0];
+JavaRDD> test = splits[1];
+
+// Create isotonic regression model from training data.
+// Isotonic parameter defaults to true so it is only shown for demonstration
+IsotonicRegressionModel model = new IsotonicRegression().setIsotonic(true).run(training);
+
+// Create tuples of predicted and real labels.
+JavaPairRDD predictionAndLabel = test.mapToPair(
+ new PairFunction, Double, Double>() {
+ @Override public Tuple2 call(Tuple3 point) {
+ Double predictedLabel = model.predict(point._2());
+ return new Tuple2(predictedLabel, point._1());
+ }
+ }
+);
+
+// Calculate mean squared error between predicted and real labels.
+Double meanSquaredError = new JavaDoubleRDD(predictionAndLabel.map(
+ new Function, Object>() {
+ @Override public Object call(Tuple2 pl) {
+ return Math.pow(pl._1() - pl._2(), 2);
+ }
+ }
+).rdd()).mean();
+
+System.out.println("Mean Squared Error = " + meanSquaredError);
+{% endhighlight %}
+
+
\ No newline at end of file
diff --git a/docs/mllib-linear-methods.md b/docs/mllib-linear-methods.md
index 44b7f67c57734..ffbd7ef1bff51 100644
--- a/docs/mllib-linear-methods.md
+++ b/docs/mllib-linear-methods.md
@@ -190,7 +190,7 @@ error.
{% highlight scala %}
import org.apache.spark.SparkContext
-import org.apache.spark.mllib.classification.SVMWithSGD
+import org.apache.spark.mllib.classification.{SVMModel, SVMWithSGD}
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.linalg.Vectors
@@ -222,6 +222,10 @@ val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val auROC = metrics.areaUnderROC()
println("Area under ROC = " + auROC)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = SVMModel.load(sc, "myModelPath")
{% endhighlight %}
The `SVMWithSGD.train()` method by default performs L2 regularization with the
@@ -304,6 +308,10 @@ public class SVMClassifier {
double auROC = metrics.areaUnderROC();
System.out.println("Area under ROC = " + auROC);
+
+ // Save and load model
+ model.save(sc.sc(), "myModelPath");
+ SVMModel sameModel = SVMModel.load(sc.sc(), "myModelPath");
}
}
{% endhighlight %}
@@ -338,6 +346,8 @@ a dependency.
The following example shows how to load a sample dataset, build Logistic Regression model,
and make predictions with the resulting model to compute the training error.
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.classification import LogisticRegressionWithSGD
from pyspark.mllib.regression import LabeledPoint
@@ -391,8 +401,9 @@ values. We compute the mean squared error at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
{% highlight scala %}
-import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.regression.LinearRegressionModel
+import org.apache.spark.mllib.regression.LinearRegressionWithSGD
import org.apache.spark.mllib.linalg.Vectors
// Load and parse the data
@@ -413,6 +424,10 @@ val valuesAndPreds = parsedData.map { point =>
}
val MSE = valuesAndPreds.map{case(v, p) => math.pow((v - p), 2)}.mean()
println("training Mean Squared Error = " + MSE)
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = LinearRegressionModel.load(sc, "myModelPath")
{% endhighlight %}
[`RidgeRegressionWithSGD`](api/scala/index.html#org.apache.spark.mllib.regression.RidgeRegressionWithSGD)
@@ -483,6 +498,10 @@ public class LinearRegression {
}
).rdd()).mean();
System.out.println("training Mean Squared Error = " + MSE);
+
+ // Save and load model
+ model.save(sc.sc(), "myModelPath");
+ LinearRegressionModel sameModel = LinearRegressionModel.load(sc.sc(), "myModelPath");
}
}
{% endhighlight %}
@@ -494,6 +513,8 @@ The example then uses LinearRegressionWithSGD to build a simple linear model to
values. We compute the mean squared error at the end to evaluate
[goodness of fit](http://en.wikipedia.org/wiki/Goodness_of_fit).
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.regression import LabeledPoint, LinearRegressionWithSGD
from numpy import array
diff --git a/docs/mllib-migration-guides.md b/docs/mllib-migration-guides.md
new file mode 100644
index 0000000000000..4de2d9491ac2b
--- /dev/null
+++ b/docs/mllib-migration-guides.md
@@ -0,0 +1,67 @@
+---
+layout: global
+title: Old Migration Guides - MLlib
+displayTitle: MLlib - Old Migration Guides
+description: MLlib migration guides from before Spark SPARK_VERSION_SHORT
+---
+
+The migration guide for the current Spark version is kept on the [MLlib Programming Guide main page](mllib-guide.html#migration-guide).
+
+## From 1.1 to 1.2
+
+The only API changes in MLlib v1.2 are in
+[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
+which continues to be an experimental API in MLlib 1.2:
+
+1. *(Breaking change)* The Scala API for classification takes a named argument specifying the number
+of classes. In MLlib v1.1, this argument was called `numClasses` in Python and
+`numClassesForClassification` in Scala. In MLlib v1.2, the names are both set to `numClasses`.
+This `numClasses` parameter is specified either via
+[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy)
+or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree)
+static `trainClassifier` and `trainRegressor` methods.
+
+2. *(Breaking change)* The API for
+[`Node`](api/scala/index.html#org.apache.spark.mllib.tree.model.Node) has changed.
+This should generally not affect user code, unless the user manually constructs decision trees
+(instead of using the `trainClassifier` or `trainRegressor` methods).
+The tree `Node` now includes more information, including the probability of the predicted label
+(for classification).
+
+3. Printing methods' output has changed. The `toString` (Scala/Java) and `__repr__` (Python) methods used to print the full model; they now print a summary. For the full model, use `toDebugString`.
+
+Examples in the Spark distribution and examples in the
+[Decision Trees Guide](mllib-decision-tree.html#examples) have been updated accordingly.
+
+## From 1.0 to 1.1
+
+The only API changes in MLlib v1.1 are in
+[`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
+which continues to be an experimental API in MLlib 1.1:
+
+1. *(Breaking change)* The meaning of tree depth has been changed by 1 in order to match
+the implementations of trees in
+[scikit-learn](http://scikit-learn.org/stable/modules/classes.html#module-sklearn.tree)
+and in [rpart](http://cran.r-project.org/web/packages/rpart/index.html).
+In MLlib v1.0, a depth-1 tree had 1 leaf node, and a depth-2 tree had 1 root node and 2 leaf nodes.
+In MLlib v1.1, a depth-0 tree has 1 leaf node, and a depth-1 tree has 1 root node and 2 leaf nodes.
+This depth is specified by the `maxDepth` parameter in
+[`Strategy`](api/scala/index.html#org.apache.spark.mllib.tree.configuration.Strategy)
+or via [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree)
+static `trainClassifier` and `trainRegressor` methods.
+
+2. *(Non-breaking change)* We recommend using the newly added `trainClassifier` and `trainRegressor`
+methods to build a [`DecisionTree`](api/scala/index.html#org.apache.spark.mllib.tree.DecisionTree),
+rather than using the old parameter class `Strategy`. These new training methods explicitly
+separate classification and regression, and they replace specialized parameter types with
+simple `String` types.
+
+Examples of the new, recommended `trainClassifier` and `trainRegressor` are given in the
+[Decision Trees Guide](mllib-decision-tree.html#examples).
+
+## From 0.9 to 1.0
+
+In MLlib v1.0, we support both dense and sparse input in a unified way, which introduces a few
+breaking changes. If your data is sparse, please store it in a sparse format instead of dense to
+take advantage of sparsity in both storage and computation. Details are described below.
+
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index d5b044d94fdd7..5224a0b49a991 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -37,7 +37,7 @@ smoothing parameter `lambda` as input, and output a
can be used for evaluation and prediction.
{% highlight scala %}
-import org.apache.spark.mllib.classification.NaiveBayes
+import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
@@ -55,6 +55,10 @@ val model = NaiveBayes.train(training, lambda = 1.0)
val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
+
+// Save and load model
+model.save(sc, "myModelPath")
+val sameModel = NaiveBayesModel.load(sc, "myModelPath")
{% endhighlight %}
@@ -93,6 +97,10 @@ double accuracy = predictionAndLabel.filter(new Function,
return pl._1().equals(pl._2());
}
}).count() / (double) test.count();
+
+// Save and load model
+model.save(sc.sc(), "myModelPath");
+NaiveBayesModel sameModel = NaiveBayesModel.load(sc.sc(), "myModelPath");
{% endhighlight %}
@@ -105,6 +113,8 @@ smoothing parameter `lambda` as input, and output a
[NaiveBayesModel](api/python/pyspark.mllib.classification.NaiveBayesModel-class.html), which can be
used for evaluation and prediction.
+Note that the Python API does not yet support model save/load but will in the future.
+
{% highlight python %}
from pyspark.mllib.regression import LabeledPoint
diff --git a/docs/monitoring.md b/docs/monitoring.md
index 7a5cadc171d6d..6816671ffbf46 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -86,7 +86,7 @@ follows:
-
spark.history.fs.updateInterval
+
spark.history.fs.update.interval.seconds
10
The period, in seconds, at which information displayed by this history server is updated.
@@ -145,11 +145,36 @@ follows:
If disabled, no access control checks are made.
+
+
spark.history.fs.cleaner.enabled
+
false
+
+ Specifies whether the History Server should periodically clean up event logs from storage.
+
+
+
+
spark.history.fs.cleaner.interval.seconds
+
86400
+
+ How often the job history cleaner checks for files to delete, in seconds. Defaults to 86400 (one day).
+ Files are only deleted if they are older than spark.history.fs.cleaner.maxAge.seconds.
+
+
+
+
spark.history.fs.cleaner.maxAge.seconds
+
3600 * 24 * 7
+
+ Job history files older than this many seconds will be deleted when the history cleaner runs.
+ Defaults to 3600 * 24 * 7 (1 week).
+
+
Note that in all of these UIs, the tables are sortable by clicking their headers,
making it easy to identify slow tasks, data skew, etc.
+Note that the history server only displays completed Spark jobs. One way to signal the completion of a Spark job is to stop the Spark Context explicitly (`sc.stop()`), or in Python using the `with SparkContext() as sc:` to handle the Spark Context setup and tear down, and still show the job history on the UI.
+
# Metrics
Spark has a configurable metrics system based on the
@@ -176,6 +201,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the
* `JmxSink`: Registers metrics for viewing in a JMX console.
* `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data.
* `GraphiteSink`: Sends metrics to a Graphite node.
+* `Slf4jSink`: Sends metrics to slf4j as log entries.
Spark also supports a Ganglia sink which is not included in the default build due to
licensing restrictions:
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 6b365e83fb56d..7b0701828878e 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -173,8 +173,11 @@ in-process.
In the Spark shell, a special interpreter-aware SparkContext is already created for you, in the
variable called `sc`. Making your own SparkContext will not work. You can set which master the
context connects to using the `--master` argument, and you can add JARs to the classpath
-by passing a comma-separated list to the `--jars` argument.
-For example, to run `bin/spark-shell` on exactly four cores, use:
+by passing a comma-separated list to the `--jars` argument. You can also add dependencies
+(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates
+to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType)
+can be passed to the `--repositories` argument. For example, to run `bin/spark-shell` on exactly
+four cores, use:
{% highlight bash %}
$ ./bin/spark-shell --master local[4]
@@ -186,6 +189,12 @@ Or, to also add `code.jar` to its classpath, use:
$ ./bin/spark-shell --master local[4] --jars code.jar
{% endhighlight %}
+To include a dependency using maven coordinates:
+
+{% highlight bash %}
+$ ./bin/spark-shell --master local[4] --packages "org.example:example:0.1"
+{% endhighlight %}
+
For a complete list of options, run `spark-shell --help`. Behind the scenes,
`spark-shell` invokes the more general [`spark-submit` script](submitting-applications.html).
@@ -196,7 +205,11 @@ For a complete list of options, run `spark-shell --help`. Behind the scenes,
In the PySpark shell, a special interpreter-aware SparkContext is already created for you, in the
variable called `sc`. Making your own SparkContext will not work. You can set which master the
context connects to using the `--master` argument, and you can add Python .zip, .egg or .py files
-to the runtime path by passing a comma-separated list to `--py-files`.
+to the runtime path by passing a comma-separated list to `--py-files`. You can also add dependencies
+(e.g. Spark Packages) to your shell session by supplying a comma-separated list of maven coordinates
+to the `--packages` argument. Any additional repositories where dependencies might exist (e.g. SonaType)
+can be passed to the `--repositories` argument. Any python dependencies a Spark Package has (listed in
+the requirements.txt of that package) must be manually installed using pip when necessary.
For example, to run `bin/pyspark` on exactly four cores, use:
{% highlight bash %}
@@ -322,7 +335,7 @@ Apart from text files, Spark's Scala API also supports several other data format
* For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html). In addition, Spark allows you to specify native types for a few common Writables; for example, `sequenceFile[Int, String]` will automatically read IntWritables and Texts.
-* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`).
+* For other Hadoop InputFormats, you can use the `SparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `SparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`).
* `RDD.saveAsObjectFile` and `SparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD.
@@ -354,7 +367,7 @@ Apart from text files, Spark's Java API also supports several other data formats
* For [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), use SparkContext's `sequenceFile[K, V]` method where `K` and `V` are the types of key and values in the file. These should be subclasses of Hadoop's [Writable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Writable.html) interface, like [IntWritable](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/IntWritable.html) and [Text](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/io/Text.html).
-* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`).
+* For other Hadoop InputFormats, you can use the `JavaSparkContext.hadoopRDD` method, which takes an arbitrary `JobConf` and input format class, key class and value class. Set these the same way you would for a Hadoop job with your input source. You can also use `JavaSparkContext.newAPIHadoopRDD` for InputFormats based on the "new" MapReduce API (`org.apache.hadoop.mapreduce`).
* `JavaRDD.saveAsObjectFile` and `JavaSparkContext.objectFile` support saving an RDD in a simple format consisting of serialized Java objects. While this is not as efficient as specialized formats like Avro, it offers an easy way to save any RDD.
@@ -975,7 +988,7 @@ for details.
take(n)
-
Return an array with the first n elements of the dataset. Note that this is currently not executed in parallel. Instead, the driver program computes all the elements.
+
Return an array with the first n elements of the dataset.
takeSample(withReplacement, num, [seed])
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 78358499fd01f..db1173a06b0b1 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -197,7 +197,11 @@ See the [configuration page](configuration.html) for information on Spark config
spark.mesos.coarse
false
- Set the run mode for Spark on Mesos. For more information about the run mode, refer to #Mesos Run Mode section above.
+ If set to "true", runs over Mesos clusters in
+ "coarse-grained" sharing mode,
+ where Spark acquires one long-lived Mesos task on each machine instead of one Mesos task per
+ Spark task. This gives lower-latency scheduling for short queries, but leaves resources in use
+ for the whole duration of the Spark job.
@@ -211,19 +215,23 @@ See the [configuration page](configuration.html) for information on Spark config
spark.mesos.executor.home
-
SPARK_HOME
+
driver side SPARK_HOME
- The location where the mesos executor will look for Spark binaries to execute, and uses the SPARK_HOME setting on default.
- This variable is only used when no spark.executor.uri is provided, and assumes Spark is installed on the specified location
- on each slave.
+ Set the directory in which Spark is installed on the executors in Mesos. By default, the
+ executors will simply use the driver's Spark home directory, which may not be visible to
+ them. Note that this is only relevant if a Spark binary package is not specified through
+ spark.executor.uri.
spark.mesos.executor.memoryOverhead
-
384
+
executor memory * 0.07, with minimum of 384
- The amount of memory that Mesos executor will request for the task to account for the overhead of running the executor itself.
- The final total amount of memory allocated is the maximum value between executor memory plus memoryOverhead, and overhead fraction (1.07) plus the executor memory.
+ This value is an additive for spark.executor.memory, specified in MiB,
+ which is used to calculate the total Mesos task memory. A value of 384
+ implies a 384MiB overhead. Additionally, there is a hard-coded 7% minimum
+ overhead. The final overhead will be the larger of either
+ `spark.mesos.executor.memoryOverhead` or 7% of `spark.executor.memory`.
diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md
index 5c6084fb46255..74d8653a8b845 100644
--- a/docs/spark-standalone.md
+++ b/docs/spark-standalone.md
@@ -222,8 +222,7 @@ SPARK_WORKER_OPTS supports the following system properties:
false
Enable periodic cleanup of worker / application directories. Note that this only affects standalone
- mode, as YARN works differently. Applications directories are cleaned up regardless of whether
- the application is still running.
+ mode, as YARN works differently. Only the directories of stopped applications are cleaned up.
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 38f617d0c836c..0146a4ed1b745 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -14,10 +14,10 @@ title: Spark SQL
Spark SQL allows relational queries expressed in SQL, HiveQL, or Scala to be executed using
Spark. At the core of this component is a new type of RDD,
-[SchemaRDD](api/scala/index.html#org.apache.spark.sql.SchemaRDD). SchemaRDDs are composed of
+[DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame). DataFrames are composed of
[Row](api/scala/index.html#org.apache.spark.sql.package@Row:org.apache.spark.sql.catalyst.expressions.Row.type) objects, along with
-a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table
-in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io)
+a schema that describes the data types of each column in the row. A DataFrame is similar to a table
+in a traditional relational database. A DataFrame can be created from an existing RDD, a [Parquet](http://parquet.io)
file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
All of the examples on this page use sample data included in the Spark distribution and can be run in the `spark-shell`.
@@ -27,10 +27,10 @@ All of the examples on this page use sample data included in the Spark distribut
Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using
Spark. At the core of this component is a new type of RDD,
-[JavaSchemaRDD](api/scala/index.html#org.apache.spark.sql.api.java.JavaSchemaRDD). JavaSchemaRDDs are composed of
+[DataFrame](api/scala/index.html#org.apache.spark.sql.DataFrame). DataFrames are composed of
[Row](api/scala/index.html#org.apache.spark.sql.api.java.Row) objects, along with
-a schema that describes the data types of each column in the row. A JavaSchemaRDD is similar to a table
-in a traditional relational database. A JavaSchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io)
+a schema that describes the data types of each column in the row. A DataFrame is similar to a table
+in a traditional relational database. A DataFrame can be created from an existing RDD, a [Parquet](http://parquet.io)
file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
@@ -38,10 +38,10 @@ file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](
Spark SQL allows relational queries expressed in SQL or HiveQL to be executed using
Spark. At the core of this component is a new type of RDD,
-[SchemaRDD](api/python/pyspark.sql.SchemaRDD-class.html). SchemaRDDs are composed of
+[DataFrame](api/python/pyspark.sql.html#pyspark.sql.DataFrame). DataFrames are composed of
[Row](api/python/pyspark.sql.Row-class.html) objects, along with
-a schema that describes the data types of each column in the row. A SchemaRDD is similar to a table
-in a traditional relational database. A SchemaRDD can be created from an existing RDD, a [Parquet](http://parquet.io)
+a schema that describes the data types of each column in the row. A DataFrame is similar to a table
+in a traditional relational database. A DataFrame can be created from an existing RDD, a [Parquet](http://parquet.io)
file, a JSON dataset, or by running HiveQL against data stored in [Apache Hive](http://hive.apache.org/).
All of the examples on this page use sample data included in the Spark distribution and can be run in the `pyspark` shell.
@@ -65,8 +65,8 @@ descendants. To create a basic SQLContext, all you need is a SparkContext.
val sc: SparkContext // An existing SparkContext.
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
-// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD.
-import sqlContext.createSchemaRDD
+// this is used to implicitly convert an RDD to a DataFrame.
+import sqlContext.implicits._
{% endhighlight %}
In addition to the basic SQLContext, you can also create a HiveContext, which provides a
@@ -84,12 +84,12 @@ feature parity with a HiveContext.
The entry point into all relational functionality in Spark is the
-[JavaSQLContext](api/scala/index.html#org.apache.spark.sql.api.java.JavaSQLContext) class, or one
-of its descendants. To create a basic JavaSQLContext, all you need is a JavaSparkContext.
+[SQLContext](api/scala/index.html#org.apache.spark.sql.api.SQLContext) class, or one
+of its descendants. To create a basic SQLContext, all you need is a JavaSparkContext.
{% highlight java %}
JavaSparkContext sc = ...; // An existing JavaSparkContext.
-JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc);
+SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
{% endhighlight %}
In addition to the basic SQLContext, you can also create a HiveContext, which provides a strict
@@ -138,39 +138,39 @@ default is "hiveql", though "sql" is also available. Since the HiveQL parser is
# Data Sources
-Spark SQL supports operating on a variety of data sources through the `SchemaRDD` interface.
-A SchemaRDD can be operated on as normal RDDs and can also be registered as a temporary table.
-Registering a SchemaRDD as a table allows you to run SQL queries over its data. This section
-describes the various methods for loading data into a SchemaRDD.
+Spark SQL supports operating on a variety of data sources through the `DataFrame` interface.
+A DataFrame can be operated on as normal RDDs and can also be registered as a temporary table.
+Registering a DataFrame as a table allows you to run SQL queries over its data. This section
+describes the various methods for loading data into a DataFrame.
## RDDs
-Spark SQL supports two different methods for converting existing RDDs into SchemaRDDs. The first
+Spark SQL supports two different methods for converting existing RDDs into DataFrames. The first
method uses reflection to infer the schema of an RDD that contains specific types of objects. This
reflection based approach leads to more concise code and works well when you already know the schema
while writing your Spark application.
-The second method for creating SchemaRDDs is through a programmatic interface that allows you to
+The second method for creating DataFrames is through a programmatic interface that allows you to
construct a schema and then apply it to an existing RDD. While this method is more verbose, it allows
-you to construct SchemaRDDs when the columns and their types are not known until runtime.
+you to construct DataFrames when the columns and their types are not known until runtime.
### Inferring the Schema Using Reflection
-The Scala interaface for Spark SQL supports automatically converting an RDD containing case classes
-to a SchemaRDD. The case class
+The Scala interface for Spark SQL supports automatically converting an RDD containing case classes
+to a DataFrame. The case class
defines the schema of the table. The names of the arguments to the case class are read using
reflection and become the names of the columns. Case classes can also be nested or contain complex
-types such as Sequences or Arrays. This RDD can be implicitly converted to a SchemaRDD and then be
+types such as Sequences or Arrays. This RDD can be implicitly converted to a DataFrame and then be
registered as a table. Tables can be used in subsequent SQL statements.
{% highlight scala %}
// sc is an existing SparkContext.
val sqlContext = new org.apache.spark.sql.SQLContext(sc)
-// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD.
-import sqlContext.createSchemaRDD
+// this is used to implicitly convert an RDD to a DataFrame.
+import sqlContext.implicits._
// Define the schema using a case class.
// Note: Case classes in Scala 2.10 can support only up to 22 fields. To work around this limit,
@@ -184,7 +184,7 @@ people.registerTempTable("people")
// SQL statements can be run by using the sql methods provided by sqlContext.
val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
-// The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+// The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
{% endhighlight %}
@@ -194,7 +194,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
Spark SQL supports automatically converting an RDD of [JavaBeans](http://stackoverflow.com/questions/3295496/what-is-a-javabean-exactly)
-into a Schema RDD. The BeanInfo, obtained using reflection, defines the schema of the table.
+into a DataFrame. The BeanInfo, obtained using reflection, defines the schema of the table.
Currently, Spark SQL does not support JavaBeans that contain
nested or contain complex types such as Lists or Arrays. You can create a JavaBean by creating a
class that implements Serializable and has getters and setters for all of its fields.
@@ -225,12 +225,12 @@ public static class Person implements Serializable {
{% endhighlight %}
-A schema can be applied to an existing RDD by calling `applySchema` and providing the Class object
+A schema can be applied to an existing RDD by calling `createDataFrame` and providing the Class object
for the JavaBean.
{% highlight java %}
// sc is an existing JavaSparkContext.
-JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc);
+SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
// Load a text file and convert each line to a JavaBean.
JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").map(
@@ -247,13 +247,13 @@ JavaRDD people = sc.textFile("examples/src/main/resources/people.txt").m
});
// Apply a schema to an RDD of JavaBeans and register it as a table.
-JavaSchemaRDD schemaPeople = sqlContext.applySchema(people, Person.class);
+DataFrame schemaPeople = sqlContext.createDataFrame(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
-JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
+DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
-// The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+// The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
List teenagerNames = teenagers.map(new Function() {
public String call(Row row) {
@@ -267,7 +267,7 @@ List teenagerNames = teenagers.map(new Function() {
-Spark SQL can convert an RDD of Row objects to a SchemaRDD, inferring the datatypes. Rows are constructed by passing a list of
+Spark SQL can convert an RDD of Row objects to a DataFrame, inferring the datatypes. Rows are constructed by passing a list of
key/value pairs as kwargs to the Row class. The keys of this list define the column names of the table,
and the types are inferred by looking at the first row. Since we currently only look at the first
row, it is important that there is no missing data in the first row of the RDD. In future versions we
@@ -284,11 +284,11 @@ lines = sc.textFile("examples/src/main/resources/people.txt")
parts = lines.map(lambda l: l.split(","))
people = parts.map(lambda p: Row(name=p[0], age=int(p[1])))
-# Infer the schema, and register the SchemaRDD as a table.
+# Infer the schema, and register the DataFrame as a table.
schemaPeople = sqlContext.inferSchema(people)
schemaPeople.registerTempTable("people")
-# SQL can be run over SchemaRDDs that have been registered as a table.
+# SQL can be run over DataFrames that have been registered as a table.
teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
# The results of SQL queries are RDDs and support all the normal RDD operations.
@@ -310,12 +310,12 @@ for teenName in teenNames.collect():
When case classes cannot be defined ahead of time (for example,
the structure of records is encoded in a string, or a text dataset will be parsed
and fields will be projected differently for different users),
-a `SchemaRDD` can be created programmatically with three steps.
+a `DataFrame` can be created programmatically with three steps.
1. Create an RDD of `Row`s from the original RDD;
2. Create the schema represented by a `StructType` matching the structure of
`Row`s in the RDD created in Step 1.
-3. Apply the schema to the RDD of `Row`s via `applySchema` method provided
+3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided
by `SQLContext`.
For example:
@@ -341,15 +341,15 @@ val schema =
val rowRDD = people.map(_.split(",")).map(p => Row(p(0), p(1).trim))
// Apply the schema to the RDD.
-val peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema)
+val peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema)
-// Register the SchemaRDD as a table.
-peopleSchemaRDD.registerTempTable("people")
+// Register the DataFrames as a table.
+peopleDataFrame.registerTempTable("people")
// SQL statements can be run by using the sql methods provided by sqlContext.
val results = sqlContext.sql("SELECT name FROM people")
-// The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+// The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
results.map(t => "Name: " + t(0)).collect().foreach(println)
{% endhighlight %}
@@ -362,13 +362,13 @@ results.map(t => "Name: " + t(0)).collect().foreach(println)
When JavaBean classes cannot be defined ahead of time (for example,
the structure of records is encoded in a string, or a text dataset will be parsed and
fields will be projected differently for different users),
-a `SchemaRDD` can be created programmatically with three steps.
+a `DataFrame` can be created programmatically with three steps.
1. Create an RDD of `Row`s from the original RDD;
2. Create the schema represented by a `StructType` matching the structure of
`Row`s in the RDD created in Step 1.
-3. Apply the schema to the RDD of `Row`s via `applySchema` method provided
-by `JavaSQLContext`.
+3. Apply the schema to the RDD of `Row`s via `createDataFrame` method provided
+by `SQLContext`.
For example:
{% highlight java %}
@@ -381,7 +381,7 @@ import org.apache.spark.sql.api.java.StructField
import org.apache.spark.sql.api.java.Row
// sc is an existing JavaSparkContext.
-JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc);
+SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
// Load a text file and convert each line to a JavaBean.
JavaRDD people = sc.textFile("examples/src/main/resources/people.txt");
@@ -406,15 +406,15 @@ JavaRDD rowRDD = people.map(
});
// Apply the schema to the RDD.
-JavaSchemaRDD peopleSchemaRDD = sqlContext.applySchema(rowRDD, schema);
+DataFrame peopleDataFrame = sqlContext.createDataFrame(rowRDD, schema);
-// Register the SchemaRDD as a table.
-peopleSchemaRDD.registerTempTable("people");
+// Register the DataFrame as a table.
+peopleDataFrame.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
-JavaSchemaRDD results = sqlContext.sql("SELECT name FROM people");
+DataFrame results = sqlContext.sql("SELECT name FROM people");
-// The results of SQL queries are SchemaRDDs and support all the normal RDD operations.
+// The results of SQL queries are DataFrames and support all the normal RDD operations.
// The columns of a row in the result can be accessed by ordinal.
List names = results.map(new Function() {
public String call(Row row) {
@@ -431,12 +431,12 @@ List names = results.map(new Function() {
When a dictionary of kwargs cannot be defined ahead of time (for example,
the structure of records is encoded in a string, or a text dataset will be parsed and
fields will be projected differently for different users),
-a `SchemaRDD` can be created programmatically with three steps.
+a `DataFrame` can be created programmatically with three steps.
1. Create an RDD of tuples or lists from the original RDD;
2. Create the schema represented by a `StructType` matching the structure of
tuples or lists in the RDD created in the step 1.
-3. Apply the schema to the RDD via `applySchema` method provided by `SQLContext`.
+3. Apply the schema to the RDD via `createDataFrame` method provided by `SQLContext`.
For example:
{% highlight python %}
@@ -458,12 +458,12 @@ fields = [StructField(field_name, StringType(), True) for field_name in schemaSt
schema = StructType(fields)
# Apply the schema to the RDD.
-schemaPeople = sqlContext.applySchema(people, schema)
+schemaPeople = sqlContext.createDataFrame(people, schema)
-# Register the SchemaRDD as a table.
+# Register the DataFrame as a table.
schemaPeople.registerTempTable("people")
-# SQL can be run over SchemaRDDs that have been registered as a table.
+# SQL can be run over DataFrames that have been registered as a table.
results = sqlContext.sql("SELECT name FROM people")
# The results of SQL queries are RDDs and support all the normal RDD operations.
@@ -493,16 +493,16 @@ Using the data from the above example:
{% highlight scala %}
// sqlContext from the previous example is used in this example.
-// createSchemaRDD is used to implicitly convert an RDD to a SchemaRDD.
-import sqlContext.createSchemaRDD
+// This is used to implicitly convert an RDD to a DataFrame.
+import sqlContext.implicits._
val people: RDD[Person] = ... // An RDD of case class objects, from the previous example.
-// The RDD is implicitly converted to a SchemaRDD by createSchemaRDD, allowing it to be stored using Parquet.
+// The RDD is implicitly converted to a DataFrame by implicits, allowing it to be stored using Parquet.
people.saveAsParquetFile("people.parquet")
// Read in the parquet file created above. Parquet files are self-describing so the schema is preserved.
-// The result of loading a Parquet file is also a SchemaRDD.
+// The result of loading a Parquet file is also a DataFrame.
val parquetFile = sqlContext.parquetFile("people.parquet")
//Parquet files can also be registered as tables and then used in SQL statements.
@@ -518,18 +518,18 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
{% highlight java %}
// sqlContext from the previous example is used in this example.
-JavaSchemaRDD schemaPeople = ... // The JavaSchemaRDD from the previous example.
+DataFrame schemaPeople = ... // The DataFrame from the previous example.
-// JavaSchemaRDDs can be saved as Parquet files, maintaining the schema information.
+// DataFrames can be saved as Parquet files, maintaining the schema information.
schemaPeople.saveAsParquetFile("people.parquet");
// Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
-// The result of loading a parquet file is also a JavaSchemaRDD.
-JavaSchemaRDD parquetFile = sqlContext.parquetFile("people.parquet");
+// The result of loading a parquet file is also a DataFrame.
+DataFrame parquetFile = sqlContext.parquetFile("people.parquet");
//Parquet files can also be registered as tables and then used in SQL statements.
parquetFile.registerTempTable("parquetFile");
-JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
+DataFrame teenagers = sqlContext.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19");
List teenagerNames = teenagers.map(new Function() {
public String call(Row row) {
return "Name: " + row.getString(0);
@@ -544,13 +544,13 @@ List teenagerNames = teenagers.map(new Function() {
{% highlight python %}
# sqlContext from the previous example is used in this example.
-schemaPeople # The SchemaRDD from the previous example.
+schemaPeople # The DataFrame from the previous example.
-# SchemaRDDs can be saved as Parquet files, maintaining the schema information.
+# DataFrames can be saved as Parquet files, maintaining the schema information.
schemaPeople.saveAsParquetFile("people.parquet")
# Read in the Parquet file created above. Parquet files are self-describing so the schema is preserved.
-# The result of loading a parquet file is also a SchemaRDD.
+# The result of loading a parquet file is also a DataFrame.
parquetFile = sqlContext.parquetFile("people.parquet")
# Parquet files can also be registered as tables and then used in SQL statements.
@@ -629,7 +629,7 @@ Configuration of Parquet can be done using the `setConf` method on SQLContext or
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD.
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
This conversion can be done using one of two methods in a SQLContext:
* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
@@ -646,7 +646,7 @@ val sqlContext = new org.apache.spark.sql.SQLContext(sc)
// A JSON dataset is pointed to by path.
// The path can be either a single text file or a directory storing text files.
val path = "examples/src/main/resources/people.json"
-// Create a SchemaRDD from the file(s) pointed to by path
+// Create a DataFrame from the file(s) pointed to by path
val people = sqlContext.jsonFile(path)
// The inferred schema can be visualized using the printSchema() method.
@@ -655,13 +655,13 @@ people.printSchema()
// |-- age: integer (nullable = true)
// |-- name: string (nullable = true)
-// Register this SchemaRDD as a table.
+// Register this DataFrame as a table.
people.registerTempTable("people")
// SQL statements can be run by using the sql methods provided by sqlContext.
val teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
-// Alternatively, a SchemaRDD can be created for a JSON dataset represented by
+// Alternatively, a DataFrame can be created for a JSON dataset represented by
// an RDD[String] storing one JSON object per string.
val anotherPeopleRDD = sc.parallelize(
"""{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}""" :: Nil)
@@ -671,8 +671,8 @@ val anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD)
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a JavaSchemaRDD.
-This conversion can be done using one of two methods in a JavaSQLContext :
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
+This conversion can be done using one of two methods in a SQLContext :
* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
* `jsonRDD` - loads data from an existing RDD where each element of the RDD is a string containing a JSON object.
@@ -683,13 +683,13 @@ a regular multi-line JSON file will most often fail.
{% highlight java %}
// sc is an existing JavaSparkContext.
-JavaSQLContext sqlContext = new org.apache.spark.sql.api.java.JavaSQLContext(sc);
+SQLContext sqlContext = new org.apache.spark.sql.SQLContext(sc);
// A JSON dataset is pointed to by path.
// The path can be either a single text file or a directory storing text files.
String path = "examples/src/main/resources/people.json";
-// Create a JavaSchemaRDD from the file(s) pointed to by path
-JavaSchemaRDD people = sqlContext.jsonFile(path);
+// Create a DataFrame from the file(s) pointed to by path
+DataFrame people = sqlContext.jsonFile(path);
// The inferred schema can be visualized using the printSchema() method.
people.printSchema();
@@ -697,23 +697,23 @@ people.printSchema();
// |-- age: integer (nullable = true)
// |-- name: string (nullable = true)
-// Register this JavaSchemaRDD as a table.
+// Register this DataFrame as a table.
people.registerTempTable("people");
// SQL statements can be run by using the sql methods provided by sqlContext.
-JavaSchemaRDD teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
+DataFrame teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19");
-// Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by
+// Alternatively, a DataFrame can be created for a JSON dataset represented by
// an RDD[String] storing one JSON object per string.
List jsonData = Arrays.asList(
"{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}");
JavaRDD anotherPeopleRDD = sc.parallelize(jsonData);
-JavaSchemaRDD anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD);
+DataFrame anotherPeople = sqlContext.jsonRDD(anotherPeopleRDD);
{% endhighlight %}
-Spark SQL can automatically infer the schema of a JSON dataset and load it as a SchemaRDD.
+Spark SQL can automatically infer the schema of a JSON dataset and load it as a DataFrame.
This conversion can be done using one of two methods in a SQLContext:
* `jsonFile` - loads data from a directory of JSON files where each line of the files is a JSON object.
@@ -731,7 +731,7 @@ sqlContext = SQLContext(sc)
# A JSON dataset is pointed to by path.
# The path can be either a single text file or a directory storing text files.
path = "examples/src/main/resources/people.json"
-# Create a SchemaRDD from the file(s) pointed to by path
+# Create a DataFrame from the file(s) pointed to by path
people = sqlContext.jsonFile(path)
# The inferred schema can be visualized using the printSchema() method.
@@ -740,13 +740,13 @@ people.printSchema()
# |-- age: integer (nullable = true)
# |-- name: string (nullable = true)
-# Register this SchemaRDD as a table.
+# Register this DataFrame as a table.
people.registerTempTable("people")
# SQL statements can be run by using the sql methods provided by sqlContext.
teenagers = sqlContext.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19")
-# Alternatively, a SchemaRDD can be created for a JSON dataset represented by
+# Alternatively, a DataFrame can be created for a JSON dataset represented by
# an RDD[String] storing one JSON object per string.
anotherPeopleRDD = sc.parallelize([
'{"name":"Yin","address":{"city":"Columbus","state":"Ohio"}}'])
@@ -792,14 +792,14 @@ sqlContext.sql("FROM src SELECT key, value").collect().foreach(println)
-When working with Hive one must construct a `JavaHiveContext`, which inherits from `JavaSQLContext`, and
+When working with Hive one must construct a `HiveContext`, which inherits from `SQLContext`, and
adds support for finding tables in the MetaStore and writing queries using HiveQL. In addition to
-the `sql` method a `JavaHiveContext` also provides an `hql` methods, which allows queries to be
+the `sql` method a `HiveContext` also provides an `hql` methods, which allows queries to be
expressed in HiveQL.
{% highlight java %}
// sc is an existing JavaSparkContext.
-JavaHiveContext sqlContext = new org.apache.spark.sql.hive.api.java.HiveContext(sc);
+HiveContext sqlContext = new org.apache.spark.sql.hive.HiveContext(sc);
sqlContext.sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)");
sqlContext.sql("LOAD DATA LOCAL INPATH 'examples/src/main/resources/kv1.txt' INTO TABLE src");
@@ -841,7 +841,7 @@ turning on some experimental options.
## Caching Data In Memory
-Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `schemaRDD.cache()`.
+Spark SQL can cache tables using an in-memory columnar format by calling `sqlContext.cacheTable("tableName")` or `dataFrame.cache()`.
Then Spark SQL will scan only required columns and will automatically tune compression to minimize
memory usage and GC pressure. You can call `sqlContext.uncacheTable("tableName")` to remove the table from memory.
@@ -1161,7 +1161,7 @@ teenagers.map(t => "Name: " + t(0)).collect().foreach(println)
The DSL uses Scala symbols to represent columns in the underlying table, which are identifiers
prefixed with a tick (`'`). Implicit conversions turn these symbols into expressions that are
evaluated by the SQL execution engine. A full list of the functions supported can be found in the
-[ScalaDoc](api/scala/index.html#org.apache.spark.sql.SchemaRDD).
+[ScalaDoc](api/scala/index.html#org.apache.spark.sql.DataFrame).
diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md
index ac01dd3d8019a..40e17246fea83 100644
--- a/docs/streaming-flume-integration.md
+++ b/docs/streaming-flume-integration.md
@@ -64,7 +64,7 @@ configuring Flume agents.
3. **Deploying:** Package `spark-streaming-flume_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide).
-## Approach 2 (Experimental): Pull-based Approach using a Custom Sink
+## Approach 2: Pull-based Approach using a Custom Sink
Instead of Flume pushing data directly to Spark Streaming, this approach runs a custom Flume sink that allows the following.
- Flume pushes data into the sink, and the data stays buffered.
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 96fb12ce5e0b9..815c98713b738 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -878,6 +878,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi
val runningCounts = pairs.updateStateByKey[Int](updateFunction _)
{% endhighlight %}
+The update function will be called for each word, with `newValues` having a sequence of 1's (from
+the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
+Scala code, take a look at the example
+[StatefulNetworkWordCount.scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache
+/spark/examples/streaming/StatefulNetworkWordCount.scala).
+
@@ -899,6 +905,12 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi
JavaPairDStream runningCounts = pairs.updateStateByKey(updateFunction);
{% endhighlight %}
+The update function will be called for each word, with `newValues` having a sequence of 1's (from
+the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
+Java code, take a look at the example
+[JavaStatefulNetworkWordCount.java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming
+/JavaStatefulNetworkWordCount.java).
+
@@ -916,14 +928,14 @@ This is applied on a DStream containing words (say, the `pairs` DStream containi
runningCounts = pairs.updateStateByKey(updateFunction)
{% endhighlight %}
-
-
-
The update function will be called for each word, with `newValues` having a sequence of 1's (from
the `(word, 1)` pairs) and the `runningCount` having the previous count. For the complete
-Scala code, take a look at the example
+Python code, take a look at the example
[stateful_network_wordcount.py]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/stateful_network_wordcount.py).
+
+
+
Note that using `updateStateByKey` requires the checkpoint directory to be configured, which is
discussed in detail in the [checkpointing](#checkpointing) section.
diff --git a/docs/submitting-applications.md b/docs/submitting-applications.md
index 14a87f8436984..57b074778f2b0 100644
--- a/docs/submitting-applications.md
+++ b/docs/submitting-applications.md
@@ -174,6 +174,11 @@ This can use up a significant amount of space over time and will need to be clea
is handled automatically, and with Spark standalone, automatic cleanup can be configured with the
`spark.worker.cleanup.appDataTtl` property.
+Users may also include any other dependencies by supplying a comma-delimited list of maven coordinates
+with `--packages`. All transitive dependencies will be handled when using this command. Additional
+repositories (or resolvers in SBT) can be added in a comma-delimited fashion with the flag `--repositories`.
+These commands can be used with `pyspark`, `spark-shell`, and `spark-submit` to include Spark Packages.
+
For Python, the equivalent `--py-files` option can be used to distribute `.egg`, `.zip` and `.py` libraries
to executors.
diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py
index 3e4c49c0e1db6..c59ab565c6862 100755
--- a/ec2/spark_ec2.py
+++ b/ec2/spark_ec2.py
@@ -42,7 +42,7 @@
from optparse import OptionParser
from sys import stderr
-SPARK_EC2_VERSION = "1.2.0"
+SPARK_EC2_VERSION = "1.2.1"
SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__))
VALID_SPARK_VERSIONS = set([
@@ -58,6 +58,7 @@
"1.1.0",
"1.1.1",
"1.2.0",
+ "1.2.1",
])
DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION
@@ -112,6 +113,7 @@ def parse_args():
version="%prog {v}".format(v=SPARK_EC2_VERSION),
usage="%prog [options] \n\n"
+ " can be: launch, destroy, login, stop, start, get-master, reboot-slaves")
+
parser.add_option(
"-s", "--slaves", type="int", default=1,
help="Number of slaves to launch (default: %default)")
@@ -133,13 +135,15 @@ def parse_args():
help="Master instance type (leave empty for same as instance-type)")
parser.add_option(
"-r", "--region", default="us-east-1",
- help="EC2 region zone to launch instances in")
+ help="EC2 region used to launch instances in, or to find them in")
parser.add_option(
"-z", "--zone", default="",
help="Availability zone to launch instances in, or 'all' to spread " +
"slaves across multiple (an additional $0.01/Gb for bandwidth" +
"between zones applies) (default: a single zone chosen at random)")
- parser.add_option("-a", "--ami", help="Amazon Machine Image ID to use")
+ parser.add_option(
+ "-a", "--ami",
+ help="Amazon Machine Image ID to use")
parser.add_option(
"-v", "--spark-version", default=DEFAULT_SPARK_VERSION,
help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)")
@@ -179,10 +183,11 @@ def parse_args():
"Only possible on EBS-backed AMIs. " +
"EBS volumes are only attached if --ebs-vol-size > 0." +
"Only support up to 8 EBS volumes.")
- parser.add_option("--placement-group", type="string", default=None,
- help="Which placement group to try and launch " +
- "instances into. Assumes placement group is already " +
- "created.")
+ parser.add_option(
+ "--placement-group", type="string", default=None,
+ help="Which placement group to try and launch " +
+ "instances into. Assumes placement group is already " +
+ "created.")
parser.add_option(
"--swap", metavar="SWAP", type="int", default=1024,
help="Swap space to set up per node, in MB (default: %default)")
@@ -226,9 +231,11 @@ def parse_args():
"--copy-aws-credentials", action="store_true", default=False,
help="Add AWS credentials to hadoop configuration to allow Spark to access S3")
parser.add_option(
- "--subnet-id", default=None, help="VPC subnet to launch instances in")
+ "--subnet-id", default=None,
+ help="VPC subnet to launch instances in")
parser.add_option(
- "--vpc-id", default=None, help="VPC to launch instances in")
+ "--vpc-id", default=None,
+ help="VPC to launch instances in")
(opts, args) = parser.parse_args()
if len(args) != 2:
@@ -290,52 +297,54 @@ def is_active(instance):
return (instance.state in ['pending', 'running', 'stopping', 'stopped'])
-# Attempt to resolve an appropriate AMI given the architecture and region of the request.
# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/
# Last Updated: 2014-06-20
# For easy maintainability, please keep this manually-inputted dictionary sorted by key.
+EC2_INSTANCE_TYPES = {
+ "c1.medium": "pvm",
+ "c1.xlarge": "pvm",
+ "c3.2xlarge": "pvm",
+ "c3.4xlarge": "pvm",
+ "c3.8xlarge": "pvm",
+ "c3.large": "pvm",
+ "c3.xlarge": "pvm",
+ "cc1.4xlarge": "hvm",
+ "cc2.8xlarge": "hvm",
+ "cg1.4xlarge": "hvm",
+ "cr1.8xlarge": "hvm",
+ "hi1.4xlarge": "pvm",
+ "hs1.8xlarge": "pvm",
+ "i2.2xlarge": "hvm",
+ "i2.4xlarge": "hvm",
+ "i2.8xlarge": "hvm",
+ "i2.xlarge": "hvm",
+ "m1.large": "pvm",
+ "m1.medium": "pvm",
+ "m1.small": "pvm",
+ "m1.xlarge": "pvm",
+ "m2.2xlarge": "pvm",
+ "m2.4xlarge": "pvm",
+ "m2.xlarge": "pvm",
+ "m3.2xlarge": "hvm",
+ "m3.large": "hvm",
+ "m3.medium": "hvm",
+ "m3.xlarge": "hvm",
+ "r3.2xlarge": "hvm",
+ "r3.4xlarge": "hvm",
+ "r3.8xlarge": "hvm",
+ "r3.large": "hvm",
+ "r3.xlarge": "hvm",
+ "t1.micro": "pvm",
+ "t2.medium": "hvm",
+ "t2.micro": "hvm",
+ "t2.small": "hvm",
+}
+
+
+# Attempt to resolve an appropriate AMI given the architecture and region of the request.
def get_spark_ami(opts):
- instance_types = {
- "c1.medium": "pvm",
- "c1.xlarge": "pvm",
- "c3.2xlarge": "pvm",
- "c3.4xlarge": "pvm",
- "c3.8xlarge": "pvm",
- "c3.large": "pvm",
- "c3.xlarge": "pvm",
- "cc1.4xlarge": "hvm",
- "cc2.8xlarge": "hvm",
- "cg1.4xlarge": "hvm",
- "cr1.8xlarge": "hvm",
- "hi1.4xlarge": "pvm",
- "hs1.8xlarge": "pvm",
- "i2.2xlarge": "hvm",
- "i2.4xlarge": "hvm",
- "i2.8xlarge": "hvm",
- "i2.xlarge": "hvm",
- "m1.large": "pvm",
- "m1.medium": "pvm",
- "m1.small": "pvm",
- "m1.xlarge": "pvm",
- "m2.2xlarge": "pvm",
- "m2.4xlarge": "pvm",
- "m2.xlarge": "pvm",
- "m3.2xlarge": "hvm",
- "m3.large": "hvm",
- "m3.medium": "hvm",
- "m3.xlarge": "hvm",
- "r3.2xlarge": "hvm",
- "r3.4xlarge": "hvm",
- "r3.8xlarge": "hvm",
- "r3.large": "hvm",
- "r3.xlarge": "hvm",
- "t1.micro": "pvm",
- "t2.medium": "hvm",
- "t2.micro": "hvm",
- "t2.small": "hvm",
- }
- if opts.instance_type in instance_types:
- instance_type = instance_types[opts.instance_type]
+ if opts.instance_type in EC2_INSTANCE_TYPES:
+ instance_type = EC2_INSTANCE_TYPES[opts.instance_type]
else:
instance_type = "pvm"
print >> stderr,\
@@ -605,10 +614,9 @@ def launch_cluster(conn, opts, cluster_name):
# Get the EC2 instances in an existing cluster if available.
# Returns a tuple of lists of EC2 instance objects for the masters and slaves
-
-
def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
- print "Searching for existing cluster " + cluster_name + "..."
+ print "Searching for existing cluster " + cluster_name + " in region " \
+ + opts.region + "..."
reservations = conn.get_all_reservations()
master_nodes = []
slave_nodes = []
@@ -626,9 +634,11 @@ def get_existing_cluster(conn, opts, cluster_name, die_on_error=True):
return (master_nodes, slave_nodes)
else:
if master_nodes == [] and slave_nodes != []:
- print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name + "-master"
+ print >> sys.stderr, "ERROR: Could not find master in group " + cluster_name \
+ + "-master" + " in region " + opts.region
else:
- print >> sys.stderr, "ERROR: Could not find any existing cluster"
+ print >> sys.stderr, "ERROR: Could not find any existing cluster" \
+ + " in region " + opts.region
sys.exit(1)
@@ -1050,6 +1060,30 @@ def real_main():
print >> stderr, 'You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file)
sys.exit(1)
+ if opts.instance_type not in EC2_INSTANCE_TYPES:
+ print >> stderr, "Warning: Unrecognized EC2 instance type for instance-type: {t}".format(
+ t=opts.instance_type)
+
+ if opts.master_instance_type != "":
+ if opts.master_instance_type not in EC2_INSTANCE_TYPES:
+ print >> stderr, \
+ "Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format(
+ t=opts.master_instance_type)
+ # Since we try instance types even if we can't resolve them, we check if they resolve first
+ # and, if they do, see if they resolve to the same virtualization type.
+ if opts.instance_type in EC2_INSTANCE_TYPES and \
+ opts.master_instance_type in EC2_INSTANCE_TYPES:
+ if EC2_INSTANCE_TYPES[opts.instance_type] != \
+ EC2_INSTANCE_TYPES[opts.master_instance_type]:
+ print >> stderr, \
+ "Error: spark-ec2 currently does not support having a master and slaves with " + \
+ "different AMI virtualization types."
+ print >> stderr, "master instance virtualization type: {t}".format(
+ t=EC2_INSTANCE_TYPES[opts.master_instance_type])
+ print >> stderr, "slave instance virtualization type: {t}".format(
+ t=EC2_INSTANCE_TYPES[opts.instance_type])
+ sys.exit(1)
+
if opts.ebs_vol_num > 8:
print >> stderr, "ebs-vol-num cannot be greater than 8"
sys.exit(1)
@@ -1140,11 +1174,12 @@ def real_main():
time.sleep(30) # Yes, it does have to be this long :-(
for group in groups:
try:
- conn.delete_security_group(group.name)
- print "Deleted security group " + group.name
+ # It is needed to use group_id to make it work with VPC
+ conn.delete_security_group(group_id=group.id)
+ print "Deleted security group %s" % group.name
except boto.exception.EC2ResponseError:
success = False
- print "Failed to delete security group " + group.name
+ print "Failed to delete security group %s" % group.name
# Unfortunately, group.revoke() returns True even if a rule was not
# deleted, so this needs to be rerun if something fails
diff --git a/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
new file mode 100644
index 0000000000000..bab9f2478e779
--- /dev/null
+++ b/examples/scala-2.10/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.streaming;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Arrays;
+import java.util.regex.Pattern;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Lists;
+import kafka.serializer.StringDecoder;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.function.*;
+import org.apache.spark.streaming.api.java.*;
+import org.apache.spark.streaming.kafka.KafkaUtils;
+import org.apache.spark.streaming.Durations;
+
+/**
+ * Consumes messages from one or more topics in Kafka and does wordcount.
+ * Usage: DirectKafkaWordCount
+ * is a list of one or more Kafka brokers
+ * is a list of one or more kafka topics to consume from
+ *
+ * Example:
+ * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2
+ */
+
+public final class JavaDirectKafkaWordCount {
+ private static final Pattern SPACE = Pattern.compile(" ");
+
+ public static void main(String[] args) {
+ if (args.length < 2) {
+ System.err.println("Usage: DirectKafkaWordCount \n" +
+ " is a list of one or more Kafka brokers\n" +
+ " is a list of one or more kafka topics to consume from\n\n");
+ System.exit(1);
+ }
+
+ StreamingExamples.setStreamingLogLevels();
+
+ String brokers = args[0];
+ String topics = args[1];
+
+ // Create context with 2 second batch interval
+ SparkConf sparkConf = new SparkConf().setAppName("JavaDirectKafkaWordCount");
+ JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, Durations.seconds(2));
+
+ HashSet topicsSet = new HashSet(Arrays.asList(topics.split(",")));
+ HashMap kafkaParams = new HashMap();
+ kafkaParams.put("metadata.broker.list", brokers);
+
+ // Create direct kafka stream with brokers and topics
+ JavaPairInputDStream messages = KafkaUtils.createDirectStream(
+ jssc,
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ kafkaParams,
+ topicsSet
+ );
+
+ // Get the lines, split them into words, count the words and print
+ JavaDStream lines = messages.map(new Function, String>() {
+ @Override
+ public String call(Tuple2 tuple2) {
+ return tuple2._2();
+ }
+ });
+ JavaDStream words = lines.flatMap(new FlatMapFunction() {
+ @Override
+ public Iterable call(String x) {
+ return Lists.newArrayList(SPACE.split(x));
+ }
+ });
+ JavaPairDStream wordCounts = words.mapToPair(
+ new PairFunction() {
+ @Override
+ public Tuple2 call(String s) {
+ return new Tuple2(s, 1);
+ }
+ }).reduceByKey(
+ new Function2() {
+ @Override
+ public Integer call(Integer i1, Integer i2) {
+ return i1 + i2;
+ }
+ });
+ wordCounts.print();
+
+ // Start the computation
+ jssc.start();
+ jssc.awaitTermination();
+ }
+}
diff --git a/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
new file mode 100644
index 0000000000000..deb08fd57b8c7
--- /dev/null
+++ b/examples/scala-2.10/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.streaming
+
+import kafka.serializer.StringDecoder
+
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.kafka._
+import org.apache.spark.SparkConf
+
+/**
+ * Consumes messages from one or more topics in Kafka and does wordcount.
+ * Usage: DirectKafkaWordCount
+ * is a list of one or more Kafka brokers
+ * is a list of one or more kafka topics to consume from
+ *
+ * Example:
+ * $ bin/run-example streaming.KafkaWordCount broker1-host:port,broker2-host:port topic1,topic2
+ */
+object DirectKafkaWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println(s"""
+ |Usage: DirectKafkaWordCount
+ | is a list of one or more Kafka brokers
+ | is a list of one or more kafka topics to consume from
+ |
+ """".stripMargin)
+ System.exit(1)
+ }
+
+ StreamingExamples.setStreamingLogLevels()
+
+ val Array(brokers, topics) = args
+
+ // Create context with 2 second batch interval
+ val sparkConf = new SparkConf().setAppName("DirectKafkaWordCount")
+ val ssc = new StreamingContext(sparkConf, Seconds(2))
+
+ // Create direct kafka stream with brokers and topics
+ val topicsSet = topics.split(",").toSet
+ val kafkaParams = Map[String, String]("metadata.broker.list" -> brokers)
+ val messages = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, topicsSet)
+
+ // Get the lines, split them into words, count the words and print
+ val lines = messages.map(_._2)
+ val words = lines.flatMap(_.split(" "))
+ val wordCounts = words.map(x => (x, 1L)).reduceByKey(_ + _)
+ wordCounts.print()
+
+ // Start the computation
+ ssc.start()
+ ssc.awaitTermination()
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index 5041e0b6d34b0..9bbc14ea40875 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -34,8 +34,8 @@
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
/**
* A simple example demonstrating model selection using CrossValidator.
@@ -71,7 +71,7 @@ public static void main(String[] args) {
new LabeledDocument(9L, "a e c l", 0.0),
new LabeledDocument(10L, "spark compile", 1.0),
new LabeledDocument(11L, "hadoop software", 0.0));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -112,12 +112,11 @@ public static void main(String[] args) {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
- cvModel.transform(test).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT id, text, probability, prediction FROM prediction");
- for (Row r: predictions.collect()) {
+ DataFrame predictions = cvModel.transform(test);
+ for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 4d9dad9f23038..19d0eb216848e 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -62,7 +62,7 @@ public static void main(String[] args) throws Exception {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
@@ -80,7 +80,7 @@ public static void main(String[] args) throws Exception {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents. cvModel uses the best model found (lrModel).
DataFrame results = model.transform(test);
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index cc69e6315fdda..4e02acce696e6 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -29,8 +29,8 @@
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
/**
* A simple example demonstrating ways to specify parameters for Estimators and Transformers.
@@ -54,7 +54,7 @@ public static void main(String[] args) {
new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledPoint.class);
// Create a LogisticRegression instance. This instance is an Estimator.
LogisticRegression lr = new LogisticRegression();
@@ -94,16 +94,14 @@ public static void main(String[] args) {
new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), LabeledPoint.class);
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
- model2.transform(test).registerTempTable("results");
- DataFrame results =
- jsql.sql("SELECT features, label, myProbability, prediction FROM results");
- for (Row r: results.collect()) {
+ DataFrame results = model2.transform(test);
+ for (Row r: results.select("features", "label", "myProbability", "prediction").collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index d929f1ad2014a..ef1ec103a879f 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -30,8 +30,8 @@
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.Tokenizer;
import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
/**
* A simple text classification pipeline that recognizes "spark" from input text. It uses the Java
@@ -54,7 +54,7 @@ public static void main(String[] args) {
new LabeledDocument(1L, "b d", 0.0),
new LabeledDocument(2L, "spark f g h", 1.0),
new LabeledDocument(3L, "hadoop mapreduce", 0.0));
- DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class);
+ DataFrame training = jsql.createDataFrame(jsc.parallelize(localTraining), LabeledDocument.class);
// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
Tokenizer tokenizer = new Tokenizer()
@@ -79,12 +79,11 @@ public static void main(String[] args) {
new Document(5L, "l m n"),
new Document(6L, "mapreduce spark"),
new Document(7L, "apache hadoop"));
- DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class);
+ DataFrame test = jsql.createDataFrame(jsc.parallelize(localTest), Document.class);
// Make predictions on test documents.
- model.transform(test).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
- for (Row r: predictions.collect()) {
+ DataFrame predictions = model.transform(test);
+ for (Row r: predictions.select("id", "text", "probability", "prediction").collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java
new file mode 100644
index 0000000000000..36baf5868736c
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaFPGrowthExample.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import java.util.ArrayList;
+
+import com.google.common.base.Joiner;
+import com.google.common.collect.Lists;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.mllib.fpm.FPGrowth;
+import org.apache.spark.mllib.fpm.FPGrowthModel;
+
+/**
+ * Java example for mining frequent itemsets using FP-growth.
+ * Example usage: ./bin/run-example mllib.JavaFPGrowthExample ./data/mllib/sample_fpgrowth.txt
+ */
+public class JavaFPGrowthExample {
+
+ public static void main(String[] args) {
+ String inputFile;
+ double minSupport = 0.3;
+ int numPartition = -1;
+ if (args.length < 1) {
+ System.err.println(
+ "Usage: JavaFPGrowth [minSupport] [numPartition]");
+ System.exit(1);
+ }
+ inputFile = args[0];
+ if (args.length >= 2) {
+ minSupport = Double.parseDouble(args[1]);
+ }
+ if (args.length >= 3) {
+ numPartition = Integer.parseInt(args[2]);
+ }
+
+ SparkConf sparkConf = new SparkConf().setAppName("JavaFPGrowthExample");
+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ JavaRDD> transactions = sc.textFile(inputFile).map(
+ new Function>() {
+ @Override
+ public ArrayList call(String s) {
+ return Lists.newArrayList(s.split(" "));
+ }
+ }
+ );
+
+ FPGrowthModel model = new FPGrowth()
+ .setMinSupport(minSupport)
+ .setNumPartitions(numPartition)
+ .run(transactions);
+
+ for (FPGrowth.FreqItemset s: model.freqItemsets().toJavaRDD().collect()) {
+ System.out.println("[" + Joiner.on(",").join(s.javaItems()) + "], " + s.freq());
+ }
+
+ sc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
index 75bc3dd788ac0..c0d1a622ffad8 100644
--- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaLDAExample.java
@@ -71,5 +71,6 @@ public Tuple2 call(Tuple2 doc_id) {
}
System.out.println();
}
+ sc.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java
new file mode 100644
index 0000000000000..6c6f9768f015e
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib;
+
+import scala.Tuple3;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.clustering.PowerIterationClustering;
+import org.apache.spark.mllib.clustering.PowerIterationClusteringModel;
+
+/**
+ * Java example for graph clustering using power iteration clustering (PIC).
+ */
+public class JavaPowerIterationClusteringExample {
+ public static void main(String[] args) {
+ SparkConf sparkConf = new SparkConf().setAppName("JavaPowerIterationClusteringExample");
+ JavaSparkContext sc = new JavaSparkContext(sparkConf);
+
+ @SuppressWarnings("unchecked")
+ JavaRDD> similarities = sc.parallelize(Lists.newArrayList(
+ new Tuple3(0L, 1L, 0.9),
+ new Tuple3(1L, 2L, 0.9),
+ new Tuple3(2L, 3L, 0.9),
+ new Tuple3(3L, 4L, 0.1),
+ new Tuple3(4L, 5L, 0.9)));
+
+ PowerIterationClustering pic = new PowerIterationClustering()
+ .setK(2)
+ .setMaxIterations(10);
+ PowerIterationClusteringModel model = pic.run(similarities);
+
+ for (PowerIterationClustering.Assignment a: model.assignments().toJavaRDD().collect()) {
+ System.out.println(a.id() + " -> " + a.cluster());
+ }
+
+ sc.stop();
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
index 8defb769ffaaf..dee794840a3e1 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -74,7 +74,7 @@ public Person call(String line) {
});
// Apply a schema to an RDD of Java Beans and register it as a table.
- DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class);
+ DataFrame schemaPeople = sqlCtx.createDataFrame(people, Person.class);
schemaPeople.registerTempTable("people");
// SQL can be run over RDDs that have been registered as tables.
diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
new file mode 100644
index 0000000000000..d46c7107c7a21
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.streaming;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.regex.Pattern;
+
+import scala.Tuple2;
+
+import com.google.common.base.Optional;
+import com.google.common.collect.Lists;
+
+import org.apache.spark.HashPartitioner;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.StorageLevels;
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.api.java.function.PairFunction;
+import org.apache.spark.streaming.Durations;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaPairDStream;
+import org.apache.spark.streaming.api.java.JavaReceiverInputDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+
+/**
+ * Counts words cumulatively in UTF8 encoded, '\n' delimited text received from the network every
+ * second starting with initial value of word count.
+ * Usage: JavaStatefulNetworkWordCount
+ * and describe the TCP server that Spark Streaming would connect to receive
+ * data.
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ bin/run-example
+ * org.apache.spark.examples.streaming.JavaStatefulNetworkWordCount localhost 9999`
+ */
+public class JavaStatefulNetworkWordCount {
+ private static final Pattern SPACE = Pattern.compile(" ");
+
+ public static void main(String[] args) {
+ if (args.length < 2) {
+ System.err.println("Usage: JavaStatefulNetworkWordCount ");
+ System.exit(1);
+ }
+
+ StreamingExamples.setStreamingLogLevels();
+
+ // Update the cumulative count function
+ final Function2, Optional, Optional> updateFunction =
+ new Function2, Optional, Optional>() {
+ @Override
+ public Optional call(List values, Optional state) {
+ Integer newSum = state.or(0);
+ for (Integer value : values) {
+ newSum += value;
+ }
+ return Optional.of(newSum);
+ }
+ };
+
+ // Create the context with a 1 second batch size
+ SparkConf sparkConf = new SparkConf().setAppName("JavaStatefulNetworkWordCount");
+ JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, Durations.seconds(1));
+ ssc.checkpoint(".");
+
+ // Initial RDD input to updateStateByKey
+ List> tuples = Arrays.asList(new Tuple2("hello", 1),
+ new Tuple2("world", 1));
+ JavaPairRDD initialRDD = ssc.sc().parallelizePairs(tuples);
+
+ JavaReceiverInputDStream lines = ssc.socketTextStream(
+ args[0], Integer.parseInt(args[1]), StorageLevels.MEMORY_AND_DISK_SER_2);
+
+ JavaDStream words = lines.flatMap(new FlatMapFunction() {
+ @Override
+ public Iterable call(String x) {
+ return Lists.newArrayList(SPACE.split(x));
+ }
+ });
+
+ JavaPairDStream wordsDstream = words.mapToPair(
+ new PairFunction() {
+ @Override
+ public Tuple2 call(String s) {
+ return new Tuple2(s, 1);
+ }
+ });
+
+ // This will give a Dstream made of state (which is the cumulative count of the words)
+ JavaPairDStream stateDstream = wordsDstream.updateStateByKey(updateFunction,
+ new HashPartitioner(ssc.sc().defaultParallelism()), initialRDD);
+
+ stateDstream.print();
+ ssc.start();
+ ssc.awaitTermination();
+ }
+}
diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py
index c7df3d7b74767..d281f4fa44282 100644
--- a/examples/src/main/python/ml/simple_text_classification_pipeline.py
+++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py
@@ -16,10 +16,10 @@
#
from pyspark import SparkContext
-from pyspark.sql import SQLContext, Row
from pyspark.ml import Pipeline
-from pyspark.ml.feature import HashingTF, Tokenizer
from pyspark.ml.classification import LogisticRegression
+from pyspark.ml.feature import HashingTF, Tokenizer
+from pyspark.sql import Row, SQLContext
"""
@@ -36,43 +36,33 @@
sqlCtx = SQLContext(sc)
# Prepare training documents, which are labeled.
- LabeledDocument = Row('id', 'text', 'label')
- training = sqlCtx.inferSchema(
- sc.parallelize([(0L, "a b c d e spark", 1.0),
- (1L, "b d", 0.0),
- (2L, "spark f g h", 1.0),
- (3L, "hadoop mapreduce", 0.0)])
- .map(lambda x: LabeledDocument(*x)))
+ LabeledDocument = Row("id", "text", "label")
+ training = sc.parallelize([(0L, "a b c d e spark", 1.0),
+ (1L, "b d", 0.0),
+ (2L, "spark f g h", 1.0),
+ (3L, "hadoop mapreduce", 0.0)]) \
+ .map(lambda x: LabeledDocument(*x)).toDF()
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
- tokenizer = Tokenizer() \
- .setInputCol("text") \
- .setOutputCol("words")
- hashingTF = HashingTF() \
- .setInputCol(tokenizer.getOutputCol()) \
- .setOutputCol("features")
- lr = LogisticRegression() \
- .setMaxIter(10) \
- .setRegParam(0.01)
- pipeline = Pipeline() \
- .setStages([tokenizer, hashingTF, lr])
+ tokenizer = Tokenizer(inputCol="text", outputCol="words")
+ hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
+ lr = LogisticRegression(maxIter=10, regParam=0.01)
+ pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
# Fit the pipeline to training documents.
model = pipeline.fit(training)
# Prepare test documents, which are unlabeled.
- Document = Row('id', 'text')
- test = sqlCtx.inferSchema(
- sc.parallelize([(4L, "spark i j k"),
- (5L, "l m n"),
- (6L, "mapreduce spark"),
- (7L, "apache hadoop")])
- .map(lambda x: Document(*x)))
+ Document = Row("id", "text")
+ test = sc.parallelize([(4L, "spark i j k"),
+ (5L, "l m n"),
+ (6L, "mapreduce spark"),
+ (7L, "apache hadoop")]) \
+ .map(lambda x: Document(*x)).toDF()
# Make predictions on test documents and print columns of interest.
prediction = model.transform(test)
- prediction.registerTempTable("prediction")
- selected = sqlCtx.sql("SELECT id, text, prediction from prediction")
+ selected = prediction.select("id", "text", "prediction")
for row in selected.collect():
print row
diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py
index 7f5c68e3d0fe2..47202fde7510b 100644
--- a/examples/src/main/python/sql.py
+++ b/examples/src/main/python/sql.py
@@ -31,7 +31,7 @@
Row(name="Smith", age=23),
Row(name="Sarah", age=18)])
# Infer schema from the first row, create a DataFrame and print the schema
- some_df = sqlContext.inferSchema(some_rdd)
+ some_df = sqlContext.createDataFrame(some_rdd)
some_df.printSchema()
# Another RDD is created from a list of tuples
@@ -40,7 +40,7 @@
schema = StructType([StructField("person_name", StringType(), False),
StructField("person_age", IntegerType(), False)])
# Create a DataFrame by applying the schema to the RDD and print the schema
- another_df = sqlContext.applySchema(another_rdd, schema)
+ another_df = sqlContext.createDataFrame(another_rdd, schema)
another_df.printSchema()
# root
# |-- age: integer (nullable = true)
diff --git a/examples/src/main/python/status_api_demo.py b/examples/src/main/python/status_api_demo.py
new file mode 100644
index 0000000000000..a33bdc475a06d
--- /dev/null
+++ b/examples/src/main/python/status_api_demo.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+import threading
+import Queue
+
+from pyspark import SparkConf, SparkContext
+
+
+def delayed(seconds):
+ def f(x):
+ time.sleep(seconds)
+ return x
+ return f
+
+
+def call_in_background(f, *args):
+ result = Queue.Queue(1)
+ t = threading.Thread(target=lambda: result.put(f(*args)))
+ t.daemon = True
+ t.start()
+ return result
+
+
+def main():
+ conf = SparkConf().set("spark.ui.showConsoleProgress", "false")
+ sc = SparkContext(appName="PythonStatusAPIDemo", conf=conf)
+
+ def run():
+ rdd = sc.parallelize(range(10), 10).map(delayed(2))
+ reduced = rdd.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y)
+ return reduced.map(delayed(2)).collect()
+
+ result = call_in_background(run)
+ status = sc.statusTracker()
+ while result.empty():
+ ids = status.getJobIdsForGroup()
+ for id in ids:
+ job = status.getJobInfo(id)
+ print "Job", id, "status: ", job.status
+ for sid in job.stageIds:
+ info = status.getStageInfo(sid)
+ if info:
+ print "Stage %d: %d tasks total (%d active, %d complete)" % \
+ (sid, info.numTasks, info.numActiveTasks, info.numCompletedTasks)
+ time.sleep(1)
+
+ print "Job results are:", result.get()
+ sc.stop()
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
index 1b53f3edbe92e..4c129dbe2d12d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala
@@ -29,7 +29,7 @@ object BroadcastTest {
val blockSize = if (args.length > 3) args(3) else "4096"
val sparkConf = new SparkConf().setAppName("Broadcast Test")
- .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroaddcastFactory")
+ .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroadcastFactory")
.set("spark.broadcast.blockSize", blockSize)
val sc = new SparkContext(sparkConf)
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
index a2893f78e0fec..6c0af20461d3b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
@@ -90,7 +90,7 @@ object CrossValidatorExample {
crossval.setNumFolds(2) // Use 3+ in practice
// Run cross-validation, and choose the best set of parameters.
- val cvModel = crossval.fit(training)
+ val cvModel = crossval.fit(training.toDF())
// Prepare test documents, which are unlabeled.
val test = sc.parallelize(Seq(
@@ -100,7 +100,7 @@ object CrossValidatorExample {
Document(7L, "apache hadoop")))
// Make predictions on test documents. cvModel uses the best model found (lrModel).
- cvModel.transform(test)
+ cvModel.transform(test.toDF())
.select("id", "text", "probability", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index aed44238939c7..df26798e41b7b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -58,7 +58,7 @@ object DeveloperApiExample {
lr.setMaxIter(10)
// Learn a LogisticRegression model. This uses the parameters stored in lr.
- val model = lr.fit(training)
+ val model = lr.fit(training.toDF())
// Prepare test data.
val test = sc.parallelize(Seq(
@@ -67,7 +67,7 @@ object DeveloperApiExample {
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
// Make predictions on test data.
- val sumPredictions: Double = model.transform(test)
+ val sumPredictions: Double = model.transform(test.toDF())
.select("features", "label", "prediction")
.collect()
.map { case Row(features: Vector, label: Double, prediction: Double) =>
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
index 836ea2e01201e..25f21113bf622 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala
@@ -93,8 +93,8 @@ object MovieLensALS {
| bin/spark-submit --class org.apache.spark.examples.ml.MovieLensALS \
| examples/target/scala-*/spark-examples-*.jar \
| --rank 10 --maxIter 15 --regParam 0.1 \
- | --movies path/to/movielens/movies.dat \
- | --ratings path/to/movielens/ratings.dat
+ | --movies data/mllib/als/sample_movielens_movies.txt \
+ | --ratings data/mllib/als/sample_movielens_ratings.txt
""".stripMargin)
}
@@ -137,9 +137,9 @@ object MovieLensALS {
.setRegParam(params.regParam)
.setNumBlocks(params.numBlocks)
- val model = als.fit(training)
+ val model = als.fit(training.toDF())
- val predictions = model.transform(test).cache()
+ val predictions = model.transform(test.toDF()).cache()
// Evaluate the model.
// TODO: Create an evaluator to compute RMSE.
@@ -157,17 +157,23 @@ object MovieLensALS {
println(s"Test RMSE = $rmse.")
// Inspect false positives.
- predictions.registerTempTable("prediction")
- sc.textFile(params.movies).map(Movie.parseMovie).registerTempTable("movie")
- sqlContext.sql(
- """
- |SELECT userId, prediction.movieId, title, rating, prediction
- | FROM prediction JOIN movie ON prediction.movieId = movie.movieId
- | WHERE rating <= 1 AND prediction >= 4
- | LIMIT 100
- """.stripMargin)
- .collect()
- .foreach(println)
+ // Note: We reference columns in 2 ways:
+ // (1) predictions("movieId") lets us specify the movieId column in the predictions
+ // DataFrame, rather than the movieId column in the movies DataFrame.
+ // (2) $"userId" specifies the userId column in the predictions DataFrame.
+ // We could also write predictions("userId") but do not have to since
+ // the movies DataFrame does not have a column "userId."
+ val movies = sc.textFile(params.movies).map(Movie.parseMovie).toDF()
+ val falsePositives = predictions.join(movies)
+ .where((predictions("movieId") === movies("movieId"))
+ && ($"rating" <= 1) && ($"prediction" >= 4))
+ .select($"userId", predictions("movieId"), $"title", $"rating", $"prediction")
+ val numFalsePositives = falsePositives.count()
+ println(s"Found $numFalsePositives false positives")
+ if (numFalsePositives > 0) {
+ println(s"Example false positives:")
+ falsePositives.limit(100).collect().foreach(println)
+ }
sc.stop()
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index 80c9f5ff5781e..bf805149d0af6 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -40,8 +40,8 @@ object SimpleParamsExample {
import sqlContext.implicits._
// Prepare training data.
- // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans
- // into DataFrames, where it uses the bean metadata to infer the schema.
+ // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of case classes
+ // into DataFrames, where it uses the case class metadata to infer the schema.
val training = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
@@ -58,7 +58,7 @@ object SimpleParamsExample {
.setRegParam(0.01)
// Learn a LogisticRegression model. This uses the parameters stored in lr.
- val model1 = lr.fit(training)
+ val model1 = lr.fit(training.toDF())
// Since model1 is a Model (i.e., a Transformer produced by an Estimator),
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
@@ -77,7 +77,7 @@ object SimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
- val model2 = lr.fit(training, paramMapCombined)
+ val model2 = lr.fit(training.toDF(), paramMapCombined)
println("Model 2 was fit using parameters: " + model2.fittingParamMap)
// Prepare test data.
@@ -90,11 +90,11 @@ object SimpleParamsExample {
// LogisticRegression.transform will only use the 'features' column.
// Note that model2.transform() outputs a 'myProbability' column instead of the usual
// 'probability' column since we renamed the lr.probabilityCol parameter previously.
- model2.transform(test)
+ model2.transform(test.toDF())
.select("features", "label", "myProbability", "prediction")
.collect()
.foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
- println("($features, $label) -> prob=$prob, prediction=$prediction")
+ println(s"($features, $label) -> prob=$prob, prediction=$prediction")
}
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
index 968cb292120d8..6772efd2c581c 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -69,7 +69,7 @@ object SimpleTextClassificationPipeline {
.setStages(Array(tokenizer, hashingTF, lr))
// Fit the pipeline to training documents.
- val model = pipeline.fit(training)
+ val model = pipeline.fit(training.toDF())
// Prepare test documents, which are unlabeled.
val test = sc.parallelize(Seq(
@@ -79,11 +79,11 @@ object SimpleTextClassificationPipeline {
Document(7L, "apache hadoop")))
// Make predictions on test documents.
- model.transform(test)
+ model.transform(test.toDF())
.select("id", "text", "probability", "prediction")
.collect()
.foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
- println("($id, $text) --> prob=$prob, prediction=$prediction")
+ println(s"($id, $text) --> prob=$prob, prediction=$prediction")
}
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
index 89b6255991a38..e943d6c889fab 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala
@@ -81,18 +81,18 @@ object DatasetExample {
println(s"Loaded ${origData.count()} instances from file: ${params.input}")
// Convert input data to DataFrame explicitly.
- val df: DataFrame = origData.toDataFrame
+ val df: DataFrame = origData.toDF()
println(s"Inferred schema:\n${df.schema.prettyJson}")
println(s"Converted to DataFrame with ${df.count()} records")
- // Select columns, using implicit conversion to DataFrames.
- val labelsDf: DataFrame = origData.select("label")
+ // Select columns
+ val labelsDf: DataFrame = df.select("label")
val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v }
val numLabels = labels.count()
val meanLabel = labels.fold(0.0)(_ + _) / numLabels
println(s"Selected label column with average value $meanLabel")
- val featuresDf: DataFrame = origData.select("features")
+ val featuresDf: DataFrame = df.select("features")
val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v }
val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())(
(summary, feat) => summary.add(feat),
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
index 11e35598baf50..14cc5cbb679c5 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseKMeans.scala
@@ -56,7 +56,7 @@ object DenseKMeans {
.text(s"number of clusters, required")
.action((x, c) => c.copy(k = x))
opt[Int]("numIterations")
- .text(s"number of iterations, default; ${defaultParams.numIterations}")
+ .text(s"number of iterations, default: ${defaultParams.numIterations}")
.action((x, c) => c.copy(numIterations = x))
opt[String]("initMode")
.text(s"initialization mode (${InitializationMode.values.mkString(",")}), " +
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala
new file mode 100644
index 0000000000000..13f24a1e59610
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import scopt.OptionParser
+
+import org.apache.spark.mllib.fpm.FPGrowth
+import org.apache.spark.{SparkConf, SparkContext}
+
+/**
+ * Example for mining frequent itemsets using FP-growth.
+ * Example usage: ./bin/run-example mllib.FPGrowthExample \
+ * --minSupport 0.8 --numPartition 2 ./data/mllib/sample_fpgrowth.txt
+ */
+object FPGrowthExample {
+
+ case class Params(
+ input: String = null,
+ minSupport: Double = 0.3,
+ numPartition: Int = -1) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("FPGrowthExample") {
+ head("FPGrowth: an example FP-growth app.")
+ opt[Double]("minSupport")
+ .text(s"minimal support level, default: ${defaultParams.minSupport}")
+ .action((x, c) => c.copy(minSupport = x))
+ opt[Int]("numPartition")
+ .text(s"number of partition, default: ${defaultParams.numPartition}")
+ .action((x, c) => c.copy(numPartition = x))
+ arg[String]("")
+ .text("input paths to input data set, whose file format is that each line " +
+ "contains a transaction with each item in String and separated by a space")
+ .required()
+ .action((x, c) => c.copy(input = x))
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf().setAppName(s"FPGrowthExample with $params")
+ val sc = new SparkContext(conf)
+ val transactions = sc.textFile(params.input).map(_.split(" ")).cache()
+
+ println(s"Number of transactions: ${transactions.count()}")
+
+ val model = new FPGrowth()
+ .setMinSupport(params.minSupport)
+ .setNumPartitions(params.numPartition)
+ .run(transactions)
+
+ println(s"Number of frequent itemsets: ${model.freqItemsets.count()}")
+
+ model.freqItemsets.collect().foreach { itemset =>
+ println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq)
+ }
+
+ sc.stop()
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index a1f4c1c4a7dab..0e1b27a8bd2ee 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -26,7 +26,7 @@ import scopt.OptionParser
import org.apache.log4j.{Level, Logger}
import org.apache.spark.{SparkContext, SparkConf}
-import org.apache.spark.mllib.clustering.{DistributedLDAModel, LDA}
+import org.apache.spark.mllib.clustering.LDA
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
@@ -137,7 +137,7 @@ object LDAExample {
sc.setCheckpointDir(params.checkpointDir.get)
}
val startTime = System.nanoTime()
- val ldaModel = lda.run(corpus).asInstanceOf[DistributedLDAModel]
+ val ldaModel = lda.run(corpus)
val elapsed = (System.nanoTime() - startTime) / 1e9
println(s"Finished training LDA model. Summary:")
@@ -159,7 +159,6 @@ object LDAExample {
}
println()
}
- sc.stop()
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
new file mode 100644
index 0000000000000..91c9772744f18
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.log4j.{Level, Logger}
+import scopt.OptionParser
+
+import org.apache.spark.mllib.clustering.PowerIterationClustering
+import org.apache.spark.rdd.RDD
+import org.apache.spark.{SparkConf, SparkContext}
+
+/**
+ * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app.
+ * Takes an input of K concentric circles and the number of points in the innermost circle.
+ * The output should be K clusters - each cluster containing precisely the points associated
+ * with each of the input circles.
+ *
+ * Run with
+ * {{{
+ * ./bin/run-example mllib.PowerIterationClusteringExample [options]
+ *
+ * Where options include:
+ * k: Number of circles/clusters
+ * n: Number of sampled points on innermost circle.. There are proportionally more points
+ * within the outer/larger circles
+ * maxIterations: Number of Power Iterations
+ * outerRadius: radius of the outermost of the concentric circles
+ * }}}
+ *
+ * Here is a sample run and output:
+ *
+ * ./bin/run-example mllib.PowerIterationClusteringExample -k 3 --n 30 --maxIterations 15
+ *
+ * Cluster assignments: 1 -> [0,1,2,3,4],2 -> [5,6,7,8,9,10,11,12,13,14],
+ * 0 -> [15,16,17,18,19,20,21,22,23,24,25,26,27,28,29]
+ *
+ *
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object PowerIterationClusteringExample {
+
+ case class Params(
+ input: String = null,
+ k: Int = 3,
+ numPoints: Int = 5,
+ maxIterations: Int = 10,
+ outerRadius: Double = 3.0
+ ) extends AbstractParams[Params]
+
+ def main(args: Array[String]) {
+ val defaultParams = Params()
+
+ val parser = new OptionParser[Params]("PIC Circles") {
+ head("PowerIterationClusteringExample: an example PIC app using concentric circles.")
+ opt[Int]('k', "k")
+ .text(s"number of circles (/clusters), default: ${defaultParams.k}")
+ .action((x, c) => c.copy(k = x))
+ opt[Int]('n', "n")
+ .text(s"number of points in smallest circle, default: ${defaultParams.numPoints}")
+ .action((x, c) => c.copy(numPoints = x))
+ opt[Int]("maxIterations")
+ .text(s"number of iterations, default: ${defaultParams.maxIterations}")
+ .action((x, c) => c.copy(maxIterations = x))
+ opt[Int]('r', "r")
+ .text(s"radius of outermost circle, default: ${defaultParams.outerRadius}")
+ .action((x, c) => c.copy(numPoints = x))
+ }
+
+ parser.parse(args, defaultParams).map { params =>
+ run(params)
+ }.getOrElse {
+ sys.exit(1)
+ }
+ }
+
+ def run(params: Params) {
+ val conf = new SparkConf()
+ .setMaster("local")
+ .setAppName(s"PowerIterationClustering with $params")
+ val sc = new SparkContext(conf)
+
+ Logger.getRootLogger.setLevel(Level.WARN)
+
+ val circlesRdd = generateCirclesRdd(sc, params.k, params.numPoints, params.outerRadius)
+ val model = new PowerIterationClustering()
+ .setK(params.k)
+ .setMaxIterations(params.maxIterations)
+ .run(circlesRdd)
+
+ val clusters = model.assignments.collect().groupBy(_.cluster).mapValues(_.map(_.id))
+ val assignments = clusters.toList.sortBy { case (k, v) => v.length}
+ val assignmentsStr = assignments
+ .map { case (k, v) =>
+ s"$k -> ${v.sorted.mkString("[", ",", "]")}"
+ }.mkString(",")
+ val sizesStr = assignments.map {
+ _._2.size
+ }.sorted.mkString("(", ",", ")")
+ println(s"Cluster assignments: $assignmentsStr\ncluster sizes: $sizesStr")
+
+ sc.stop()
+ }
+
+ def generateCircle(radius: Double, n: Int) = {
+ Seq.tabulate(n) { i =>
+ val theta = 2.0 * math.Pi * i / n
+ (radius * math.cos(theta), radius * math.sin(theta))
+ }
+ }
+
+ def generateCirclesRdd(sc: SparkContext,
+ nCircles: Int = 3,
+ nPoints: Int = 30,
+ outerRadius: Double): RDD[(Long, Long, Double)] = {
+
+ val radii = Array.tabulate(nCircles) { cx => outerRadius / (nCircles - cx)}
+ val groupSizes = Array.tabulate(nCircles) { cx => (cx + 1) * nPoints}
+ val points = (0 until nCircles).flatMap { cx =>
+ generateCircle(radii(cx), groupSizes(cx))
+ }.zipWithIndex
+ val rdd = sc.parallelize(points)
+ val distancesRdd = rdd.cartesian(rdd).flatMap { case (((x0, y0), i0), ((x1, y1), i1)) =>
+ if (i0 < i1) {
+ Some((i0.toLong, i1.toLong, gaussianSimilarity((x0, y0), (x1, y1), 1.0)))
+ } else {
+ None
+ }
+ }
+ distancesRdd
+ }
+
+ /**
+ * Gaussian Similarity: http://en.wikipedia.org/wiki/Radial_basis_function_kernel
+ */
+ def gaussianSimilarity(p1: (Double, Double), p2: (Double, Double), sigma: Double) = {
+ val coeff = 1.0 / (math.sqrt(2.0 * math.Pi) * sigma)
+ val expCoeff = -1.0 / 2.0 * math.pow(sigma, 2.0)
+ val ssquares = (p1._1 - p2._1) * (p1._1 - p2._1) + (p1._2 - p2._2) * (p1._2 - p2._2)
+ coeff * math.exp(expCoeff * ssquares)
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
index 1eac3c8d03e39..6331d1c0060f8 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
@@ -19,7 +19,7 @@ package org.apache.spark.examples.sql
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
// One method for defining the schema of an RDD is to make a case class with the desired column
// names and types.
@@ -34,10 +34,10 @@ object RDDRelation {
// Importing the SQL context gives access to all the SQL functions and implicit conversions.
import sqlContext.implicits._
- val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
+ val df = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))).toDF()
// Any RDD containing case classes can be registered as a table. The schema of the table is
// automatically inferred using scala reflection.
- rdd.registerTempTable("records")
+ df.registerTempTable("records")
// Once tables have been registered, you can run SQL queries over them.
println("Result of SELECT *:")
@@ -55,10 +55,10 @@ object RDDRelation {
rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println)
// Queries can also be written using a LINQ-like Scala DSL.
- rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println)
+ df.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println)
// Write out an RDD as a parquet file.
- rdd.saveAsParquetFile("pair.parquet")
+ df.saveAsParquetFile("pair.parquet")
// Read in parquet file. Parquet files are self-describing so the schmema is preserved.
val parquetFile = sqlContext.parquetFile("pair.parquet")
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
index 15754cdfcc35e..b7ba60ec28155 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala
@@ -68,7 +68,7 @@ object HiveFromSpark {
// You can also register RDDs as temporary tables within a HiveContext.
val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
- rdd.registerTempTable("records")
+ rdd.toDF().registerTempTable("records")
// Queries can then join RDD data with data stored in Hive.
println("Result of SELECT *:")
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
index 6ff0c47793a25..f40caad322f59 100644
--- a/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/streaming/MQTTWordCount.scala
@@ -17,8 +17,8 @@
package org.apache.spark.examples.streaming
-import org.eclipse.paho.client.mqttv3.{MqttClient, MqttClientPersistence, MqttException, MqttMessage, MqttTopic}
-import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence
+import org.eclipse.paho.client.mqttv3._
+import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.{Seconds, StreamingContext}
@@ -31,8 +31,6 @@ import org.apache.spark.SparkConf
*/
object MQTTPublisher {
- var client: MqttClient = _
-
def main(args: Array[String]) {
if (args.length < 2) {
System.err.println("Usage: MQTTPublisher ")
@@ -42,25 +40,36 @@ object MQTTPublisher {
StreamingExamples.setStreamingLogLevels()
val Seq(brokerUrl, topic) = args.toSeq
+
+ var client: MqttClient = null
try {
- var peristance:MqttClientPersistence =new MqttDefaultFilePersistence("/tmp")
- client = new MqttClient(brokerUrl, MqttClient.generateClientId(), peristance)
+ val persistence = new MemoryPersistence()
+ client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence)
+
+ client.connect()
+
+ val msgtopic = client.getTopic(topic)
+ val msgContent = "hello mqtt demo for spark streaming"
+ val message = new MqttMessage(msgContent.getBytes("utf-8"))
+
+ while (true) {
+ try {
+ msgtopic.publish(message)
+ println(s"Published data. topic: {msgtopic.getName()}; Message: {message}")
+ } catch {
+ case e: MqttException if e.getReasonCode == MqttException.REASON_CODE_MAX_INFLIGHT =>
+ Thread.sleep(10)
+ println("Queue is full, wait for to consume data from the message queue")
+ }
+ }
} catch {
case e: MqttException => println("Exception Caught: " + e)
+ } finally {
+ if (client != null) {
+ client.disconnect()
+ }
}
-
- client.connect()
-
- val msgtopic: MqttTopic = client.getTopic(topic)
- val msg: String = "hello mqtt demo for spark streaming"
-
- while (true) {
- val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes("utf-8"))
- msgtopic.publish(message)
- println("Published data. topic: " + msgtopic.getName() + " Message: " + message)
- }
- client.disconnect()
}
}
@@ -96,9 +105,9 @@ object MQTTWordCount {
val sparkConf = new SparkConf().setAppName("MQTTWordCount")
val ssc = new StreamingContext(sparkConf, Seconds(2))
val lines = MQTTUtils.createStream(ssc, brokerUrl, topic, StorageLevel.MEMORY_ONLY_SER_2)
-
- val words = lines.flatMap(x => x.toString.split(" "))
+ val words = lines.flatMap(x => x.split(" "))
val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _)
+
wordCounts.print()
ssc.start()
ssc.awaitTermination()
diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala
index 4b732c1592ab2..44dec45c227ca 100644
--- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala
+++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala
@@ -19,7 +19,6 @@ package org.apache.spark.streaming.flume
import java.net.InetSocketAddress
-import org.apache.spark.annotation.Experimental
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext}
@@ -121,7 +120,6 @@ object FlumeUtils {
* @param port Port of the host at which the Spark Sink is listening
* @param storageLevel Storage level to use for storing the received objects
*/
- @Experimental
def createPollingStream(
ssc: StreamingContext,
hostname: String,
@@ -138,7 +136,6 @@ object FlumeUtils {
* @param addresses List of InetSocketAddresses representing the hosts to connect to.
* @param storageLevel Storage level to use for storing the received objects
*/
- @Experimental
def createPollingStream(
ssc: StreamingContext,
addresses: Seq[InetSocketAddress],
@@ -159,7 +156,6 @@ object FlumeUtils {
* result in this stream using more threads
* @param storageLevel Storage level to use for storing the received objects
*/
- @Experimental
def createPollingStream(
ssc: StreamingContext,
addresses: Seq[InetSocketAddress],
@@ -178,7 +174,6 @@ object FlumeUtils {
* @param hostname Hostname of the host on which the Spark Sink is running
* @param port Port of the host at which the Spark Sink is listening
*/
- @Experimental
def createPollingStream(
jssc: JavaStreamingContext,
hostname: String,
@@ -195,7 +190,6 @@ object FlumeUtils {
* @param port Port of the host at which the Spark Sink is listening
* @param storageLevel Storage level to use for storing the received objects
*/
- @Experimental
def createPollingStream(
jssc: JavaStreamingContext,
hostname: String,
@@ -212,7 +206,6 @@ object FlumeUtils {
* @param addresses List of InetSocketAddresses on which the Spark Sink is running.
* @param storageLevel Storage level to use for storing the received objects
*/
- @Experimental
def createPollingStream(
jssc: JavaStreamingContext,
addresses: Array[InetSocketAddress],
@@ -233,7 +226,6 @@ object FlumeUtils {
* result in this stream using more threads
* @param storageLevel Storage level to use for storing the received objects
*/
- @Experimental
def createPollingStream(
jssc: JavaStreamingContext,
addresses: Array[InetSocketAddress],
diff --git a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
index 1e24da7f5f60c..cfedb5a042a35 100644
--- a/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
+++ b/external/flume/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -31,7 +31,7 @@ public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
ssc = new JavaStreamingContext(conf, new Duration(1000));
ssc.checkpoint("checkpoint");
}
diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
index b57a1c71e35b9..e04d4088df7dc 100644
--- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
+++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala
@@ -34,10 +34,9 @@ import org.scalatest.{BeforeAndAfter, FunSuite}
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.ReceiverInputDStream
-import org.apache.spark.streaming.util.ManualClock
import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext}
import org.apache.spark.streaming.flume.sink._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ManualClock, Utils}
class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging {
@@ -54,7 +53,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging
def beforeFunction() {
logInfo("Using manual clock")
- conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+ conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock")
}
before(beforeFunction())
@@ -236,7 +235,7 @@ class FlumePollingStreamSuite extends FunSuite with BeforeAndAfter with Logging
tx.commit()
tx.close()
Thread.sleep(500) // Allow some time for the events to reach
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
}
null
}
diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml
index 503fc129dc4f2..8daa7ed608f6a 100644
--- a/external/kafka-assembly/pom.xml
+++ b/external/kafka-assembly/pom.xml
@@ -33,9 +33,6 @@
streaming-kafka-assembly
- scala-${scala.binary.version}
- spark-streaming-kafka-assembly-${project.version}.jar
- ${project.build.directory}/${spark.jar.dir}/${spark.jar.basename}
@@ -61,7 +58,6 @@
maven-shade-pluginfalse
- ${spark.jar}*:*
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala
new file mode 100644
index 0000000000000..5a74febb4bd46
--- /dev/null
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Broker.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Represent the host and port info for a Kafka broker.
+ * Differs from the Kafka project's internal kafka.cluster.Broker, which contains a server ID
+ */
+@Experimental
+final class Broker private(
+ /** Broker's hostname */
+ val host: String,
+ /** Broker's port */
+ val port: Int) extends Serializable {
+ override def equals(obj: Any): Boolean = obj match {
+ case that: Broker =>
+ this.host == that.host &&
+ this.port == that.port
+ case _ => false
+ }
+
+ override def hashCode: Int = {
+ 41 * (41 + host.hashCode) + port
+ }
+
+ override def toString(): String = {
+ s"Broker($host, $port)"
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Companion object that provides methods to create instances of [[Broker]].
+ */
+@Experimental
+object Broker {
+ def create(host: String, port: Int): Broker =
+ new Broker(host, port)
+
+ def apply(host: String, port: Int): Broker =
+ new Broker(host, port)
+
+ def unapply(broker: Broker): Option[(String, Int)] = {
+ if (broker == null) {
+ None
+ } else {
+ Some((broker.host, broker.port))
+ }
+ }
+}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
index c7bca43eb889d..04e65cb3d708c 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/DirectKafkaInputDStream.scala
@@ -50,14 +50,13 @@ import org.apache.spark.streaming.dstream._
* @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive)
* starting point of the stream
* @param messageHandler function for translating each message into the desired type
- * @param maxRetries maximum number of times in a row to retry getting leaders' offsets
*/
private[streaming]
class DirectKafkaInputDStream[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag,
+ U <: Decoder[K]: ClassTag,
+ T <: Decoder[V]: ClassTag,
R: ClassTag](
@transient ssc_ : StreamingContext,
val kafkaParams: Map[String, String],
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
index ccc62bfe8f057..2f7e0ab39fefd 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala
@@ -332,6 +332,9 @@ object KafkaCluster {
extends ConsumerConfig(originalProps) {
val seedBrokers: Array[(String, Int)] = brokers.split(",").map { hp =>
val hpa = hp.split(":")
+ if (hpa.size == 1) {
+ throw new SparkException(s"Broker not the in correct format of : [$brokers]")
+ }
(hpa(0), hpa(1).toInt)
}
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
index 50bf7cbdb8dbf..d56cc01be9514 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala
@@ -36,14 +36,12 @@ import kafka.utils.VerifiableProperties
* Starting and ending offsets are specified in advance,
* so that you can control exactly-once semantics.
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
- * @param batch Each KafkaRDDPartition in the batch corresponds to a
- * range of offsets for a given Kafka topic/partition
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD
* @param messageHandler function for translating each message into the desired type
*/
-private[spark]
+private[kafka]
class KafkaRDD[
K: ClassTag,
V: ClassTag,
@@ -183,7 +181,7 @@ class KafkaRDD[
}
}
-private[spark]
+private[kafka]
object KafkaRDD {
import KafkaCluster.LeaderOffset
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
index 36372e08f65f6..a842a6f17766f 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala
@@ -26,7 +26,7 @@ import org.apache.spark.Partition
* @param host preferred kafka host, i.e. the leader at the time the rdd was created
* @param port preferred kafka host's port
*/
-private[spark]
+private[kafka]
class KafkaRDDPartition(
val index: Int,
val topic: String,
@@ -36,24 +36,3 @@ class KafkaRDDPartition(
val host: String,
val port: Int
) extends Partition
-
-private[spark]
-object KafkaRDDPartition {
- def apply(
- index: Int,
- topic: String,
- partition: Int,
- fromOffset: Long,
- untilOffset: Long,
- host: String,
- port: Int
- ): KafkaRDDPartition = new KafkaRDDPartition(
- index,
- topic,
- partition,
- fromOffset,
- untilOffset,
- host,
- port
- )
-}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
index f8aa6c5c6263c..5a9bd4214cf51 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala
@@ -18,27 +18,30 @@
package org.apache.spark.streaming.kafka
import java.lang.{Integer => JInt}
+import java.lang.{Long => JLong}
import java.util.{Map => JMap}
+import java.util.{Set => JSet}
import scala.reflect.ClassTag
import scala.collection.JavaConversions._
import kafka.common.TopicAndPartition
import kafka.message.MessageAndMetadata
-import kafka.serializer.{Decoder, StringDecoder}
-
+import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder}
+import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
-import org.apache.spark.streaming.api.java.{JavaPairReceiverInputDStream, JavaStreamingContext}
+import org.apache.spark.streaming.api.java.{JavaPairInputDStream, JavaInputDStream, JavaPairReceiverInputDStream, JavaStreamingContext}
import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream}
+import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
object KafkaUtils {
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* @param ssc StreamingContext object
* @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..)
* @param groupId The group id for this consumer
@@ -62,7 +65,7 @@ object KafkaUtils {
}
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* @param ssc StreamingContext object
* @param kafkaParams Map of kafka configuration parameters,
* see http://kafka.apache.org/08/configuration.html
@@ -81,7 +84,7 @@ object KafkaUtils {
}
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* Storage level of the data will be the default StorageLevel.MEMORY_AND_DISK_SER_2.
* @param jssc JavaStreamingContext object
* @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..)
@@ -99,7 +102,7 @@ object KafkaUtils {
}
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* @param jssc JavaStreamingContext object
* @param zkQuorum Zookeeper quorum (hostname:port,hostname:port,..).
* @param groupId The group id for this consumer.
@@ -119,10 +122,10 @@ object KafkaUtils {
}
/**
- * Create an input stream that pulls messages from a Kafka Broker.
+ * Create an input stream that pulls messages from Kafka Brokers.
* @param jssc JavaStreamingContext object
- * @param keyTypeClass Key type of RDD
- * @param valueTypeClass value type of RDD
+ * @param keyTypeClass Key type of DStream
+ * @param valueTypeClass value type of Dstream
* @param keyDecoderClass Type of kafka key decoder
* @param valueDecoderClass Type of kafka value decoder
* @param kafkaParams Map of kafka configuration parameters,
@@ -151,14 +154,27 @@ object KafkaUtils {
jssc.ssc, kafkaParams.toMap, Map(topics.mapValues(_.intValue()).toSeq: _*), storageLevel)
}
- /** A batch-oriented interface for consuming from Kafka.
- * Starting and ending offsets are specified in advance,
- * so that you can control exactly-once semantics.
+ /** get leaders for the given offset ranges, or throw an exception */
+ private def leadersForRanges(
+ kafkaParams: Map[String, String],
+ offsetRanges: Array[OffsetRange]): Map[TopicAndPartition, (String, Int)] = {
+ val kc = new KafkaCluster(kafkaParams)
+ val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet
+ val leaders = kc.findLeaders(topics).fold(
+ errs => throw new SparkException(errs.mkString("\n")),
+ ok => ok
+ )
+ leaders
+ }
+
+ /**
+ * Create a RDD from Kafka using offset ranges for each topic and partition.
+ *
* @param sc SparkContext object
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
* @param offsetRanges Each OffsetRange in the batch corresponds to a
* range of offsets for a given Kafka topic/partition
*/
@@ -166,134 +182,212 @@ object KafkaUtils {
def createRDD[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag] (
+ KD <: Decoder[K]: ClassTag,
+ VD <: Decoder[V]: ClassTag](
sc: SparkContext,
kafkaParams: Map[String, String],
offsetRanges: Array[OffsetRange]
- ): RDD[(K, V)] = {
+ ): RDD[(K, V)] = {
val messageHandler = (mmd: MessageAndMetadata[K, V]) => (mmd.key, mmd.message)
- val kc = new KafkaCluster(kafkaParams)
- val topics = offsetRanges.map(o => TopicAndPartition(o.topic, o.partition)).toSet
- val leaders = kc.findLeaders(topics).fold(
- errs => throw new SparkException(errs.mkString("\n")),
- ok => ok
- )
- new KafkaRDD[K, V, U, T, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler)
+ val leaders = leadersForRanges(kafkaParams, offsetRanges)
+ new KafkaRDD[K, V, KD, VD, (K, V)](sc, kafkaParams, offsetRanges, leaders, messageHandler)
}
- /** A batch-oriented interface for consuming from Kafka.
- * Starting and ending offsets are specified in advance,
- * so that you can control exactly-once semantics.
+ /**
+ * :: Experimental ::
+ * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you
+ * specify the Kafka leader to connect to (to optimize fetching) and access the message as well
+ * as the metadata.
+ *
* @param sc SparkContext object
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
* @param offsetRanges Each OffsetRange in the batch corresponds to a
* range of offsets for a given Kafka topic/partition
- * @param leaders Kafka leaders for each offset range in batch
- * @param messageHandler function for translating each message into the desired type
+ * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map,
+ * in which case leaders will be looked up on the driver.
+ * @param messageHandler Function for translating each message and metadata into the desired type
*/
@Experimental
def createRDD[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag,
- R: ClassTag] (
+ KD <: Decoder[K]: ClassTag,
+ VD <: Decoder[V]: ClassTag,
+ R: ClassTag](
sc: SparkContext,
kafkaParams: Map[String, String],
offsetRanges: Array[OffsetRange],
- leaders: Array[Leader],
+ leaders: Map[TopicAndPartition, Broker],
messageHandler: MessageAndMetadata[K, V] => R
- ): RDD[R] = {
-
- val leaderMap = leaders
- .map(l => TopicAndPartition(l.topic, l.partition) -> (l.host, l.port))
- .toMap
- new KafkaRDD[K, V, U, T, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler)
+ ): RDD[R] = {
+ val leaderMap = if (leaders.isEmpty) {
+ leadersForRanges(kafkaParams, offsetRanges)
+ } else {
+ // This could be avoided by refactoring KafkaRDD.leaders and KafkaCluster to use Broker
+ leaders.map {
+ case (tp: TopicAndPartition, Broker(host, port)) => (tp, (host, port))
+ }.toMap
+ }
+ new KafkaRDD[K, V, KD, VD, R](sc, kafkaParams, offsetRanges, leaderMap, messageHandler)
}
+
/**
- * This stream can guarantee that each message from Kafka is included in transformations
- * (as opposed to output actions) exactly once, even in most failure situations.
- *
- * Points to note:
+ * Create a RDD from Kafka using offset ranges for each topic and partition.
*
- * Failure Recovery - You must checkpoint this stream, or save offsets yourself and provide them
- * as the fromOffsets parameter on restart.
- * Kafka must have sufficient log retention to obtain messages after failure.
- *
- * Getting offsets from the stream - see programming guide
+ * @param jsc JavaSparkContext object
+ * @param kafkaParams Kafka
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
+ * @param offsetRanges Each OffsetRange in the batch corresponds to a
+ * range of offsets for a given Kafka topic/partition
+ */
+ @Experimental
+ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V]](
+ jsc: JavaSparkContext,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ keyDecoderClass: Class[KD],
+ valueDecoderClass: Class[VD],
+ kafkaParams: JMap[String, String],
+ offsetRanges: Array[OffsetRange]
+ ): JavaPairRDD[K, V] = {
+ implicit val keyCmt: ClassTag[K] = ClassTag(keyClass)
+ implicit val valueCmt: ClassTag[V] = ClassTag(valueClass)
+ implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass)
+ implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass)
+ new JavaPairRDD(createRDD[K, V, KD, VD](
+ jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges))
+ }
+
+ /**
+ * :: Experimental ::
+ * Create a RDD from Kafka using offset ranges for each topic and partition. This allows you
+ * specify the Kafka leader to connect to (to optimize fetching) and access the message as well
+ * as the metadata.
*
-. * Zookeeper - This does not use Zookeeper to store offsets. For interop with Kafka monitors
- * that depend on Zookeeper, you must store offsets in ZK yourself.
+ * @param jsc JavaSparkContext object
+ * @param kafkaParams Kafka
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
+ * @param offsetRanges Each OffsetRange in the batch corresponds to a
+ * range of offsets for a given Kafka topic/partition
+ * @param leaders Kafka brokers for each TopicAndPartition in offsetRanges. May be an empty map,
+ * in which case leaders will be looked up on the driver.
+ * @param messageHandler Function for translating each message and metadata into the desired type
+ */
+ @Experimental
+ def createRDD[K, V, KD <: Decoder[K], VD <: Decoder[V], R](
+ jsc: JavaSparkContext,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ keyDecoderClass: Class[KD],
+ valueDecoderClass: Class[VD],
+ recordClass: Class[R],
+ kafkaParams: JMap[String, String],
+ offsetRanges: Array[OffsetRange],
+ leaders: JMap[TopicAndPartition, Broker],
+ messageHandler: JFunction[MessageAndMetadata[K, V], R]
+ ): JavaRDD[R] = {
+ implicit val keyCmt: ClassTag[K] = ClassTag(keyClass)
+ implicit val valueCmt: ClassTag[V] = ClassTag(valueClass)
+ implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass)
+ implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass)
+ implicit val recordCmt: ClassTag[R] = ClassTag(recordClass)
+ val leaderMap = Map(leaders.toSeq: _*)
+ createRDD[K, V, KD, VD, R](
+ jsc.sc, Map(kafkaParams.toSeq: _*), offsetRanges, leaderMap, messageHandler.call _)
+ }
+
+ /**
+ * :: Experimental ::
+ * Create an input stream that directly pulls messages from Kafka Brokers
+ * without using any receiver. This stream can guarantee that each message
+ * from Kafka is included in transformations exactly once (see points below).
*
- * End-to-end semantics - This does not guarantee that any output operation will push each record
- * exactly once. To ensure end-to-end exactly-once semantics (that is, receiving exactly once and
- * outputting exactly once), you have to either ensure that the output operation is
- * idempotent, or transactionally store offsets with the output. See the programming guide for
- * more details.
+ * Points to note:
+ * - No receivers: This stream does not use any receiver. It directly queries Kafka
+ * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
+ * You can access the offsets used in each batch from the generated RDDs (see
+ * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
+ * - Failure Recovery: To recover from driver failures, you have to enable checkpointing
+ * in the [[StreamingContext]]. The information on consumed offset can be
+ * recovered from the checkpoint. See the programming guide for details (constraints, etc.).
+ * - End-to-end semantics: This stream ensures that every records is effectively received and
+ * transformed exactly once, but gives no guarantees on whether the transformed data are
+ * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure
+ * that the output operation is idempotent, or use transactions to output records atomically.
+ * See the programming guide for more details.
*
* @param ssc StreamingContext object
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
- * @param messageHandler function for translating each message into the desired type
- * @param fromOffsets per-topic/partition Kafka offsets defining the (inclusive)
- * starting point of the stream
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers) specified in
+ * host1:port1,host2:port2 form.
+ * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive)
+ * starting point of the stream
+ * @param messageHandler Function for translating each message and metadata into the desired type
*/
@Experimental
def createDirectStream[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag,
+ KD <: Decoder[K]: ClassTag,
+ VD <: Decoder[V]: ClassTag,
R: ClassTag] (
ssc: StreamingContext,
kafkaParams: Map[String, String],
fromOffsets: Map[TopicAndPartition, Long],
messageHandler: MessageAndMetadata[K, V] => R
): InputDStream[R] = {
- new DirectKafkaInputDStream[K, V, U, T, R](
+ new DirectKafkaInputDStream[K, V, KD, VD, R](
ssc, kafkaParams, fromOffsets, messageHandler)
}
/**
- * This stream can guarantee that each message from Kafka is included in transformations
- * (as opposed to output actions) exactly once, even in most failure situations.
+ * :: Experimental ::
+ * Create an input stream that directly pulls messages from Kafka Brokers
+ * without using any receiver. This stream can guarantee that each message
+ * from Kafka is included in transformations exactly once (see points below).
*
* Points to note:
- *
- * Failure Recovery - You must checkpoint this stream.
- * Kafka must have sufficient log retention to obtain messages after failure.
- *
- * Getting offsets from the stream - see programming guide
- *
-. * Zookeeper - This does not use Zookeeper to store offsets. For interop with Kafka monitors
- * that depend on Zookeeper, you must store offsets in ZK yourself.
- *
- * End-to-end semantics - This does not guarantee that any output operation will push each record
- * exactly once. To ensure end-to-end exactly-once semantics (that is, receiving exactly once and
- * outputting exactly once), you have to ensure that the output operation is idempotent.
+ * - No receivers: This stream does not use any receiver. It directly queries Kafka
+ * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
+ * You can access the offsets used in each batch from the generated RDDs (see
+ * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
+ * - Failure Recovery: To recover from driver failures, you have to enable checkpointing
+ * in the [[StreamingContext]]. The information on consumed offset can be
+ * recovered from the checkpoint. See the programming guide for details (constraints, etc.).
+ * - End-to-end semantics: This stream ensures that every records is effectively received and
+ * transformed exactly once, but gives no guarantees on whether the transformed data are
+ * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure
+ * that the output operation is idempotent, or use transactions to output records atomically.
+ * See the programming guide for more details.
*
* @param ssc StreamingContext object
* @param kafkaParams Kafka
- * configuration parameters.
- * Requires "metadata.broker.list" or "bootstrap.servers" to be set with Kafka broker(s),
- * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
- * If starting without a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest"
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers), specified in
+ * host1:port1,host2:port2 form.
+ * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest"
* to determine where the stream starts (defaults to "largest")
- * @param topics names of the topics to consume
+ * @param topics Names of the topics to consume
*/
@Experimental
def createDirectStream[
K: ClassTag,
V: ClassTag,
- U <: Decoder[_]: ClassTag,
- T <: Decoder[_]: ClassTag] (
+ KD <: Decoder[K]: ClassTag,
+ VD <: Decoder[V]: ClassTag] (
ssc: StreamingContext,
kafkaParams: Map[String, String],
topics: Set[String]
@@ -313,11 +407,155 @@ object KafkaUtils {
val fromOffsets = leaderOffsets.map { case (tp, lo) =>
(tp, lo.offset)
}
- new DirectKafkaInputDStream[K, V, U, T, (K, V)](
+ new DirectKafkaInputDStream[K, V, KD, VD, (K, V)](
ssc, kafkaParams, fromOffsets, messageHandler)
}).fold(
errs => throw new SparkException(errs.mkString("\n")),
ok => ok
)
}
+
+ /**
+ * :: Experimental ::
+ * Create an input stream that directly pulls messages from Kafka Brokers
+ * without using any receiver. This stream can guarantee that each message
+ * from Kafka is included in transformations exactly once (see points below).
+ *
+ * Points to note:
+ * - No receivers: This stream does not use any receiver. It directly queries Kafka
+ * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
+ * You can access the offsets used in each batch from the generated RDDs (see
+ * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
+ * - Failure Recovery: To recover from driver failures, you have to enable checkpointing
+ * in the [[StreamingContext]]. The information on consumed offset can be
+ * recovered from the checkpoint. See the programming guide for details (constraints, etc.).
+ * - End-to-end semantics: This stream ensures that every records is effectively received and
+ * transformed exactly once, but gives no guarantees on whether the transformed data are
+ * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure
+ * that the output operation is idempotent, or use transactions to output records atomically.
+ * See the programming guide for more details.
+ *
+ * @param jssc JavaStreamingContext object
+ * @param keyClass Class of the keys in the Kafka records
+ * @param valueClass Class of the values in the Kafka records
+ * @param keyDecoderClass Class of the key decoder
+ * @param valueDecoderClass Class of the value decoder
+ * @param recordClass Class of the records in DStream
+ * @param kafkaParams Kafka
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers), specified in
+ * host1:port1,host2:port2 form.
+ * @param fromOffsets Per-topic/partition Kafka offsets defining the (inclusive)
+ * starting point of the stream
+ * @param messageHandler Function for translating each message and metadata into the desired type
+ */
+ @Experimental
+ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V], R](
+ jssc: JavaStreamingContext,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ keyDecoderClass: Class[KD],
+ valueDecoderClass: Class[VD],
+ recordClass: Class[R],
+ kafkaParams: JMap[String, String],
+ fromOffsets: JMap[TopicAndPartition, JLong],
+ messageHandler: JFunction[MessageAndMetadata[K, V], R]
+ ): JavaInputDStream[R] = {
+ implicit val keyCmt: ClassTag[K] = ClassTag(keyClass)
+ implicit val valueCmt: ClassTag[V] = ClassTag(valueClass)
+ implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass)
+ implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass)
+ implicit val recordCmt: ClassTag[R] = ClassTag(recordClass)
+ createDirectStream[K, V, KD, VD, R](
+ jssc.ssc,
+ Map(kafkaParams.toSeq: _*),
+ Map(fromOffsets.mapValues { _.longValue() }.toSeq: _*),
+ messageHandler.call _
+ )
+ }
+
+ /**
+ * :: Experimental ::
+ * Create an input stream that directly pulls messages from Kafka Brokers
+ * without using any receiver. This stream can guarantee that each message
+ * from Kafka is included in transformations exactly once (see points below).
+ *
+ * Points to note:
+ * - No receivers: This stream does not use any receiver. It directly queries Kafka
+ * - Offsets: This does not use Zookeeper to store offsets. The consumed offsets are tracked
+ * by the stream itself. For interoperability with Kafka monitoring tools that depend on
+ * Zookeeper, you have to update Kafka/Zookeeper yourself from the streaming application.
+ * You can access the offsets used in each batch from the generated RDDs (see
+ * [[org.apache.spark.streaming.kafka.HasOffsetRanges]]).
+ * - Failure Recovery: To recover from driver failures, you have to enable checkpointing
+ * in the [[StreamingContext]]. The information on consumed offset can be
+ * recovered from the checkpoint. See the programming guide for details (constraints, etc.).
+ * - End-to-end semantics: This stream ensures that every records is effectively received and
+ * transformed exactly once, but gives no guarantees on whether the transformed data are
+ * outputted exactly once. For end-to-end exactly-once semantics, you have to either ensure
+ * that the output operation is idempotent, or use transactions to output records atomically.
+ * See the programming guide for more details.
+ *
+ * @param jssc JavaStreamingContext object
+ * @param keyClass Class of the keys in the Kafka records
+ * @param valueClass Class of the values in the Kafka records
+ * @param keyDecoderClass Class of the key decoder
+ * @param valueDecoderClass Class type of the value decoder
+ * @param kafkaParams Kafka
+ * configuration parameters. Requires "metadata.broker.list" or "bootstrap.servers"
+ * to be set with Kafka broker(s) (NOT zookeeper servers), specified in
+ * host1:port1,host2:port2 form.
+ * If not starting from a checkpoint, "auto.offset.reset" may be set to "largest" or "smallest"
+ * to determine where the stream starts (defaults to "largest")
+ * @param topics Names of the topics to consume
+ */
+ @Experimental
+ def createDirectStream[K, V, KD <: Decoder[K], VD <: Decoder[V]](
+ jssc: JavaStreamingContext,
+ keyClass: Class[K],
+ valueClass: Class[V],
+ keyDecoderClass: Class[KD],
+ valueDecoderClass: Class[VD],
+ kafkaParams: JMap[String, String],
+ topics: JSet[String]
+ ): JavaPairInputDStream[K, V] = {
+ implicit val keyCmt: ClassTag[K] = ClassTag(keyClass)
+ implicit val valueCmt: ClassTag[V] = ClassTag(valueClass)
+ implicit val keyDecoderCmt: ClassTag[KD] = ClassTag(keyDecoderClass)
+ implicit val valueDecoderCmt: ClassTag[VD] = ClassTag(valueDecoderClass)
+ createDirectStream[K, V, KD, VD](
+ jssc.ssc,
+ Map(kafkaParams.toSeq: _*),
+ Set(topics.toSeq: _*)
+ )
+ }
+}
+
+/**
+ * This is a helper class that wraps the KafkaUtils.createStream() into more
+ * Python-friendly class and function so that it can be easily
+ * instantiated and called from Python's KafkaUtils (see SPARK-6027).
+ *
+ * The zero-arg constructor helps instantiate this class from the Class object
+ * classOf[KafkaUtilsPythonHelper].newInstance(), and the createStream()
+ * takes care of known parameters instead of passing them from Python
+ */
+private class KafkaUtilsPythonHelper {
+ def createStream(
+ jssc: JavaStreamingContext,
+ kafkaParams: JMap[String, String],
+ topics: JMap[String, JInt],
+ storageLevel: StorageLevel): JavaPairReceiverInputDStream[Array[Byte], Array[Byte]] = {
+ KafkaUtils.createStream[Array[Byte], Array[Byte], DefaultDecoder, DefaultDecoder](
+ jssc,
+ classOf[Array[Byte]],
+ classOf[Array[Byte]],
+ classOf[DefaultDecoder],
+ classOf[DefaultDecoder],
+ kafkaParams,
+ topics,
+ storageLevel)
+ }
}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Leader.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Leader.scala
deleted file mode 100644
index 3454d92e72b47..0000000000000
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/Leader.scala
+++ /dev/null
@@ -1,46 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.kafka
-
-import kafka.common.TopicAndPartition
-
-/** Host info for the leader of a Kafka TopicAndPartition */
-final class Leader private(
- /** kafka topic name */
- val topic: String,
- /** kafka partition id */
- val partition: Int,
- /** kafka hostname */
- val host: String,
- /** kafka host's port */
- val port: Int) extends Serializable
-
-object Leader {
- def create(topic: String, partition: Int, host: String, port: Int): Leader =
- new Leader(topic, partition, host, port)
-
- def create(topicAndPartition: TopicAndPartition, host: String, port: Int): Leader =
- new Leader(topicAndPartition.topic, topicAndPartition.partition, host, port)
-
- def apply(topic: String, partition: Int, host: String, port: Int): Leader =
- new Leader(topic, partition, host, port)
-
- def apply(topicAndPartition: TopicAndPartition, host: String, port: Int): Leader =
- new Leader(topicAndPartition.topic, topicAndPartition.partition, host, port)
-
-}
diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
index 334c12e4627b4..9c3dfeb8f5928 100644
--- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
+++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/OffsetRange.scala
@@ -19,16 +19,35 @@ package org.apache.spark.streaming.kafka
import kafka.common.TopicAndPartition
-/** Something that has a collection of OffsetRanges */
+import org.apache.spark.annotation.Experimental
+
+/**
+ * :: Experimental ::
+ * Represents any object that has a collection of [[OffsetRange]]s. This can be used access the
+ * offset ranges in RDDs generated by the direct Kafka DStream (see
+ * [[KafkaUtils.createDirectStream()]]).
+ * {{{
+ * KafkaUtils.createDirectStream(...).foreachRDD { rdd =>
+ * val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ * ...
+ * }
+ * }}}
+ */
+@Experimental
trait HasOffsetRanges {
def offsetRanges: Array[OffsetRange]
}
-/** Represents a range of offsets from a single Kafka TopicAndPartition */
+/**
+ * :: Experimental ::
+ * Represents a range of offsets from a single Kafka TopicAndPartition. Instances of this class
+ * can be created with `OffsetRange.create()`.
+ */
+@Experimental
final class OffsetRange private(
- /** kafka topic name */
+ /** Kafka topic name */
val topic: String,
- /** kafka partition id */
+ /** Kafka partition id */
val partition: Int,
/** inclusive starting offset */
val fromOffset: Long,
@@ -36,11 +55,33 @@ final class OffsetRange private(
val untilOffset: Long) extends Serializable {
import OffsetRange.OffsetRangeTuple
+ override def equals(obj: Any): Boolean = obj match {
+ case that: OffsetRange =>
+ this.topic == that.topic &&
+ this.partition == that.partition &&
+ this.fromOffset == that.fromOffset &&
+ this.untilOffset == that.untilOffset
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ toTuple.hashCode()
+ }
+
+ override def toString(): String = {
+ s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset]"
+ }
+
/** this is to avoid ClassNotFoundException during checkpoint restore */
private[streaming]
def toTuple: OffsetRangeTuple = (topic, partition, fromOffset, untilOffset)
}
+/**
+ * :: Experimental ::
+ * Companion object the provides methods to create instances of [[OffsetRange]].
+ */
+@Experimental
object OffsetRange {
def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange =
new OffsetRange(topic, partition, fromOffset, untilOffset)
@@ -61,10 +102,10 @@ object OffsetRange {
new OffsetRange(topicAndPartition.topic, topicAndPartition.partition, fromOffset, untilOffset)
/** this is to avoid ClassNotFoundException during checkpoint restore */
- private[spark]
+ private[kafka]
type OffsetRangeTuple = (String, Int, Long, Long)
- private[streaming]
+ private[kafka]
def apply(t: OffsetRangeTuple) =
new OffsetRange(t._1, t._2, t._3, t._4)
}
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
new file mode 100644
index 0000000000000..1334cc8fd1b57
--- /dev/null
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Random;
+import java.util.Arrays;
+
+import org.apache.spark.SparkConf;
+
+import scala.Tuple2;
+
+import junit.framework.Assert;
+
+import kafka.common.TopicAndPartition;
+import kafka.message.MessageAndMetadata;
+import kafka.serializer.StringDecoder;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.Durations;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+
+import org.junit.Test;
+import org.junit.After;
+import org.junit.Before;
+
+public class JavaDirectKafkaStreamSuite implements Serializable {
+ private transient JavaStreamingContext ssc = null;
+ private transient Random random = new Random();
+ private transient KafkaStreamSuiteBase suiteBase = null;
+
+ @Before
+ public void setUp() {
+ suiteBase = new KafkaStreamSuiteBase() { };
+ suiteBase.setupKafka();
+ System.clearProperty("spark.driver.port");
+ SparkConf sparkConf = new SparkConf()
+ .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
+ ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200));
+ }
+
+ @After
+ public void tearDown() {
+ ssc.stop();
+ ssc = null;
+ System.clearProperty("spark.driver.port");
+ suiteBase.tearDownKafka();
+ }
+
+ @Test
+ public void testKafkaStream() throws InterruptedException {
+ String topic1 = "topic1";
+ String topic2 = "topic2";
+
+ String[] topic1data = createTopicAndSendData(topic1);
+ String[] topic2data = createTopicAndSendData(topic2);
+
+ HashSet sent = new HashSet();
+ sent.addAll(Arrays.asList(topic1data));
+ sent.addAll(Arrays.asList(topic2data));
+
+ HashMap kafkaParams = new HashMap();
+ kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress());
+ kafkaParams.put("auto.offset.reset", "smallest");
+
+ JavaDStream stream1 = KafkaUtils.createDirectStream(
+ ssc,
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ kafkaParams,
+ topicToSet(topic1)
+ ).map(
+ new Function, String>() {
+ @Override
+ public String call(scala.Tuple2 kv) throws Exception {
+ return kv._2();
+ }
+ }
+ );
+
+ JavaDStream stream2 = KafkaUtils.createDirectStream(
+ ssc,
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ String.class,
+ kafkaParams,
+ topicOffsetToMap(topic2, (long) 0),
+ new Function, String>() {
+ @Override
+ public String call(MessageAndMetadata msgAndMd) throws Exception {
+ return msgAndMd.message();
+ }
+ }
+ );
+ JavaDStream unifiedStream = stream1.union(stream2);
+
+ final HashSet result = new HashSet();
+ unifiedStream.foreachRDD(
+ new Function, Void>() {
+ @Override
+ public Void call(org.apache.spark.api.java.JavaRDD rdd) throws Exception {
+ result.addAll(rdd.collect());
+ return null;
+ }
+ }
+ );
+ ssc.start();
+ long startTime = System.currentTimeMillis();
+ boolean matches = false;
+ while (!matches && System.currentTimeMillis() - startTime < 20000) {
+ matches = sent.size() == result.size();
+ Thread.sleep(50);
+ }
+ Assert.assertEquals(sent, result);
+ ssc.stop();
+ }
+
+ private HashSet topicToSet(String topic) {
+ HashSet topicSet = new HashSet();
+ topicSet.add(topic);
+ return topicSet;
+ }
+
+ private HashMap topicOffsetToMap(String topic, Long offsetToStart) {
+ HashMap topicMap = new HashMap();
+ topicMap.put(new TopicAndPartition(topic, 0), offsetToStart);
+ return topicMap;
+ }
+
+ private String[] createTopicAndSendData(String topic) {
+ String[] data = { topic + "-1", topic + "-2", topic + "-3"};
+ suiteBase.createTopic(topic);
+ suiteBase.sendMessages(topic, data);
+ return data;
+ }
+}
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
new file mode 100644
index 0000000000000..9d2e1705c6c73
--- /dev/null
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaRDDSuite.java
@@ -0,0 +1,156 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Arrays;
+
+import org.apache.spark.SparkConf;
+
+import scala.Tuple2;
+
+import junit.framework.Assert;
+
+import kafka.common.TopicAndPartition;
+import kafka.message.MessageAndMetadata;
+import kafka.serializer.StringDecoder;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+
+import org.junit.Test;
+import org.junit.After;
+import org.junit.Before;
+
+public class JavaKafkaRDDSuite implements Serializable {
+ private transient JavaSparkContext sc = null;
+ private transient KafkaStreamSuiteBase suiteBase = null;
+
+ @Before
+ public void setUp() {
+ suiteBase = new KafkaStreamSuiteBase() { };
+ suiteBase.setupKafka();
+ System.clearProperty("spark.driver.port");
+ SparkConf sparkConf = new SparkConf()
+ .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
+ sc = new JavaSparkContext(sparkConf);
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ suiteBase.tearDownKafka();
+ }
+
+ @Test
+ public void testKafkaRDD() throws InterruptedException {
+ String topic1 = "topic1";
+ String topic2 = "topic2";
+
+ String[] topic1data = createTopicAndSendData(topic1);
+ String[] topic2data = createTopicAndSendData(topic2);
+
+ HashMap kafkaParams = new HashMap();
+ kafkaParams.put("metadata.broker.list", suiteBase.brokerAddress());
+
+ OffsetRange[] offsetRanges = {
+ OffsetRange.create(topic1, 0, 0, 1),
+ OffsetRange.create(topic2, 0, 0, 1)
+ };
+
+ HashMap emptyLeaders = new HashMap();
+ HashMap leaders = new HashMap();
+ String[] hostAndPort = suiteBase.brokerAddress().split(":");
+ Broker broker = Broker.create(hostAndPort[0], Integer.parseInt(hostAndPort[1]));
+ leaders.put(new TopicAndPartition(topic1, 0), broker);
+ leaders.put(new TopicAndPartition(topic2, 0), broker);
+
+ JavaRDD rdd1 = KafkaUtils.createRDD(
+ sc,
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ kafkaParams,
+ offsetRanges
+ ).map(
+ new Function, String>() {
+ @Override
+ public String call(scala.Tuple2 kv) throws Exception {
+ return kv._2();
+ }
+ }
+ );
+
+ JavaRDD rdd2 = KafkaUtils.createRDD(
+ sc,
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ String.class,
+ kafkaParams,
+ offsetRanges,
+ emptyLeaders,
+ new Function, String>() {
+ @Override
+ public String call(MessageAndMetadata msgAndMd) throws Exception {
+ return msgAndMd.message();
+ }
+ }
+ );
+
+ JavaRDD rdd3 = KafkaUtils.createRDD(
+ sc,
+ String.class,
+ String.class,
+ StringDecoder.class,
+ StringDecoder.class,
+ String.class,
+ kafkaParams,
+ offsetRanges,
+ leaders,
+ new Function, String>() {
+ @Override
+ public String call(MessageAndMetadata msgAndMd) throws Exception {
+ return msgAndMd.message();
+ }
+ }
+ );
+
+ // just making sure the java user apis work; the scala tests handle logic corner cases
+ long count1 = rdd1.count();
+ long count2 = rdd2.count();
+ long count3 = rdd3.count();
+ Assert.assertTrue(count1 > 0);
+ Assert.assertEquals(count1, count2);
+ Assert.assertEquals(count1, count3);
+ }
+
+ private String[] createTopicAndSendData(String topic) {
+ String[] data = { topic + "-1", topic + "-2", topic + "-3"};
+ suiteBase.createTopic(topic);
+ suiteBase.sendMessages(topic, data);
+ return data;
+ }
+}
diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
index 6e1abf3f385ee..208cc51b29876 100644
--- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
+++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java
@@ -79,9 +79,10 @@ public void testKafkaStream() throws InterruptedException {
suiteBase.createTopic(topic);
HashMap tmp = new HashMap(sent);
- suiteBase.produceAndSendMessage(topic,
+ suiteBase.sendMessages(topic,
JavaConverters.mapAsScalaMapConverter(tmp).asScala().toMap(
- Predef.>conforms()));
+ Predef.>conforms())
+ );
HashMap kafkaParams = new HashMap();
kafkaParams.put("zookeeper.connect", suiteBase.zkAddress());
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
new file mode 100644
index 0000000000000..17ca9d145d665
--- /dev/null
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -0,0 +1,306 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka
+
+import java.io.File
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.language.postfixOps
+
+import kafka.common.TopicAndPartition
+import kafka.message.MessageAndMetadata
+import kafka.serializer.StringDecoder
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.util.Utils
+
+class DirectKafkaStreamSuite extends KafkaStreamSuiteBase
+ with BeforeAndAfter with BeforeAndAfterAll with Eventually {
+ val sparkConf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName(this.getClass.getSimpleName)
+
+ var sc: SparkContext = _
+ var ssc: StreamingContext = _
+ var testDir: File = _
+
+ override def beforeAll {
+ setupKafka()
+ }
+
+ override def afterAll {
+ tearDownKafka()
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ sc = null
+ }
+ if (sc != null) {
+ sc.stop()
+ }
+ if (testDir != null) {
+ Utils.deleteRecursively(testDir)
+ }
+ }
+
+
+ test("basic stream receiving with multiple topics and smallest starting offset") {
+ val topics = Set("basic1", "basic2", "basic3")
+ val data = Map("a" -> 7, "b" -> 9)
+ topics.foreach { t =>
+ createTopic(t)
+ sendMessages(t, data)
+ }
+ val totalSent = data.values.sum * topics.size
+ val kafkaParams = Map(
+ "metadata.broker.list" -> s"$brokerAddress",
+ "auto.offset.reset" -> "smallest"
+ )
+
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, topics)
+ }
+
+ val allReceived = new ArrayBuffer[(String, String)]
+
+ stream.foreachRDD { rdd =>
+ // Get the offset ranges in the RDD
+ val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ val collected = rdd.mapPartitionsWithIndex { (i, iter) =>
+ // For each partition, get size of the range in the partition,
+ // and the number of items in the partition
+ val off = offsets(i)
+ val all = iter.toSeq
+ val partSize = all.size
+ val rangeSize = off.untilOffset - off.fromOffset
+ Iterator((partSize, rangeSize))
+ }.collect
+
+ // Verify whether number of elements in each partition
+ // matches with the corresponding offset range
+ collected.foreach { case (partSize, rangeSize) =>
+ assert(partSize === rangeSize, "offset ranges are wrong")
+ }
+ }
+ stream.foreachRDD { rdd => allReceived ++= rdd.collect() }
+ ssc.start()
+ eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
+ assert(allReceived.size === totalSent,
+ "didn't get expected number of messages, messages:\n" + allReceived.mkString("\n"))
+ }
+ ssc.stop()
+ }
+
+ test("receiving from largest starting offset") {
+ val topic = "largest"
+ val topicPartition = TopicAndPartition(topic, 0)
+ val data = Map("a" -> 10)
+ createTopic(topic)
+ val kafkaParams = Map(
+ "metadata.broker.list" -> s"$brokerAddress",
+ "auto.offset.reset" -> "largest"
+ )
+ val kc = new KafkaCluster(kafkaParams)
+ def getLatestOffset(): Long = {
+ kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset
+ }
+
+ // Send some initial messages before starting context
+ sendMessages(topic, data)
+ eventually(timeout(10 seconds), interval(20 milliseconds)) {
+ assert(getLatestOffset() > 3)
+ }
+ val offsetBeforeStart = getLatestOffset()
+
+ // Setup context and kafka stream with largest offset
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, Set(topic))
+ }
+ assert(
+ stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]]
+ .fromOffsets(topicPartition) >= offsetBeforeStart,
+ "Start offset not from latest"
+ )
+
+ val collectedData = new mutable.ArrayBuffer[String]()
+ stream.map { _._2 }.foreachRDD { rdd => collectedData ++= rdd.collect() }
+ ssc.start()
+ val newData = Map("b" -> 10)
+ sendMessages(topic, newData)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ collectedData.contains("b")
+ }
+ assert(!collectedData.contains("a"))
+ }
+
+
+ test("creating stream by offset") {
+ val topic = "offset"
+ val topicPartition = TopicAndPartition(topic, 0)
+ val data = Map("a" -> 10)
+ createTopic(topic)
+ val kafkaParams = Map(
+ "metadata.broker.list" -> s"$brokerAddress",
+ "auto.offset.reset" -> "largest"
+ )
+ val kc = new KafkaCluster(kafkaParams)
+ def getLatestOffset(): Long = {
+ kc.getLatestLeaderOffsets(Set(topicPartition)).right.get(topicPartition).offset
+ }
+
+ // Send some initial messages before starting context
+ sendMessages(topic, data)
+ eventually(timeout(10 seconds), interval(20 milliseconds)) {
+ assert(getLatestOffset() >= 10)
+ }
+ val offsetBeforeStart = getLatestOffset()
+
+ // Setup context and kafka stream with largest offset
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder, String](
+ ssc, kafkaParams, Map(topicPartition -> 11L),
+ (m: MessageAndMetadata[String, String]) => m.message())
+ }
+ assert(
+ stream.asInstanceOf[DirectKafkaInputDStream[_, _, _, _, _]]
+ .fromOffsets(topicPartition) >= offsetBeforeStart,
+ "Start offset not from latest"
+ )
+
+ val collectedData = new mutable.ArrayBuffer[String]()
+ stream.foreachRDD { rdd => collectedData ++= rdd.collect() }
+ ssc.start()
+ val newData = Map("b" -> 10)
+ sendMessages(topic, newData)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ collectedData.contains("b")
+ }
+ assert(!collectedData.contains("a"))
+ }
+
+ // Test to verify the offset ranges can be recovered from the checkpoints
+ test("offset recovery") {
+ val topic = "recovery"
+ createTopic(topic)
+ testDir = Utils.createTempDir()
+
+ val kafkaParams = Map(
+ "metadata.broker.list" -> s"$brokerAddress",
+ "auto.offset.reset" -> "smallest"
+ )
+
+ // Send data to Kafka and wait for it to be received
+ def sendDataAndWaitForReceive(data: Seq[Int]) {
+ val strings = data.map { _.toString}
+ sendMessages(topic, strings.map { _ -> 1}.toMap)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains })
+ }
+ }
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(100))
+ val kafkaStream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
+ ssc, kafkaParams, Set(topic))
+ }
+ val keyedStream = kafkaStream.map { v => "key" -> v._2.toInt }
+ val stateStream = keyedStream.updateStateByKey { (values: Seq[Int], state: Option[Int]) =>
+ Some(values.sum + state.getOrElse(0))
+ }
+ ssc.checkpoint(testDir.getAbsolutePath)
+
+ // This is to collect the raw data received from Kafka
+ kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) =>
+ val data = rdd.map { _._2 }.collect()
+ DirectKafkaStreamSuite.collectedData.appendAll(data)
+ }
+
+ // This is ensure all the data is eventually receiving only once
+ stateStream.foreachRDD { (rdd: RDD[(String, Int)]) =>
+ rdd.collect().headOption.foreach { x => DirectKafkaStreamSuite.total = x._2 }
+ }
+ ssc.start()
+
+ // Send some data and wait for them to be received
+ for (i <- (1 to 10).grouped(4)) {
+ sendDataAndWaitForReceive(i)
+ }
+
+ // Verify that offset ranges were generated
+ val offsetRangesBeforeStop = getOffsetRanges(kafkaStream)
+ assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated")
+ assert(
+ offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 },
+ "starting offset not zero"
+ )
+ ssc.stop()
+ logInfo("====== RESTARTING ========")
+
+ // Recover context from checkpoints
+ ssc = new StreamingContext(testDir.getAbsolutePath)
+ val recoveredStream = ssc.graph.getInputStreams().head.asInstanceOf[DStream[(String, String)]]
+
+ // Verify offset ranges have been recovered
+ val recoveredOffsetRanges = getOffsetRanges(recoveredStream)
+ assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered")
+ val earlierOffsetRangesAsSets = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) }
+ assert(
+ recoveredOffsetRanges.forall { or =>
+ earlierOffsetRangesAsSets.contains((or._1, or._2.toSet))
+ },
+ "Recovered ranges are not the same as the ones generated"
+ )
+
+ // Restart context, give more data and verify the total at the end
+ // If the total is write that means each records has been received only once
+ ssc.start()
+ sendDataAndWaitForReceive(11 to 20)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ assert(DirectKafkaStreamSuite.total === (1 to 20).sum)
+ }
+ ssc.stop()
+ }
+
+ /** Get the generated offset ranges from the DirectKafkaStream */
+ private def getOffsetRanges[K, V](
+ kafkaStream: DStream[(K, V)]): Seq[(Time, Array[OffsetRange])] = {
+ kafkaStream.generatedRDDs.mapValues { rdd =>
+ rdd.asInstanceOf[KafkaRDD[K, V, _, _, (K, V)]].offsetRanges
+ }.toSeq.sortBy { _._1 }
+ }
+}
+
+object DirectKafkaStreamSuite {
+ val collectedData = new mutable.ArrayBuffer[String]()
+ var total = -1L
+}
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
index e57c8f6987fdc..fc9275b7207be 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaClusterSuite.scala
@@ -19,33 +19,29 @@ package org.apache.spark.streaming.kafka
import scala.util.Random
-import org.scalatest.BeforeAndAfter
import kafka.common.TopicAndPartition
+import org.scalatest.BeforeAndAfterAll
-class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
- val brokerHost = "localhost"
-
- val kafkaParams = Map("metadata.broker.list" -> s"$brokerHost:$brokerPort")
-
- val kc = new KafkaCluster(kafkaParams)
-
+class KafkaClusterSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll {
val topic = "kcsuitetopic" + Random.nextInt(10000)
-
val topicAndPartition = TopicAndPartition(topic, 0)
+ var kc: KafkaCluster = null
- before {
+ override def beforeAll() {
setupKafka()
createTopic(topic)
- produceAndSendMessage(topic, Map("a" -> 1))
+ sendMessages(topic, Map("a" -> 1))
+ kc = new KafkaCluster(Map("metadata.broker.list" -> s"$brokerAddress"))
}
- after {
+ override def afterAll() {
tearDownKafka()
}
test("metadata apis") {
- val leader = kc.findLeaders(Set(topicAndPartition)).right.get
- assert(leader(topicAndPartition) === (brokerHost, brokerPort), "didn't get leader")
+ val leader = kc.findLeaders(Set(topicAndPartition)).right.get(topicAndPartition)
+ val leaderAddress = s"${leader._1}:${leader._2}"
+ assert(leaderAddress === brokerAddress, "didn't get leader")
val parts = kc.getPartitions(Set(topic)).right.get
assert(parts(topicAndPartition), "didn't get partitions")
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaDirectStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaDirectStreamSuite.scala
deleted file mode 100644
index 0891ce344f16a..0000000000000
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaDirectStreamSuite.scala
+++ /dev/null
@@ -1,92 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.kafka
-
-import scala.util.Random
-import scala.concurrent.duration._
-
-import org.scalatest.BeforeAndAfter
-import org.scalatest.concurrent.Eventually
-
-import kafka.serializer.StringDecoder
-
-import org.apache.spark.SparkConf
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Milliseconds, StreamingContext}
-
-class KafkaDirectStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter with Eventually {
- val sparkConf = new SparkConf()
- .setMaster("local[4]")
- .setAppName(this.getClass.getSimpleName)
-
- val brokerHost = "localhost"
-
- val kafkaParams = Map(
- "metadata.broker.list" -> s"$brokerHost:$brokerPort",
- "auto.offset.reset" -> "smallest"
- )
-
- var ssc: StreamingContext = _
-
- before {
- setupKafka()
-
- ssc = new StreamingContext(sparkConf, Milliseconds(500))
- }
-
- after {
- if (ssc != null) {
- ssc.stop()
- }
- tearDownKafka()
- }
-
- test("multi topic stream") {
- val topics = Set("newA", "newB")
- val data = Map("a" -> 7, "b" -> 9)
- topics.foreach { t =>
- createTopic(t)
- produceAndSendMessage(t, data)
- }
- val stream = KafkaUtils.createDirectStream[String, String, StringDecoder, StringDecoder](
- ssc, kafkaParams, topics)
- var total = 0L;
-
- stream.foreachRDD { rdd =>
- val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
- val collected = rdd.mapPartitionsWithIndex { (i, iter) =>
- val off = offsets(i)
- val all = iter.toSeq
- val partSize = all.size
- val rangeSize = off.untilOffset - off.fromOffset
- all.map { _ =>
- (partSize, rangeSize)
- }.toIterator
- }.collect
- collected.foreach { case (partSize, rangeSize) =>
- assert(partSize === rangeSize, "offset ranges are wrong")
- }
- total += collected.size
- }
- ssc.start()
- eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
- assert(total === data.values.sum * topics.size, "didn't get all messages")
- }
- ssc.stop()
- }
-}
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index 9b9e3f5fce8bd..a223da70b043f 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -21,18 +21,22 @@ import scala.util.Random
import kafka.serializer.StringDecoder
import kafka.common.TopicAndPartition
-import org.scalatest.BeforeAndAfter
+import kafka.message.MessageAndMetadata
+import org.scalatest.BeforeAndAfterAll
import org.apache.spark._
import org.apache.spark.SparkContext._
-class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
+class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfterAll {
+ val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
var sc: SparkContext = _
- before {
+ override def beforeAll {
+ sc = new SparkContext(sparkConf)
+
setupKafka()
}
- after {
+ override def afterAll {
if (sc != null) {
sc.stop
sc = null
@@ -40,60 +44,94 @@ class KafkaRDDSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
tearDownKafka()
}
- test("Kafka RDD") {
- val sparkConf = new SparkConf().setMaster("local[4]").setAppName(this.getClass.getSimpleName)
- sc = new SparkContext(sparkConf)
+ test("basic usage") {
+ val topic = "topicbasic"
+ createTopic(topic)
+ val messages = Set("the", "quick", "brown", "fox")
+ sendMessages(topic, messages.toArray)
+
+
+ val kafkaParams = Map("metadata.broker.list" -> brokerAddress,
+ "group.id" -> s"test-consumer-${Random.nextInt(10000)}")
+
+ val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
+
+ val rdd = KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder](
+ sc, kafkaParams, offsetRanges)
+
+ val received = rdd.map(_._2).collect.toSet
+ assert(received === messages)
+ }
+
+ test("iterator boundary conditions") {
+ // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd
val topic = "topic1"
val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
createTopic(topic)
- produceAndSendMessage(topic, sent)
- val kafkaParams = Map("metadata.broker.list" -> s"localhost:$brokerPort",
+ val kafkaParams = Map("metadata.broker.list" -> brokerAddress,
"group.id" -> s"test-consumer-${Random.nextInt(10000)}")
val kc = new KafkaCluster(kafkaParams)
- val rdd = getRdd(kc, Set(topic))
// this is the "lots of messages" case
- // make sure we get all of them
+ sendMessages(topic, sent)
+ // rdd defined from leaders after sending messages, should get the number sent
+ val rdd = getRdd(kc, Set(topic))
+
assert(rdd.isDefined)
- assert(rdd.get.count === sent.values.sum)
+ assert(rdd.get.count === sent.values.sum, "didn't get all sent messages")
- kc.setConsumerOffsets(
- kafkaParams("group.id"),
- rdd.get.offsetRanges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap)
+ val ranges = rdd.get.asInstanceOf[HasOffsetRanges]
+ .offsetRanges.map(o => TopicAndPartition(o.topic, o.partition) -> o.untilOffset).toMap
+
+ kc.setConsumerOffsets(kafkaParams("group.id"), ranges)
- val rdd2 = getRdd(kc, Set(topic))
- val sent2 = Map("d" -> 1)
- produceAndSendMessage(topic, sent2)
// this is the "0 messages" case
- // make sure we dont get anything, since messages were sent after rdd was defined
+ val rdd2 = getRdd(kc, Set(topic))
+ // shouldn't get anything, since message is sent after rdd was defined
+ val sentOnlyOne = Map("d" -> 1)
+
+ sendMessages(topic, sentOnlyOne)
assert(rdd2.isDefined)
- assert(rdd2.get.count === 0)
+ assert(rdd2.get.count === 0, "got messages when there shouldn't be any")
+ // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above
val rdd3 = getRdd(kc, Set(topic))
- produceAndSendMessage(topic, Map("extra" -> 22))
- // this is the "exactly 1 message" case
- // make sure we get exactly one message, despite there being lots more available
+ // send lots of messages after rdd was defined, they shouldn't show up
+ sendMessages(topic, Map("extra" -> 22))
+
assert(rdd3.isDefined)
- assert(rdd3.get.count === sent2.values.sum)
+ assert(rdd3.get.count === sentOnlyOne.values.sum, "didn't get exactly one message")
}
// get an rdd from the committed consumer offsets until the latest leader offsets,
private def getRdd(kc: KafkaCluster, topics: Set[String]) = {
val groupId = kc.kafkaParams("group.id")
- for {
- topicPartitions <- kc.getPartitions(topics).right.toOption
- from <- kc.getConsumerOffsets(groupId, topicPartitions).right.toOption.orElse(
+ def consumerOffsets(topicPartitions: Set[TopicAndPartition]) = {
+ kc.getConsumerOffsets(groupId, topicPartitions).right.toOption.orElse(
kc.getEarliestLeaderOffsets(topicPartitions).right.toOption.map { offs =>
offs.map(kv => kv._1 -> kv._2.offset)
}
)
- until <- kc.getLatestLeaderOffsets(topicPartitions).right.toOption
- } yield {
- KafkaRDD[String, String, StringDecoder, StringDecoder, String](
- sc, kc.kafkaParams, from, until, mmd => s"${mmd.offset} ${mmd.message}")
+ }
+ kc.getPartitions(topics).right.toOption.flatMap { topicPartitions =>
+ consumerOffsets(topicPartitions).flatMap { from =>
+ kc.getLatestLeaderOffsets(topicPartitions).right.toOption.map { until =>
+ val offsetRanges = from.map { case (tp: TopicAndPartition, fromOffset: Long) =>
+ OffsetRange(tp.topic, tp.partition, fromOffset, until(tp).offset)
+ }.toArray
+
+ val leaders = until.map { case (tp: TopicAndPartition, lo: KafkaCluster.LeaderOffset) =>
+ tp -> Broker(lo.host, lo.port)
+ }.toMap
+
+ KafkaUtils.createRDD[String, String, StringDecoder, StringDecoder, String](
+ sc, kc.kafkaParams, offsetRanges, leaders,
+ (mmd: MessageAndMetadata[String, String]) => s"${mmd.offset} ${mmd.message}")
+ }
+ }
}
}
}
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
index f207dc6d4fa04..e4966eebb9b34 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaStreamSuite.scala
@@ -48,30 +48,41 @@ import org.apache.spark.util.Utils
*/
abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Logging {
- var zkAddress: String = _
- var zkClient: ZkClient = _
-
private val zkHost = "localhost"
+ private var zkPort: Int = 0
private val zkConnectionTimeout = 6000
private val zkSessionTimeout = 6000
private var zookeeper: EmbeddedZookeeper = _
- private var zkPort: Int = 0
- protected var brokerPort = 9092
+ private val brokerHost = "localhost"
+ private var brokerPort = 9092
private var brokerConf: KafkaConfig = _
private var server: KafkaServer = _
private var producer: Producer[String, String] = _
+ private var zkReady = false
+ private var brokerReady = false
+
+ protected var zkClient: ZkClient = _
+
+ def zkAddress: String = {
+ assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address")
+ s"$zkHost:$zkPort"
+ }
+
+ def brokerAddress: String = {
+ assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address")
+ s"$brokerHost:$brokerPort"
+ }
def setupKafka() {
// Zookeeper server startup
zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort")
// Get the actual zookeeper binding port
zkPort = zookeeper.actualPort
- zkAddress = s"$zkHost:$zkPort"
- logInfo("==================== 0 ====================")
+ zkReady = true
+ logInfo("==================== Zookeeper Started ====================")
- zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout,
- ZKStringSerializer)
- logInfo("==================== 1 ====================")
+ zkClient = new ZkClient(zkAddress, zkSessionTimeout, zkConnectionTimeout, ZKStringSerializer)
+ logInfo("==================== Zookeeper Client Created ====================")
// Kafka broker startup
var bindSuccess: Boolean = false
@@ -80,9 +91,8 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin
val brokerProps = getBrokerConfig()
brokerConf = new KafkaConfig(brokerProps)
server = new KafkaServer(brokerConf)
- logInfo("==================== 2 ====================")
server.startup()
- logInfo("==================== 3 ====================")
+ logInfo("==================== Kafka Broker Started ====================")
bindSuccess = true
} catch {
case e: KafkaException =>
@@ -94,10 +104,13 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin
}
Thread.sleep(2000)
- logInfo("==================== 4 ====================")
+ logInfo("==================== Kafka + Zookeeper Ready ====================")
+ brokerReady = true
}
def tearDownKafka() {
+ brokerReady = false
+ zkReady = false
if (producer != null) {
producer.close()
producer = null
@@ -121,26 +134,23 @@ abstract class KafkaStreamSuiteBase extends FunSuite with Eventually with Loggin
}
}
- private def createTestMessage(topic: String, sent: Map[String, Int])
- : Seq[KeyedMessage[String, String]] = {
- val messages = for ((s, freq) <- sent; i <- 0 until freq) yield {
- new KeyedMessage[String, String](topic, s)
- }
- messages.toSeq
- }
-
def createTopic(topic: String) {
AdminUtils.createTopic(zkClient, topic, 1, 1)
- logInfo("==================== 5 ====================")
// wait until metadata is propagated
waitUntilMetadataIsPropagated(topic, 0)
+ logInfo(s"==================== Topic $topic Created ====================")
}
- def produceAndSendMessage(topic: String, sent: Map[String, Int]) {
+ def sendMessages(topic: String, messageToFreq: Map[String, Int]) {
+ val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray
+ sendMessages(topic, messages)
+ }
+
+ def sendMessages(topic: String, messages: Array[String]) {
producer = new Producer[String, String](new ProducerConfig(getProducerConfig()))
- producer.send(createTestMessage(topic, sent): _*)
+ producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*)
producer.close()
- logInfo("==================== 6 ====================")
+ logInfo(s"==================== Sent Messages: ${messages.mkString(", ")} ====================")
}
private def getBrokerConfig(): Properties = {
@@ -218,7 +228,7 @@ class KafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter {
val topic = "topic1"
val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
createTopic(topic)
- produceAndSendMessage(topic, sent)
+ sendMessages(topic, sent)
val kafkaParams = Map("zookeeper.connect" -> zkAddress,
"group.id" -> s"test-consumer-${Random.nextInt(10000)}",
diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
index 64ccc92c81fa9..fc53c23abda85 100644
--- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
+++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/ReliableKafkaStreamSuite.scala
@@ -79,7 +79,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
test("Reliable Kafka input stream with single topic") {
var topic = "test-topic"
createTopic(topic)
- produceAndSendMessage(topic, data)
+ sendMessages(topic, data)
// Verify whether the offset of this group/topic/partition is 0 before starting.
assert(getCommitOffset(groupId, topic, 0) === None)
@@ -111,7 +111,7 @@ class ReliableKafkaStreamSuite extends KafkaStreamSuiteBase with BeforeAndAfter
val topics = Map("topic1" -> 1, "topic2" -> 1, "topic3" -> 1)
topics.foreach { case (t, _) =>
createTopic(t)
- produceAndSendMessage(t, data)
+ sendMessages(t, data)
}
// Before started, verify all the group/topic/partition offsets are 0.
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala
index 1ef91dd49284f..3c0ef94cb0fab 100644
--- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala
+++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTInputDStream.scala
@@ -17,23 +17,23 @@
package org.apache.spark.streaming.mqtt
+import java.io.IOException
+import java.util.concurrent.Executors
+import java.util.Properties
+
+import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable.HashMap
-import scala.collection.JavaConversions._
import scala.reflect.ClassTag
-import java.util.Properties
-import java.util.concurrent.Executors
-import java.io.IOException
-
+import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken
import org.eclipse.paho.client.mqttv3.MqttCallback
import org.eclipse.paho.client.mqttv3.MqttClient
import org.eclipse.paho.client.mqttv3.MqttClientPersistence
-import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
-import org.eclipse.paho.client.mqttv3.IMqttDeliveryToken
import org.eclipse.paho.client.mqttv3.MqttException
import org.eclipse.paho.client.mqttv3.MqttMessage
import org.eclipse.paho.client.mqttv3.MqttTopic
+import org.eclipse.paho.client.mqttv3.persist.MemoryPersistence
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
@@ -82,18 +82,18 @@ class MQTTReceiver(
val client = new MqttClient(brokerUrl, MqttClient.generateClientId(), persistence)
// Callback automatically triggers as and when new message arrives on specified topic
- val callback: MqttCallback = new MqttCallback() {
+ val callback = new MqttCallback() {
// Handles Mqtt message
- override def messageArrived(arg0: String, arg1: MqttMessage) {
- store(new String(arg1.getPayload(),"utf-8"))
+ override def messageArrived(topic: String, message: MqttMessage) {
+ store(new String(message.getPayload(),"utf-8"))
}
- override def deliveryComplete(arg0: IMqttDeliveryToken) {
+ override def deliveryComplete(token: IMqttDeliveryToken) {
}
- override def connectionLost(arg0: Throwable) {
- restart("Connection lost ", arg0)
+ override def connectionLost(cause: Throwable) {
+ restart("Connection lost ", cause)
}
}
diff --git a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
index c5ffe51f9986c..1142d0f56ba34 100644
--- a/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
+++ b/external/mqtt/src/main/scala/org/apache/spark/streaming/mqtt/MQTTUtils.scala
@@ -17,10 +17,11 @@
package org.apache.spark.streaming.mqtt
+import scala.reflect.ClassTag
+
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaStreamingContext, JavaDStream}
-import scala.reflect.ClassTag
import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream}
object MQTTUtils {
diff --git a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
index 1e24da7f5f60c..cfedb5a042a35 100644
--- a/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
+++ b/external/mqtt/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -31,7 +31,7 @@ public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
ssc = new JavaStreamingContext(conf, new Duration(1000));
ssc.checkpoint("checkpoint");
}
diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
index e84adc088a680..0f3298af6234a 100644
--- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
+++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTStreamSuite.scala
@@ -42,8 +42,8 @@ import org.apache.spark.util.Utils
class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
private val batchDuration = Milliseconds(500)
- private val master: String = "local[2]"
- private val framework: String = this.getClass.getSimpleName
+ private val master = "local[2]"
+ private val framework = this.getClass.getSimpleName
private val freePort = findFreePort()
private val brokerUri = "//localhost:" + freePort
private val topic = "def"
@@ -69,7 +69,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
test("mqtt input stream") {
val sendMessage = "MQTT demo for spark streaming"
- val receiveStream: ReceiverInputDStream[String] =
+ val receiveStream =
MQTTUtils.createStream(ssc, "tcp:" + brokerUri, topic, StorageLevel.MEMORY_ONLY)
@volatile var receiveMessage: List[String] = List()
receiveStream.foreachRDD { rdd =>
@@ -93,6 +93,7 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
private def setupMQTT() {
broker = new BrokerService()
+ broker.setDataDirectoryFile(Utils.createTempDir())
connector = new TransportConnector()
connector.setName("mqtt")
connector.setUri(new URI("mqtt:" + brokerUri))
@@ -122,12 +123,12 @@ class MQTTStreamSuite extends FunSuite with Eventually with BeforeAndAfter {
def publishData(data: String): Unit = {
var client: MqttClient = null
try {
- val persistence: MqttClientPersistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
+ val persistence = new MqttDefaultFilePersistence(persistenceDir.getAbsolutePath)
client = new MqttClient("tcp:" + brokerUri, MqttClient.generateClientId(), persistence)
client.connect()
if (client.isConnected) {
- val msgTopic: MqttTopic = client.getTopic(topic)
- val message: MqttMessage = new MqttMessage(data.getBytes("utf-8"))
+ val msgTopic = client.getTopic(topic)
+ val message = new MqttMessage(data.getBytes("utf-8"))
message.setQos(1)
message.setRetained(true)
diff --git a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
index 1e24da7f5f60c..cfedb5a042a35 100644
--- a/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
+++ b/external/twitter/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -31,7 +31,7 @@ public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
ssc = new JavaStreamingContext(conf, new Duration(1000));
ssc.checkpoint("checkpoint");
}
diff --git a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
index 1e24da7f5f60c..cfedb5a042a35 100644
--- a/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
+++ b/external/zeromq/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -31,7 +31,7 @@ public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
ssc = new JavaStreamingContext(conf, new Duration(1000));
ssc.checkpoint("checkpoint");
}
diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml
index c815eda52bda7..216661b8bc73a 100644
--- a/extras/kinesis-asl/pom.xml
+++ b/extras/kinesis-asl/pom.xml
@@ -67,11 +67,6 @@
scalacheck_${scala.binary.version}test
-
- org.easymock
- easymockclassextension
- test
- com.novocodejunit-interface
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
index 0b80b611cdce7..588e86a1887ec 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisCheckpointState.scala
@@ -18,9 +18,7 @@ package org.apache.spark.streaming.kinesis
import org.apache.spark.Logging
import org.apache.spark.streaming.Duration
-import org.apache.spark.streaming.util.Clock
-import org.apache.spark.streaming.util.ManualClock
-import org.apache.spark.streaming.util.SystemClock
+import org.apache.spark.util.{Clock, ManualClock, SystemClock}
/**
* This is a helper class for managing checkpoint clocks.
@@ -35,7 +33,7 @@ private[kinesis] class KinesisCheckpointState(
/* Initialize the checkpoint clock using the given currentClock + checkpointInterval millis */
val checkpointClock = new ManualClock()
- checkpointClock.setTime(currentClock.currentTime() + checkpointInterval.milliseconds)
+ checkpointClock.setTime(currentClock.getTimeMillis() + checkpointInterval.milliseconds)
/**
* Check if it's time to checkpoint based on the current time and the derived time
@@ -44,13 +42,13 @@ private[kinesis] class KinesisCheckpointState(
* @return true if it's time to checkpoint
*/
def shouldCheckpoint(): Boolean = {
- new SystemClock().currentTime() > checkpointClock.currentTime()
+ new SystemClock().getTimeMillis() > checkpointClock.getTimeMillis()
}
/**
* Advance the checkpoint clock by the checkpoint interval.
*/
def advanceCheckpoint() = {
- checkpointClock.addToTime(checkpointInterval.milliseconds)
+ checkpointClock.advance(checkpointInterval.milliseconds)
}
}
diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
index 8ecc2d90160b1..af8cd875b4541 100644
--- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
+++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala
@@ -104,7 +104,7 @@ private[kinesis] class KinesisRecordProcessor(
logDebug(s"Checkpoint: WorkerId $workerId completed checkpoint of ${batch.size}" +
s" records for shardId $shardId")
logDebug(s"Checkpoint: Next checkpoint is at " +
- s" ${checkpointState.checkpointClock.currentTime()} for shardId $shardId")
+ s" ${checkpointState.checkpointClock.getTimeMillis()} for shardId $shardId")
}
} catch {
case e: Throwable => {
diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
index 41dbd64c2b1fa..255fe65819608 100644
--- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
+++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisReceiverSuite.scala
@@ -20,17 +20,17 @@ import java.nio.ByteBuffer
import scala.collection.JavaConversions.seqAsJavaList
-import org.apache.spark.annotation.Experimental
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.Milliseconds
import org.apache.spark.streaming.Seconds
import org.apache.spark.streaming.StreamingContext
import org.apache.spark.streaming.TestSuiteBase
-import org.apache.spark.streaming.util.Clock
-import org.apache.spark.streaming.util.ManualClock
+import org.apache.spark.util.{ManualClock, Clock}
+
+import org.mockito.Mockito._
import org.scalatest.BeforeAndAfter
import org.scalatest.Matchers
-import org.scalatest.mock.EasyMockSugar
+import org.scalatest.mock.MockitoSugar
import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException
import com.amazonaws.services.kinesis.clientlibrary.exceptions.KinesisClientLibDependencyException
@@ -42,10 +42,10 @@ import com.amazonaws.services.kinesis.clientlibrary.types.ShutdownReason
import com.amazonaws.services.kinesis.model.Record
/**
- * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor
+ * Suite of Kinesis streaming receiver tests focusing mostly on the KinesisRecordProcessor
*/
class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAfter
- with EasyMockSugar {
+ with MockitoSugar {
val app = "TestKinesisReceiver"
val stream = "mySparkStream"
@@ -73,6 +73,14 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
currentClockMock = mock[Clock]
}
+ override def afterFunction(): Unit = {
+ super.afterFunction()
+ // Since this suite was originally written using EasyMock, add this to preserve the old
+ // mocking semantics (see SPARK-5735 for more details)
+ verifyNoMoreInteractions(receiverMock, checkpointerMock, checkpointClockMock,
+ checkpointStateMock, currentClockMock)
+ }
+
test("kinesis utils api") {
val ssc = new StreamingContext(master, framework, batchDuration)
// Tests the API, does not actually test data receiving
@@ -83,193 +91,175 @@ class KinesisReceiverSuite extends TestSuiteBase with Matchers with BeforeAndAft
}
test("process records including store and checkpoint") {
- val expectedCheckpointIntervalMillis = 10
- expecting {
- receiverMock.isStopped().andReturn(false).once()
- receiverMock.store(record1.getData().array()).once()
- receiverMock.store(record2.getData().array()).once()
- checkpointStateMock.shouldCheckpoint().andReturn(true).once()
- checkpointerMock.checkpoint().once()
- checkpointStateMock.advanceCheckpoint().once()
- }
- whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) {
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId,
- checkpointStateMock)
- recordProcessor.processRecords(batch, checkpointerMock)
- }
+ when(receiverMock.isStopped()).thenReturn(false)
+ when(checkpointStateMock.shouldCheckpoint()).thenReturn(true)
+
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ recordProcessor.processRecords(batch, checkpointerMock)
+
+ verify(receiverMock, times(1)).isStopped()
+ verify(receiverMock, times(1)).store(record1.getData().array())
+ verify(receiverMock, times(1)).store(record2.getData().array())
+ verify(checkpointStateMock, times(1)).shouldCheckpoint()
+ verify(checkpointerMock, times(1)).checkpoint()
+ verify(checkpointStateMock, times(1)).advanceCheckpoint()
}
test("shouldn't store and checkpoint when receiver is stopped") {
- expecting {
- receiverMock.isStopped().andReturn(true).once()
- }
- whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) {
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId,
- checkpointStateMock)
- recordProcessor.processRecords(batch, checkpointerMock)
- }
+ when(receiverMock.isStopped()).thenReturn(true)
+
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ recordProcessor.processRecords(batch, checkpointerMock)
+
+ verify(receiverMock, times(1)).isStopped()
}
test("shouldn't checkpoint when exception occurs during store") {
- expecting {
- receiverMock.isStopped().andReturn(false).once()
- receiverMock.store(record1.getData().array()).andThrow(new RuntimeException()).once()
- }
- whenExecuting(receiverMock, checkpointerMock, checkpointStateMock) {
- intercept[RuntimeException] {
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId,
- checkpointStateMock)
- recordProcessor.processRecords(batch, checkpointerMock)
- }
+ when(receiverMock.isStopped()).thenReturn(false)
+ when(receiverMock.store(record1.getData().array())).thenThrow(new RuntimeException())
+
+ intercept[RuntimeException] {
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ recordProcessor.processRecords(batch, checkpointerMock)
}
+
+ verify(receiverMock, times(1)).isStopped()
+ verify(receiverMock, times(1)).store(record1.getData().array())
}
test("should set checkpoint time to currentTime + checkpoint interval upon instantiation") {
- expecting {
- currentClockMock.currentTime().andReturn(0).once()
- }
- whenExecuting(currentClockMock) {
+ when(currentClockMock.getTimeMillis()).thenReturn(0)
+
val checkpointIntervalMillis = 10
- val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock)
- assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis)
- }
+ val checkpointState =
+ new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock)
+ assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis)
+
+ verify(currentClockMock, times(1)).getTimeMillis()
}
test("should checkpoint if we have exceeded the checkpoint interval") {
- expecting {
- currentClockMock.currentTime().andReturn(0).once()
- }
- whenExecuting(currentClockMock) {
- val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock)
- assert(checkpointState.shouldCheckpoint())
- }
+ when(currentClockMock.getTimeMillis()).thenReturn(0)
+
+ val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MinValue), currentClockMock)
+ assert(checkpointState.shouldCheckpoint())
+
+ verify(currentClockMock, times(1)).getTimeMillis()
}
test("shouldn't checkpoint if we have not exceeded the checkpoint interval") {
- expecting {
- currentClockMock.currentTime().andReturn(0).once()
- }
- whenExecuting(currentClockMock) {
- val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock)
- assert(!checkpointState.shouldCheckpoint())
- }
+ when(currentClockMock.getTimeMillis()).thenReturn(0)
+
+ val checkpointState = new KinesisCheckpointState(Milliseconds(Long.MaxValue), currentClockMock)
+ assert(!checkpointState.shouldCheckpoint())
+
+ verify(currentClockMock, times(1)).getTimeMillis()
}
test("should add to time when advancing checkpoint") {
- expecting {
- currentClockMock.currentTime().andReturn(0).once()
- }
- whenExecuting(currentClockMock) {
- val checkpointIntervalMillis = 10
- val checkpointState = new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock)
- assert(checkpointState.checkpointClock.currentTime() == checkpointIntervalMillis)
- checkpointState.advanceCheckpoint()
- assert(checkpointState.checkpointClock.currentTime() == (2 * checkpointIntervalMillis))
- }
+ when(currentClockMock.getTimeMillis()).thenReturn(0)
+
+ val checkpointIntervalMillis = 10
+ val checkpointState =
+ new KinesisCheckpointState(Milliseconds(checkpointIntervalMillis), currentClockMock)
+ assert(checkpointState.checkpointClock.getTimeMillis() == checkpointIntervalMillis)
+ checkpointState.advanceCheckpoint()
+ assert(checkpointState.checkpointClock.getTimeMillis() == (2 * checkpointIntervalMillis))
+
+ verify(currentClockMock, times(1)).getTimeMillis()
}
test("shutdown should checkpoint if the reason is TERMINATE") {
- expecting {
- checkpointerMock.checkpoint().once()
- }
- whenExecuting(checkpointerMock, checkpointStateMock) {
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId,
- checkpointStateMock)
- val reason = ShutdownReason.TERMINATE
- recordProcessor.shutdown(checkpointerMock, reason)
- }
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ val reason = ShutdownReason.TERMINATE
+ recordProcessor.shutdown(checkpointerMock, reason)
+
+ verify(checkpointerMock, times(1)).checkpoint()
}
test("shutdown should not checkpoint if the reason is something other than TERMINATE") {
- expecting {
- }
- whenExecuting(checkpointerMock, checkpointStateMock) {
- val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId,
- checkpointStateMock)
- recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE)
- recordProcessor.shutdown(checkpointerMock, null)
- }
+ val recordProcessor = new KinesisRecordProcessor(receiverMock, workerId, checkpointStateMock)
+ recordProcessor.shutdown(checkpointerMock, ShutdownReason.ZOMBIE)
+ recordProcessor.shutdown(checkpointerMock, null)
+
+ verify(checkpointerMock, never()).checkpoint()
}
test("retry success on first attempt") {
val expectedIsStopped = false
- expecting {
- receiverMock.isStopped().andReturn(expectedIsStopped).once()
- }
- whenExecuting(receiverMock) {
- val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100)
- assert(actualVal == expectedIsStopped)
- }
+ when(receiverMock.isStopped()).thenReturn(expectedIsStopped)
+
+ val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100)
+ assert(actualVal == expectedIsStopped)
+
+ verify(receiverMock, times(1)).isStopped()
}
test("retry success on second attempt after a Kinesis throttling exception") {
val expectedIsStopped = false
- expecting {
- receiverMock.isStopped().andThrow(new ThrottlingException("error message"))
- .andReturn(expectedIsStopped).once()
- }
- whenExecuting(receiverMock) {
- val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100)
- assert(actualVal == expectedIsStopped)
- }
+ when(receiverMock.isStopped())
+ .thenThrow(new ThrottlingException("error message"))
+ .thenReturn(expectedIsStopped)
+
+ val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100)
+ assert(actualVal == expectedIsStopped)
+
+ verify(receiverMock, times(2)).isStopped()
}
test("retry success on second attempt after a Kinesis dependency exception") {
val expectedIsStopped = false
- expecting {
- receiverMock.isStopped().andThrow(new KinesisClientLibDependencyException("error message"))
- .andReturn(expectedIsStopped).once()
- }
- whenExecuting(receiverMock) {
- val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100)
- assert(actualVal == expectedIsStopped)
- }
+ when(receiverMock.isStopped())
+ .thenThrow(new KinesisClientLibDependencyException("error message"))
+ .thenReturn(expectedIsStopped)
+
+ val actualVal = KinesisRecordProcessor.retryRandom(receiverMock.isStopped(), 2, 100)
+ assert(actualVal == expectedIsStopped)
+
+ verify(receiverMock, times(2)).isStopped()
}
test("retry failed after a shutdown exception") {
- expecting {
- checkpointerMock.checkpoint().andThrow(new ShutdownException("error message")).once()
- }
- whenExecuting(checkpointerMock) {
- intercept[ShutdownException] {
- KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
- }
+ when(checkpointerMock.checkpoint()).thenThrow(new ShutdownException("error message"))
+
+ intercept[ShutdownException] {
+ KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
}
+
+ verify(checkpointerMock, times(1)).checkpoint()
}
test("retry failed after an invalid state exception") {
- expecting {
- checkpointerMock.checkpoint().andThrow(new InvalidStateException("error message")).once()
- }
- whenExecuting(checkpointerMock) {
- intercept[InvalidStateException] {
- KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
- }
+ when(checkpointerMock.checkpoint()).thenThrow(new InvalidStateException("error message"))
+
+ intercept[InvalidStateException] {
+ KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
}
+
+ verify(checkpointerMock, times(1)).checkpoint()
}
test("retry failed after unexpected exception") {
- expecting {
- checkpointerMock.checkpoint().andThrow(new RuntimeException("error message")).once()
- }
- whenExecuting(checkpointerMock) {
- intercept[RuntimeException] {
- KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
- }
+ when(checkpointerMock.checkpoint()).thenThrow(new RuntimeException("error message"))
+
+ intercept[RuntimeException] {
+ KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
}
+
+ verify(checkpointerMock, times(1)).checkpoint()
}
test("retry failed after exhausing all retries") {
val expectedErrorMessage = "final try error message"
- expecting {
- checkpointerMock.checkpoint().andThrow(new ThrottlingException("error message"))
- .andThrow(new ThrottlingException(expectedErrorMessage)).once()
- }
- whenExecuting(checkpointerMock) {
- val exception = intercept[RuntimeException] {
- KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
- }
- exception.getMessage().shouldBe(expectedErrorMessage)
+ when(checkpointerMock.checkpoint())
+ .thenThrow(new ThrottlingException("error message"))
+ .thenThrow(new ThrottlingException(expectedErrorMessage))
+
+ val exception = intercept[RuntimeException] {
+ KinesisRecordProcessor.retryRandom(checkpointerMock.checkpoint(), 2, 100)
}
+ exception.getMessage().shouldBe(expectedErrorMessage)
+
+ verify(checkpointerMock, times(2)).checkpoint()
}
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
index 6dad167fa7411..904be213147dc 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala
@@ -104,8 +104,14 @@ class VertexRDDImpl[VD] private[graphx] (
this.mapVertexPartitions(_.map(f))
override def diff(other: VertexRDD[VD]): VertexRDD[VD] = {
+ val otherPartition = other match {
+ case other: VertexRDD[_] if this.partitioner == other.partitioner =>
+ other.partitionsRDD
+ case _ =>
+ VertexRDD(other.partitionBy(this.partitioner.get)).partitionsRDD
+ }
val newPartitionsRDD = partitionsRDD.zipPartitions(
- other.partitionsRDD, preservesPartitioning = true
+ otherPartition, preservesPartitioning = true
) { (thisIter, otherIter) =>
val thisPart = thisIter.next()
val otherPart = otherIter.next()
@@ -133,7 +139,7 @@ class VertexRDDImpl[VD] private[graphx] (
// Test if the other vertex is a VertexRDD to choose the optimal join strategy.
// If the other set is a VertexRDD then we use the much more efficient leftZipJoin
other match {
- case other: VertexRDD[_] =>
+ case other: VertexRDD[_] if this.partitioner == other.partitioner =>
leftZipJoin(other)(f)
case _ =>
this.withPartitionsRDD[VD3](
@@ -162,7 +168,7 @@ class VertexRDDImpl[VD] private[graphx] (
// Test if the other vertex is a VertexRDD to choose the optimal join strategy.
// If the other set is a VertexRDD then we use the much more efficient innerZipJoin
other match {
- case other: VertexRDD[_] =>
+ case other: VertexRDD[_] if this.partitioner == other.partitioner =>
innerZipJoin(other)(f)
case _ =>
this.withPartitionsRDD(
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
index f58587e10a820..3e4157a63fd1c 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala
@@ -37,6 +37,17 @@ object SVDPlusPlus {
var gamma7: Double)
extends Serializable
+ /**
+ * This method is now replaced by the updated version of `run()` and returns exactly
+ * the same result.
+ */
+ @deprecated("Call run()", "1.4.0")
+ def runSVDPlusPlus(edges: RDD[Edge[Double]], conf: Conf)
+ : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) =
+ {
+ run(edges, conf)
+ }
+
/**
* Implement SVD++ based on "Factorization Meets the Neighborhood:
* a Multifaceted Collaborative Filtering Model",
@@ -52,7 +63,7 @@ object SVDPlusPlus {
* @return a graph with vertex attributes containing the trained model
*/
def run(edges: RDD[Edge[Double]], conf: Conf)
- : (Graph[(DoubleMatrix, DoubleMatrix, Double, Double), Double], Double) =
+ : (Graph[(Array[Double], Array[Double], Double, Double), Double], Double) =
{
// Generate default vertex attribute
def defaultF(rank: Int): (DoubleMatrix, DoubleMatrix, Double, Double) = {
@@ -72,17 +83,22 @@ object SVDPlusPlus {
// construct graph
var g = Graph.fromEdges(edges, defaultF(conf.rank)).cache()
+ materialize(g)
+ edges.unpersist()
// Calculate initial bias and norm
val t0 = g.aggregateMessages[(Long, Double)](
ctx => { ctx.sendToSrc((1L, ctx.attr)); ctx.sendToDst((1L, ctx.attr)) },
(g1, g2) => (g1._1 + g2._1, g1._2 + g2._2))
- g = g.outerJoinVertices(t0) {
+ val gJoinT0 = g.outerJoinVertices(t0) {
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
msg: Option[(Long, Double)]) =>
(vd._1, vd._2, msg.get._2 / msg.get._1, 1.0 / scala.math.sqrt(msg.get._1))
- }
+ }.cache()
+ materialize(gJoinT0)
+ g.unpersist()
+ g = gJoinT0
def sendMsgTrainF(conf: Conf, u: Double)
(ctx: EdgeContext[
@@ -114,12 +130,15 @@ object SVDPlusPlus {
val t1 = g.aggregateMessages[DoubleMatrix](
ctx => ctx.sendToSrc(ctx.dstAttr._2),
(g1, g2) => g1.addColumnVector(g2))
- g = g.outerJoinVertices(t1) {
+ val gJoinT1 = g.outerJoinVertices(t1) {
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double),
msg: Option[DoubleMatrix]) =>
if (msg.isDefined) (vd._1, vd._1
.addColumnVector(msg.get.mul(vd._4)), vd._3, vd._4) else vd
- }
+ }.cache()
+ materialize(gJoinT1)
+ g.unpersist()
+ g = gJoinT1
// Phase 2, update p for user nodes and q, y for item nodes
g.cache()
@@ -127,13 +146,16 @@ object SVDPlusPlus {
sendMsgTrainF(conf, u),
(g1: (DoubleMatrix, DoubleMatrix, Double), g2: (DoubleMatrix, DoubleMatrix, Double)) =>
(g1._1.addColumnVector(g2._1), g1._2.addColumnVector(g2._2), g1._3 + g2._3))
- g = g.outerJoinVertices(t2) {
+ val gJoinT2 = g.outerJoinVertices(t2) {
(vid: VertexId,
vd: (DoubleMatrix, DoubleMatrix, Double, Double),
msg: Option[(DoubleMatrix, DoubleMatrix, Double)]) =>
(vd._1.addColumnVector(msg.get._1), vd._2.addColumnVector(msg.get._2),
vd._3 + msg.get._3, vd._4)
- }
+ }.cache()
+ materialize(gJoinT2)
+ g.unpersist()
+ g = gJoinT2
}
// calculate error on training set
@@ -147,13 +169,28 @@ object SVDPlusPlus {
val err = (ctx.attr - pred) * (ctx.attr - pred)
ctx.sendToDst(err)
}
+
g.cache()
val t3 = g.aggregateMessages[Double](sendMsgTestF(conf, u), _ + _)
- g = g.outerJoinVertices(t3) {
+ val gJoinT3 = g.outerJoinVertices(t3) {
(vid: VertexId, vd: (DoubleMatrix, DoubleMatrix, Double, Double), msg: Option[Double]) =>
if (msg.isDefined) (vd._1, vd._2, vd._3, msg.get) else vd
- }
+ }.cache()
+ materialize(gJoinT3)
+ g.unpersist()
+ g = gJoinT3
- (g, u)
+ // Convert DoubleMatrix to Array[Double]:
+ val newVertices = g.vertices.mapValues(v => (v._1.toArray, v._2.toArray, v._3, v._4))
+ (Graph(newVertices, g.edges), u)
}
+
+ /**
+ * Forces materialization of a Graph by count()ing its RDDs.
+ */
+ private def materialize(g: Graph[_,_]): Unit = {
+ g.vertices.count()
+ g.edges.count()
+ }
+
}
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
index 590f0474957dd..179f2843818e0 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala
@@ -61,8 +61,8 @@ object ShortestPaths {
}
def sendMessage(edge: EdgeTriplet[SPMap, _]): Iterator[(VertexId, SPMap)] = {
- val newAttr = incrementMap(edge.srcAttr)
- if (edge.dstAttr != addMaps(newAttr, edge.dstAttr)) Iterator((edge.dstId, newAttr))
+ val newAttr = incrementMap(edge.dstAttr)
+ if (edge.srcAttr != addMaps(newAttr, edge.srcAttr)) Iterator((edge.srcId, newAttr))
else Iterator.empty
}
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
index e01df56e94de9..9987a4b1a3c25 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/lib/SVDPlusPlusSuite.scala
@@ -32,7 +32,7 @@ class SVDPlusPlusSuite extends FunSuite with LocalSparkContext {
Edge(fields(0).toLong * 2, fields(1).toLong * 2 + 1, fields(2).toDouble)
}
val conf = new SVDPlusPlus.Conf(10, 2, 0.0, 5.0, 0.007, 0.007, 0.005, 0.015) // 2 iterations
- var (graph, u) = SVDPlusPlus.run(edges, conf)
+ var (graph, u) = SVDPlusPlus.runSVDPlusPlus(edges, conf)
graph.cache()
val err = graph.vertices.collect().map{ case (vid, vd) =>
if (vid % 2 == 1) vd._4 else 0.0
diff --git a/make-distribution.sh b/make-distribution.sh
index 051c87c0894ae..dd990d4b96e46 100755
--- a/make-distribution.sh
+++ b/make-distribution.sh
@@ -98,9 +98,9 @@ done
if [ -z "$JAVA_HOME" ]; then
# Fall back on JAVA_HOME from rpm, if found
if [ $(command -v rpm) ]; then
- RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null)
+ RPM_JAVA_HOME="$(rpm -E %java_home 2>/dev/null)"
if [ "$RPM_JAVA_HOME" != "%java_home" ]; then
- JAVA_HOME=$RPM_JAVA_HOME
+ JAVA_HOME="$RPM_JAVA_HOME"
echo "No JAVA_HOME set, proceeding with '$JAVA_HOME' learned from rpm"
fi
fi
@@ -113,24 +113,24 @@ fi
if [ $(command -v git) ]; then
GITREV=$(git rev-parse --short HEAD 2>/dev/null || :)
- if [ ! -z $GITREV ]; then
+ if [ ! -z "$GITREV" ]; then
GITREVSTRING=" (git revision $GITREV)"
fi
unset GITREV
fi
-if [ ! $(command -v $MVN) ] ; then
+if [ ! $(command -v "$MVN") ] ; then
echo -e "Could not locate Maven command: '$MVN'."
echo -e "Specify the Maven command with the --mvn flag"
exit -1;
fi
-VERSION=$($MVN help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1)
-SPARK_HADOOP_VERSION=$($MVN help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\
+VERSION=$("$MVN" help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1)
+SPARK_HADOOP_VERSION=$("$MVN" help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\
| grep -v "INFO"\
| tail -n 1)
-SPARK_HIVE=$($MVN help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\
+SPARK_HIVE=$("$MVN" help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\
| grep -v "INFO"\
| fgrep --count "hive";\
# Reset exit status to 0, otherwise the script stops here if the last grep finds nothing\
@@ -147,7 +147,7 @@ if [[ ! "$JAVA_VERSION" =~ "1.6" && -z "$SKIP_JAVA_TEST" ]]; then
echo "Output from 'java -version' was:"
echo "$JAVA_VERSION"
read -p "Would you like to continue anyways? [y,n]: " -r
- if [[ ! $REPLY =~ ^[Yy]$ ]]; then
+ if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then
echo "Okay, exiting."
exit 1
fi
@@ -232,7 +232,7 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR"
if [ "$SPARK_TACHYON" == "true" ]; then
TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'`
- pushd $TMPD > /dev/null
+ pushd "$TMPD" > /dev/null
echo "Fetching tachyon tgz"
TACHYON_DL="${TACHYON_TGZ}.part"
@@ -259,7 +259,7 @@ if [ "$SPARK_TACHYON" == "true" ]; then
fi
popd > /dev/null
- rm -rf $TMPD
+ rm -rf "$TMPD"
fi
if [ "$MAKE_TGZ" == "true" ]; then
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index bb291e6e1fd7d..5bbcd2e080e07 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml
import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
abstract class PipelineStage extends Serializable with Logging {
/**
+ * :: DeveloperAPI ::
+ *
* Derives the output schema from the input schema and parameters.
+ * The schema describes the columns and types of the data.
+ *
+ * @param schema Input schema to this stage
+ * @param paramMap Parameters passed to this stage
+ * @return Output schema from this stage
*/
- private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+ @DeveloperApi
+ def transformSchema(schema: StructType, paramMap: ParamMap): StructType
/**
* Derives the output schema from the input schema and parameters, optionally with logging.
@@ -114,7 +122,9 @@ class Pipeline extends Estimator[PipelineModel] {
throw new IllegalArgumentException(
s"Do not support stage $stage of type ${stage.getClass}")
}
- curDataset = transformer.transform(curDataset, paramMap)
+ if (index < indexOfLastEstimator) {
+ curDataset = transformer.transform(curDataset, paramMap)
+ }
transformers += transformer
} else {
transformers += stage.asInstanceOf[Transformer]
@@ -124,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
new PipelineModel(this, map, transformers.toArray)
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
@@ -169,7 +179,7 @@ class PipelineModel private[ml] (
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index cd95c16aa768d..9a5848684b179 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -23,7 +23,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
@@ -62,7 +62,10 @@ abstract class Transformer extends PipelineStage with Params {
private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
extends Transformer with HasInputCol with HasOutputCol with Logging {
+ /** @group setParam */
def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
+
+ /** @group setParam */
def setOutputCol(value: String): T = set(outputCol, value).asInstanceOf[T]
/**
@@ -97,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
- dataset.select($"*", callUDF(
- this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol)))
+ dataset.withColumn(map(outputCol),
+ callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol))))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 1bf8eb4640d11..c5fc89f935432 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -66,6 +66,7 @@ private[spark] abstract class Classifier[
extends Predictor[FeaturesType, E, M]
with ClassifierParams {
+ /** @group setParam */
def setRawPredictionCol(value: String): E =
set(rawPredictionCol, value).asInstanceOf[E]
@@ -87,6 +88,7 @@ private[spark]
abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
extends PredictionModel[FeaturesType, M] with ClassifierParams {
+ /** @group setParam */
def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M]
/** Number of classes (values which the label can take). */
@@ -180,24 +182,22 @@ private[ml] object ClassificationModel {
if (map(model.rawPredictionCol) != "") {
// output raw prediction
val features2raw: FeaturesType => Vector = model.predictRaw
- tmpData = tmpData.select($"*",
- callUDF(features2raw, new VectorUDT,
- col(map(model.featuresCol))).as(map(model.rawPredictionCol)))
+ tmpData = tmpData.withColumn(map(model.rawPredictionCol),
+ callUDF(features2raw, new VectorUDT, col(map(model.featuresCol))))
numColsOutput += 1
if (map(model.predictionCol) != "") {
val raw2pred: Vector => Double = (rawPred) => {
rawPred.toArray.zipWithIndex.maxBy(_._1)._2
}
- tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType,
- col(map(model.rawPredictionCol))).as(map(model.predictionCol)))
+ tmpData = tmpData.withColumn(map(model.predictionCol),
+ callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol))))
numColsOutput += 1
}
} else if (map(model.predictionCol) != "") {
// output prediction
val features2pred: FeaturesType => Double = model.predict
- tmpData = tmpData.select($"*",
- callUDF(features2pred, DoubleType,
- col(map(model.featuresCol))).as(map(model.predictionCol)))
+ tmpData = tmpData.withColumn(map(model.predictionCol),
+ callUDF(features2pred, DoubleType, col(map(model.featuresCol))))
numColsOutput += 1
}
(numColsOutput, tmpData)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index c146fe244c66e..21f61d80dd95a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
@@ -49,8 +49,13 @@ class LogisticRegression
setMaxIter(100)
setThreshold(0.5)
+ /** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
@@ -93,6 +98,7 @@ class LogisticRegressionModel private[ml] (
setThreshold(0.5)
+ /** @group setParam */
def setThreshold(value: Double): this.type = set(threshold, value)
private val margin: Vector => Double = (features) => {
@@ -124,44 +130,39 @@ class LogisticRegressionModel private[ml] (
var numColsOutput = 0
if (map(rawPredictionCol) != "") {
val features2raw: Vector => Vector = (features) => predictRaw(features)
- tmpData = tmpData.select($"*",
- callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol)))
+ tmpData = tmpData.withColumn(map(rawPredictionCol),
+ callUDF(features2raw, new VectorUDT, col(map(featuresCol))))
numColsOutput += 1
}
if (map(probabilityCol) != "") {
if (map(rawPredictionCol) != "") {
- val raw2prob: Vector => Vector = { (rawPreds: Vector) =>
+ val raw2prob = udf { (rawPreds: Vector) =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
- Vectors.dense(1.0 - prob1, prob1)
+ Vectors.dense(1.0 - prob1, prob1): Vector
}
- tmpData = tmpData.select($"*",
- callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol)))
+ tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol))))
} else {
- val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features)
- tmpData = tmpData.select($"*",
- callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector }
+ tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol))))
}
numColsOutput += 1
}
if (map(predictionCol) != "") {
val t = map(threshold)
if (map(probabilityCol) != "") {
- val predict: Vector => Double = { probs: Vector =>
+ val predict = udf { probs: Vector =>
if (probs(1) > t) 1.0 else 0.0
}
- tmpData = tmpData.select($"*",
- callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol)))
+ tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol))))
} else if (map(rawPredictionCol) != "") {
- val predict: Vector => Double = { rawPreds: Vector =>
+ val predict = udf { rawPreds: Vector =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
if (prob1 > t) 1.0 else 0.0
}
- tmpData = tmpData.select($"*",
- callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol)))
+ tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol))))
} else {
- val predict: Vector => Double = (features: Vector) => this.predict(features)
- tmpData = tmpData.select($"*",
- callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ val predict = udf { features: Vector => this.predict(features) }
+ tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol))))
}
numColsOutput += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 1202528ca654e..bd8caac855981 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -61,6 +61,7 @@ private[spark] abstract class ProbabilisticClassifier[
M <: ProbabilisticClassificationModel[FeaturesType, M]]
extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams {
+ /** @group setParam */
def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
}
@@ -82,6 +83,7 @@ private[spark] abstract class ProbabilisticClassificationModel[
M <: ProbabilisticClassificationModel[FeaturesType, M]]
extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
+ /** @group setParam */
def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
/**
@@ -120,8 +122,8 @@ private[spark] abstract class ProbabilisticClassificationModel[
val features2probs: FeaturesType => Vector = (features) => {
tmpModel.predictProbabilities(features)
}
- outputData.select($"*",
- callUDF(features2probs, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ outputData.withColumn(map(probabilityCol),
+ callUDF(features2probs, new VectorUDT, col(map(featuresCol))))
} else {
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index f21a30627e540..2360f4479f1c2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -35,13 +35,23 @@ import org.apache.spark.sql.types.DoubleType
class BinaryClassificationEvaluator extends Evaluator with Params
with HasRawPredictionCol with HasLabelCol {
- /** param for metric name in evaluation */
+ /**
+ * param for metric name in evaluation
+ * @group param
+ */
val metricName: Param[String] = new Param(this, "metricName",
"metric name in evaluation (areaUnderROC|areaUnderPR)", Some("areaUnderROC"))
+
+ /** @group getParam */
def getMetricName: String = get(metricName)
+
+ /** @group setParam */
def setMetricName(value: String): this.type = set(metricName, value)
+ /** @group setParam */
def setScoreCol(value: String): this.type = set(rawPredictionCol, value)
+
+ /** @group setParam */
def setLabelCol(value: String): this.type = set(labelCol, value)
override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 0956062643f23..6131ba8832691 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -31,11 +31,18 @@ import org.apache.spark.sql.types.DataType
@AlphaComponent
class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
- /** number of features */
+ /**
+ * number of features
+ * @group param
+ */
val numFeatures = new IntParam(this, "numFeatures", "number of features", Some(1 << 18))
- def setNumFeatures(value: Int) = set(numFeatures, value)
+
+ /** @group getParam */
def getNumFeatures: Int = get(numFeatures)
+ /** @group setParam */
+ def setNumFeatures(value: Int) = set(numFeatures, value)
+
override protected def createTransformFunc(paramMap: ParamMap): Iterable[_] => Vector = {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
hashingTF.transform
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 4745a7ae95679..1142aa4f8e73d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
/**
@@ -39,7 +39,10 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
@AlphaComponent
class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerParams {
+ /** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
@@ -52,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
model
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
@@ -75,17 +78,20 @@ class StandardScalerModel private[ml] (
scaler: feature.StandardScalerModel)
extends Model[StandardScalerModel] with StandardScalerParams {
+ /** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val scale = udf((v: Vector) => { scaler.transform(v) } : Vector)
- dataset.select($"*", scale(col(map(inputCol))).as(map(outputCol)))
+ dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
val inputType = schema(map(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index 89b53f3890ea3..dfb89cc8d4af3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -24,7 +24,7 @@ import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -85,8 +85,13 @@ private[spark] abstract class Predictor[
M <: PredictionModel[FeaturesType, M]]
extends Estimator[M] with PredictorParams {
+ /** @group setParam */
def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
+
+ /** @group setParam */
def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
+
+ /** @group setParam */
def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
override def fit(dataset: DataFrame, paramMap: ParamMap): M = {
@@ -127,7 +132,7 @@ private[spark] abstract class Predictor[
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
}
@@ -160,8 +165,10 @@ private[spark] abstract class Predictor[
private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
extends Model[M] with PredictorParams {
+ /** @group setParam */
def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
+ /** @group setParam */
def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
/**
@@ -177,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
@DeveloperApi
protected def featuresDataType: DataType = new VectorUDT
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
}
@@ -209,7 +216,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
val pred: FeaturesType => Double = (features) => {
tmpModel.predict(features)
}
- dataset.select($"*", callUDF(pred, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ dataset.withColumn(map(predictionCol), callUDF(pred, DoubleType, col(map(featuresCol))))
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/package.scala b/mllib/src/main/scala/org/apache/spark/ml/package.scala
index 51cd48c90432a..b45bd1499b72e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/package.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/package.scala
@@ -20,5 +20,19 @@ package org.apache.spark
/**
* Spark ML is an ALPHA component that adds a new set of machine learning APIs to let users quickly
* assemble and configure practical machine learning pipelines.
+ *
+ * @groupname param Parameters
+ * @groupdesc param A list of (hyper-)parameter keys this algorithm can take. Users can set and get
+ * the parameter values through setters and getters, respectively.
+ * @groupprio param -5
+ *
+ * @groupname setParam Parameter setters
+ * @groupprio setParam 5
+ *
+ * @groupname getParam Parameter getters
+ * @groupprio getParam 6
+ *
+ * @groupname Ungrouped Members
+ * @groupprio Ungrouped 0
*/
package object ml
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index 32fc74462ef4a..1a70322b4cace 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -24,67 +24,117 @@ package org.apache.spark.ml.param
*/
private[ml] trait HasRegParam extends Params {
- /** param for regularization parameter */
+ /**
+ * param for regularization parameter
+ * @group param
+ */
val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
+
+ /** @group getParam */
def getRegParam: Double = get(regParam)
}
private[ml] trait HasMaxIter extends Params {
- /** param for max number of iterations */
+ /**
+ * param for max number of iterations
+ * @group param
+ */
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+
+ /** @group getParam */
def getMaxIter: Int = get(maxIter)
}
private[ml] trait HasFeaturesCol extends Params {
- /** param for features column name */
+ /**
+ * param for features column name
+ * @group param
+ */
val featuresCol: Param[String] =
new Param(this, "featuresCol", "features column name", Some("features"))
+
+ /** @group getParam */
def getFeaturesCol: String = get(featuresCol)
}
private[ml] trait HasLabelCol extends Params {
- /** param for label column name */
+ /**
+ * param for label column name
+ * @group param
+ */
val labelCol: Param[String] = new Param(this, "labelCol", "label column name", Some("label"))
+
+ /** @group getParam */
def getLabelCol: String = get(labelCol)
}
private[ml] trait HasPredictionCol extends Params {
- /** param for prediction column name */
+ /**
+ * param for prediction column name
+ * @group param
+ */
val predictionCol: Param[String] =
new Param(this, "predictionCol", "prediction column name", Some("prediction"))
+
+ /** @group getParam */
def getPredictionCol: String = get(predictionCol)
}
private[ml] trait HasRawPredictionCol extends Params {
- /** param for raw prediction column name */
+ /**
+ * param for raw prediction column name
+ * @group param
+ */
val rawPredictionCol: Param[String] =
new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
Some("rawPrediction"))
+
+ /** @group getParam */
def getRawPredictionCol: String = get(rawPredictionCol)
}
private[ml] trait HasProbabilityCol extends Params {
- /** param for predicted class conditional probabilities column name */
+ /**
+ * param for predicted class conditional probabilities column name
+ * @group param
+ */
val probabilityCol: Param[String] =
new Param(this, "probabilityCol", "column name for predicted class conditional probabilities",
Some("probability"))
+
+ /** @group getParam */
def getProbabilityCol: String = get(probabilityCol)
}
private[ml] trait HasThreshold extends Params {
- /** param for threshold in (binary) prediction */
+ /**
+ * param for threshold in (binary) prediction
+ * @group param
+ */
val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")
+
+ /** @group getParam */
def getThreshold: Double = get(threshold)
}
private[ml] trait HasInputCol extends Params {
- /** param for input column name */
+ /**
+ * param for input column name
+ * @group param
+ */
val inputCol: Param[String] = new Param(this, "inputCol", "input column name")
+
+ /** @group getParam */
def getInputCol: String = get(inputCol)
}
private[ml] trait HasOutputCol extends Params {
- /** param for output column name */
+ /**
+ * param for output column name
+ * @group param
+ */
val outputCol: Param[String] = new Param(this, "outputCol", "output column name")
+
+ /** @group getParam */
def getOutputCol: String = get(outputCol)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index bf5737177ceee..7bb69df65362b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -29,14 +29,14 @@ import com.github.fommil.netlib.LAPACK.{getInstance => lapack}
import org.jblas.DoubleMatrix
import org.netlib.util.intW
-import org.apache.spark.{HashPartitioner, Logging, Partitioner}
+import org.apache.spark.{Logging, Partitioner}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -49,43 +49,89 @@ import org.apache.spark.util.random.XORShiftRandom
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
with HasPredictionCol {
- /** Param for rank of the matrix factorization. */
+ /**
+ * Param for rank of the matrix factorization.
+ * @group param
+ */
val rank = new IntParam(this, "rank", "rank of the factorization", Some(10))
+
+ /** @group getParam */
def getRank: Int = get(rank)
- /** Param for number of user blocks. */
+ /**
+ * Param for number of user blocks.
+ * @group param
+ */
val numUserBlocks = new IntParam(this, "numUserBlocks", "number of user blocks", Some(10))
+
+ /** @group getParam */
def getNumUserBlocks: Int = get(numUserBlocks)
- /** Param for number of item blocks. */
+ /**
+ * Param for number of item blocks.
+ * @group param
+ */
val numItemBlocks =
new IntParam(this, "numItemBlocks", "number of item blocks", Some(10))
+
+ /** @group getParam */
def getNumItemBlocks: Int = get(numItemBlocks)
- /** Param to decide whether to use implicit preference. */
+ /**
+ * Param to decide whether to use implicit preference.
+ * @group param
+ */
val implicitPrefs =
new BooleanParam(this, "implicitPrefs", "whether to use implicit preference", Some(false))
+
+ /** @group getParam */
def getImplicitPrefs: Boolean = get(implicitPrefs)
- /** Param for the alpha parameter in the implicit preference formulation. */
+ /**
+ * Param for the alpha parameter in the implicit preference formulation.
+ * @group param
+ */
val alpha = new DoubleParam(this, "alpha", "alpha for implicit preference", Some(1.0))
+
+ /** @group getParam */
def getAlpha: Double = get(alpha)
- /** Param for the column name for user ids. */
+ /**
+ * Param for the column name for user ids.
+ * @group param
+ */
val userCol = new Param[String](this, "userCol", "column name for user ids", Some("user"))
+
+ /** @group getParam */
def getUserCol: String = get(userCol)
- /** Param for the column name for item ids. */
+ /**
+ * Param for the column name for item ids.
+ * @group param
+ */
val itemCol =
new Param[String](this, "itemCol", "column name for item ids", Some("item"))
+
+ /** @group getParam */
def getItemCol: String = get(itemCol)
- /** Param for the column name for ratings. */
+ /**
+ * Param for the column name for ratings.
+ * @group param
+ */
val ratingCol = new Param[String](this, "ratingCol", "column name for ratings", Some("rating"))
+
+ /** @group getParam */
def getRatingCol: String = get(ratingCol)
+ /**
+ * Param for whether to apply nonnegativity constraints.
+ * @group param
+ */
val nonnegative = new BooleanParam(
this, "nonnegative", "whether to use nonnegative constraint for least squares", Some(false))
+
+ /** @group getParam */
val getNonnegative: Boolean = get(nonnegative)
/**
@@ -124,8 +170,8 @@ class ALSModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext.implicits._
val map = this.paramMap ++ paramMap
- val users = userFactors.toDataFrame("id", "features")
- val items = itemFactors.toDataFrame("id", "features")
+ val users = userFactors.toDF("id", "features")
+ val items = itemFactors.toDF("id", "features")
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
@@ -142,7 +188,7 @@ class ALSModel private[ml] (
.select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
}
- override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
@@ -181,20 +227,46 @@ class ALS extends Estimator[ALSModel] with ALSParams {
import org.apache.spark.ml.recommendation.ALS.Rating
+ /** @group setParam */
def setRank(value: Int): this.type = set(rank, value)
+
+ /** @group setParam */
def setNumUserBlocks(value: Int): this.type = set(numUserBlocks, value)
+
+ /** @group setParam */
def setNumItemBlocks(value: Int): this.type = set(numItemBlocks, value)
+
+ /** @group setParam */
def setImplicitPrefs(value: Boolean): this.type = set(implicitPrefs, value)
+
+ /** @group setParam */
def setAlpha(value: Double): this.type = set(alpha, value)
+
+ /** @group setParam */
def setUserCol(value: String): this.type = set(userCol, value)
+
+ /** @group setParam */
def setItemCol(value: String): this.type = set(itemCol, value)
+
+ /** @group setParam */
def setRatingCol(value: String): this.type = set(ratingCol, value)
+
+ /** @group setParam */
def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** @group setParam */
def setNonnegative(value: Boolean): this.type = set(nonnegative, value)
- /** Sets both numUserBlocks and numItemBlocks to the specific value. */
+ /**
+ * Sets both numUserBlocks and numItemBlocks to the specific value.
+ * @group setParam
+ */
def setNumBlocks(value: Int): this.type = {
setNumUserBlocks(value)
setNumItemBlocks(value)
@@ -220,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
model
}
- override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
validateAndTransformSchema(schema, paramMap)
}
}
@@ -429,8 +501,8 @@ object ALS extends Logging {
require(intermediateRDDStorageLevel != StorageLevel.NONE,
"ALS is not designed to run without persisting intermediate RDDs.")
val sc = ratings.sparkContext
- val userPart = new HashPartitioner(numUserBlocks)
- val itemPart = new HashPartitioner(numItemBlocks)
+ val userPart = new ALSPartitioner(numUserBlocks)
+ val itemPart = new ALSPartitioner(numItemBlocks)
val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions)
val itemLocalIndexEncoder = new LocalIndexEncoder(itemPart.numPartitions)
val solver = if (nonnegative) new NNLSSolver else new CholeskySolver
@@ -478,13 +550,23 @@ object ALS extends Logging {
val userIdAndFactors = userInBlocks
.mapValues(_.srcIds)
.join(userFactors)
- .values
+ .mapPartitions({ items =>
+ items.flatMap { case (_, (ids, factors)) =>
+ ids.view.zip(factors)
+ }
+ // Preserve the partitioning because IDs are consistent with the partitioners in userInBlocks
+ // and userFactors.
+ }, preservesPartitioning = true)
.setName("userFactors")
.persist(finalRDDStorageLevel)
val itemIdAndFactors = itemInBlocks
.mapValues(_.srcIds)
.join(itemFactors)
- .values
+ .mapPartitions({ items =>
+ items.flatMap { case (_, (ids, factors)) =>
+ ids.view.zip(factors)
+ }
+ }, preservesPartitioning = true)
.setName("itemFactors")
.persist(finalRDDStorageLevel)
if (finalRDDStorageLevel != StorageLevel.NONE) {
@@ -497,13 +579,7 @@ object ALS extends Logging {
itemOutBlocks.unpersist()
blockRatings.unpersist()
}
- val userOutput = userIdAndFactors.flatMap { case (ids, factors) =>
- ids.view.zip(factors)
- }
- val itemOutput = itemIdAndFactors.flatMap { case (ids, factors) =>
- ids.view.zip(factors)
- }
- (userOutput, itemOutput)
+ (userIdAndFactors, itemIdAndFactors)
}
/**
@@ -923,15 +999,15 @@ object ALS extends Logging {
"Converting to local indices took " + (System.nanoTime() - start) / 1e9 + " seconds.")
val dstLocalIndices = dstIds.map(dstIdToLocalIndex.apply)
(srcBlockId, (dstBlockId, srcIds, dstLocalIndices, ratings))
- }.groupByKey(new HashPartitioner(srcPart.numPartitions))
- .mapValues { iter =>
- val builder =
- new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions))
- iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
- builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
- }
- builder.build().compress()
- }.setName(prefix + "InBlocks")
+ }.groupByKey(new ALSPartitioner(srcPart.numPartitions))
+ .mapValues { iter =>
+ val builder =
+ new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions))
+ iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) =>
+ builder.add(dstBlockId, srcIds, dstLocalIndices, ratings)
+ }
+ builder.build().compress()
+ }.setName(prefix + "InBlocks")
.persist(storageLevel)
val outBlocks = inBlocks.mapValues { case InBlock(srcIds, dstPtrs, dstEncodedIndices, _) =>
val encoder = new LocalIndexEncoder(dstPart.numPartitions)
@@ -992,7 +1068,7 @@ object ALS extends Logging {
(dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx))))
}
}
- val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.length))
+ val merged = srcOut.groupByKey(new ALSPartitioner(dstInBlocks.partitions.length))
dstInBlocks.join(merged).mapValues {
case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) =>
val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks)
@@ -1077,4 +1153,11 @@ object ALS extends Logging {
encoded & localIndexMask
}
}
+
+ /**
+ * Partitioner used by ALS. We requires that getPartition is a projection. That is, for any key k,
+ * we have getPartition(getPartition(k)) = getPartition(k). Since the the default HashPartitioner
+ * satisfies this requirement, we simply use a type alias here.
+ */
+ private[recommendation] type ALSPartitioner = org.apache.spark.HashPartitioner
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index d5a7bdafcb623..65f6627a0c351 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -44,7 +44,10 @@ class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegress
setRegParam(0.1)
setMaxIter(100)
+ /** @group setParam */
def setRegParam(value: Double): this.type = set(regParam, value)
+
+ /** @group setParam */
def setMaxIter(value: Int): this.type = set(maxIter, value)
override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 5d51c51346665..2eb1dac56f1e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -31,22 +31,42 @@ import org.apache.spark.sql.types.StructType
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
private[ml] trait CrossValidatorParams extends Params {
- /** param for the estimator to be cross-validated */
+ /**
+ * param for the estimator to be cross-validated
+ * @group param
+ */
val estimator: Param[Estimator[_]] = new Param(this, "estimator", "estimator for selection")
+
+ /** @group getParam */
def getEstimator: Estimator[_] = get(estimator)
- /** param for estimator param maps */
+ /**
+ * param for estimator param maps
+ * @group param
+ */
val estimatorParamMaps: Param[Array[ParamMap]] =
new Param(this, "estimatorParamMaps", "param maps for the estimator")
+
+ /** @group getParam */
def getEstimatorParamMaps: Array[ParamMap] = get(estimatorParamMaps)
- /** param for the evaluator for selection */
+ /**
+ * param for the evaluator for selection
+ * @group param
+ */
val evaluator: Param[Evaluator] = new Param(this, "evaluator", "evaluator for selection")
+
+ /** @group getParam */
def getEvaluator: Evaluator = get(evaluator)
- /** param for number of folds for cross validation */
+ /**
+ * param for number of folds for cross validation
+ * @group param
+ */
val numFolds: IntParam =
new IntParam(this, "numFolds", "number of folds for cross validation", Some(3))
+
+ /** @group getParam */
def getNumFolds: Int = get(numFolds)
}
@@ -59,9 +79,16 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
private val f2jBLAS = new F2jBLAS
+ /** @group setParam */
def setEstimator(value: Estimator[_]): this.type = set(estimator, value)
+
+ /** @group setParam */
def setEstimatorParamMaps(value: Array[ParamMap]): this.type = set(estimatorParamMaps, value)
+
+ /** @group setParam */
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
+
+ /** @group setParam */
def setNumFolds(value: Int): this.type = set(numFolds, value)
override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {
@@ -76,11 +103,12 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
val metrics = new Array[Double](epm.size)
val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
- val trainingDataset = sqlCtx.applySchema(training, schema).cache()
- val validationDataset = sqlCtx.applySchema(validation, schema).cache()
+ val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
+ val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
+ trainingDataset.unpersist()
var i = 0
while (i < numModels) {
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map)
@@ -88,6 +116,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
metrics(i) += metric
i += 1
}
+ validationDataset.unpersist()
}
f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1)
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
@@ -100,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
cvModel
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
val map = this.paramMap ++ paramMap
map(estimator).transformSchema(schema, paramMap)
}
@@ -121,7 +150,7 @@ class CrossValidatorModel private[ml] (
bestModel.transform(dataset, paramMap)
}
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
bestModel.transformSchema(schema, paramMap)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
index 348c1e8760a66..35a0db76f3a8c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/ClassificationModel.scala
@@ -17,12 +17,12 @@
package org.apache.spark.mllib.classification
+import org.json4s.{DefaultFormats, JValue}
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
/**
* :: Experimental ::
@@ -60,16 +60,10 @@ private[mllib] object ClassificationModel {
/**
* Helper method for loading GLM classification model metadata.
- *
- * @param modelClass String name for model class (used for error messages)
* @return (numFeatures, numClasses)
*/
- def getNumFeaturesClasses(metadata: DataFrame, modelClass: String, path: String): (Int, Int) = {
- metadata.select("numFeatures", "numClasses").take(1)(0) match {
- case Row(nFeatures: Int, nClasses: Int) => (nFeatures, nClasses)
- case _ => throw new Exception(s"$modelClass unable to load" +
- s" numFeatures, numClasses from metadata: ${Loader.metadataPath(path)}")
- }
+ def getNumFeaturesClasses(metadata: JValue): (Int, Int) = {
+ implicit val formats = DefaultFormats
+ ((metadata \ "numFeatures").extract[Int], (metadata \ "numClasses").extract[Int])
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 9a391bfff76a3..b787667b018e6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -173,8 +173,7 @@ object LogisticRegressionModel extends Loader[LogisticRegressionModel] {
val classNameV1_0 = "org.apache.spark.mllib.classification.LogisticRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val (numFeatures, numClasses) =
- ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
// numFeatures, numClasses, weights are checked in model initialization
val model =
@@ -356,6 +355,10 @@ class LogisticRegressionWithLBFGS
}
override protected def createModel(weights: Vector, intercept: Double) = {
- new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
+ if (numOfLinearPredictor == 1) {
+ new LogisticRegressionModel(weights, intercept)
+ } else {
+ new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
+ }
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index d9ce2822dd391..b11fd4f128c56 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -18,15 +18,16 @@
package org.apache.spark.mllib.classification
import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
-import org.apache.spark.{SparkContext, SparkException, Logging}
+import org.apache.spark.{Logging, SparkContext, SparkException}
import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, SQLContext}
-
/**
* Model for Naive Bayes Classifiers.
*
@@ -78,7 +79,7 @@ class NaiveBayesModel private[mllib] (
object NaiveBayesModel extends Loader[NaiveBayesModel] {
- import Loader._
+ import org.apache.spark.mllib.util.Loader._
private object SaveLoadV1_0 {
@@ -95,13 +96,13 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
import sqlContext.implicits._
// Create JSON metadata.
- val metadataRDD =
- sc.parallelize(Seq((thisClassName, thisFormatVersion, data.theta(0).size, data.pi.size)), 1)
- .toDataFrame("class", "version", "numFeatures", "numClasses")
- metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("numFeatures" -> data.theta(0).length) ~ ("numClasses" -> data.pi.length)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
// Create Parquet data.
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
dataRDD.saveAsParquetFile(dataPath(path))
}
@@ -126,8 +127,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val (numFeatures, numClasses) =
- ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val model = SaveLoadV1_0.load(sc, path)
assert(model.pi.size == numClasses,
s"NaiveBayesModel.load expected $numClasses classes," +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 24d31e62ba500..cfc7f868a02f0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -23,10 +23,9 @@ import org.apache.spark.mllib.classification.impl.GLMClassificationModel
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader}
+import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
import org.apache.spark.rdd.RDD
-
/**
* Model for Support Vector Machines (SVMs).
*
@@ -97,8 +96,7 @@ object SVMModel extends Loader[SVMModel] {
val classNameV1_0 = "org.apache.spark.mllib.classification.SVMModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val (numFeatures, numClasses) =
- ClassificationModel.getNumFeaturesClasses(metadata, classNameV1_0, path)
+ val (numFeatures, numClasses) = ClassificationModel.getNumFeaturesClasses(metadata)
val data = GLMClassificationModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0)
val model = new SVMModel(data.weights, data.intercept)
assert(model.weights.size == numFeatures, s"SVMModel.load with numFeatures=$numFeatures" +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 8d600572ed7f3..8956189ff1158 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -17,10 +17,13 @@
package org.apache.spark.mllib.classification.impl
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{Row, SQLContext}
/**
* Helper class for import/export of GLM classification models.
@@ -52,16 +55,14 @@ private[classification] object GLMClassificationModel {
import sqlContext.implicits._
// Create JSON metadata.
- val metadataRDD =
- sc.parallelize(Seq((modelClass, thisFormatVersion, numFeatures, numClasses)), 1)
- .toDataFrame("class", "version", "numFeatures", "numClasses")
- metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+ val metadata = compact(render(
+ ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
+ ("numFeatures" -> numFeatures) ~ ("numClasses" -> numClasses)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
// Create Parquet data.
val data = Data(weights, intercept, threshold)
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
- // TODO: repartition with 1 partition after SPARK-5532 gets fixed
- dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ sc.parallelize(Seq(data), 1).toDF().saveAsParquetFile(Loader.dataPath(path))
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
index 0be3014de862e..568b65305649f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala
@@ -19,10 +19,10 @@ package org.apache.spark.mllib.clustering
import scala.collection.mutable.IndexedSeq
-import breeze.linalg.{DenseMatrix => BreezeMatrix, DenseVector => BreezeVector, Transpose, diag}
+import breeze.linalg.{diag, DenseMatrix => BreezeMatrix, DenseVector => BDV, Vector => BV}
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, Matrices, Vector, Vectors}
+import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, Matrices, Vector, Vectors}
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
@@ -41,7 +41,11 @@ import org.apache.spark.util.Utils
* less than convergenceTol, or until it has reached the max number of iterations.
* While this process is generally guaranteed to converge, it is not guaranteed
* to find a global optimum.
- *
+ *
+ * Note: For high-dimensional data (with many features), this algorithm may perform poorly.
+ * This is due to high-dimensional data (a) making it difficult to cluster at all (based
+ * on statistical/theoretical arguments) and (b) numerical issues with Gaussian distributions.
+ *
* @param k The number of independent Gaussians in the mixture model
* @param convergenceTol The maximum change in log-likelihood at which convergence
* is considered to have occurred.
@@ -130,7 +134,7 @@ class GaussianMixture private (
val sc = data.sparkContext
// we will operate on the data as breeze data
- val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()
+ val breezeData = data.map(_.toBreeze).cache()
// Get length of the input vectors
val d = breezeData.first().length
@@ -148,7 +152,7 @@ class GaussianMixture private (
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
- })
+ })
}
}
@@ -169,7 +173,7 @@ class GaussianMixture private (
var i = 0
while (i < k) {
val mu = sums.means(i) / sums.weights(i)
- BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu).asInstanceOf[DenseVector],
+ BLAS.syr(-sums.weights(i), Vectors.fromBreeze(mu),
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
weights(i) = sums.weights(i) / sumWeights
gaussians(i) = new MultivariateGaussian(mu, sums.sigmas(i) / sums.weights(i))
@@ -185,8 +189,8 @@ class GaussianMixture private (
}
/** Average of dense breeze vectors */
- private def vectorMean(x: IndexedSeq[BreezeVector[Double]]): BreezeVector[Double] = {
- val v = BreezeVector.zeros[Double](x(0).length)
+ private def vectorMean(x: IndexedSeq[BV[Double]]): BDV[Double] = {
+ val v = BDV.zeros[Double](x(0).length)
x.foreach(xi => v += xi)
v / x.length.toDouble
}
@@ -195,10 +199,10 @@ class GaussianMixture private (
* Construct matrix where diagonal entries are element-wise
* variance of input vectors (computes biased variance)
*/
- private def initCovariance(x: IndexedSeq[BreezeVector[Double]]): BreezeMatrix[Double] = {
+ private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = {
val mu = vectorMean(x)
- val ss = BreezeVector.zeros[Double](x(0).length)
- x.map(xi => (xi - mu) :^ 2.0).foreach(u => ss += u)
+ val ss = BDV.zeros[Double](x(0).length)
+ x.foreach(xi => ss += (xi - mu) :^ 2.0)
diag(ss / x.length.toDouble)
}
}
@@ -207,7 +211,7 @@ class GaussianMixture private (
private object ExpectationSum {
def zero(k: Int, d: Int): ExpectationSum = {
new ExpectationSum(0.0, Array.fill(k)(0.0),
- Array.fill(k)(BreezeVector.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
+ Array.fill(k)(BDV.zeros(d)), Array.fill(k)(BreezeMatrix.zeros(d,d)))
}
// compute cluster contributions for each input point
@@ -215,19 +219,18 @@ private object ExpectationSum {
def add(
weights: Array[Double],
dists: Array[MultivariateGaussian])
- (sums: ExpectationSum, x: BreezeVector[Double]): ExpectationSum = {
+ (sums: ExpectationSum, x: BV[Double]): ExpectationSum = {
val p = weights.zip(dists).map {
case (weight, dist) => MLUtils.EPSILON + weight * dist.pdf(x)
}
val pSum = p.sum
sums.logLikelihood += math.log(pSum)
- val xxt = x * new Transpose(x)
var i = 0
while (i < sums.k) {
p(i) /= pSum
sums.weights(i) += p(i)
sums.means(i) += x * p(i)
- BLAS.syr(p(i), Vectors.fromBreeze(x).asInstanceOf[DenseVector],
+ BLAS.syr(p(i), Vectors.fromBreeze(x),
Matrices.fromBreeze(sums.sigmas(i)).asInstanceOf[DenseMatrix])
i = i + 1
}
@@ -239,7 +242,7 @@ private object ExpectationSum {
private class ExpectationSum(
var logLikelihood: Double,
val weights: Array[Double],
- val means: Array[BreezeVector[Double]],
+ val means: Array[BDV[Double]],
val sigmas: Array[BreezeMatrix[Double]]) extends Serializable {
val k = weights.length
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index 716cfd9e103c8..1453e4dac768e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -560,34 +560,23 @@ private[clustering] object LDA {
// Create vertices.
// Initially, we use random soft assignments of tokens to topics (random gamma).
- val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] =
- edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
- val random = new Random(partIndex + randomSeed)
- partEdges.map { edge =>
- // Create a random gamma_{wjk}
- (edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0))
+ def createVertices(): RDD[(VertexId, TopicCounts)] = {
+ val verticesTMP: RDD[(VertexId, TopicCounts)] =
+ edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
+ val random = new Random(partIndex + randomSeed)
+ partEdges.flatMap { edge =>
+ val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
+ val sum = gamma * edge.attr
+ Seq((edge.srcId, sum), (edge.dstId, sum))
+ }
}
- }
- def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = {
- val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] =
- edgesWithGamma.map { case (edge, gamma: TopicCounts) =>
- (sendToWhere(edge), (edge.attr, gamma))
- }
- verticesTMP.aggregateByKey(BDV.zeros[Double](k))(
- (sum, t) => {
- brzAxpy(t._1, t._2, sum)
- sum
- },
- (sum0, sum1) => {
- sum0 += sum1
- }
- )
+ verticesTMP.reduceByKey(_ + _)
}
- val docVertices = createVertices(_.srcId)
- val termVertices = createVertices(_.dstId)
+
+ val docTermVertices = createVertices()
// Partition such that edges are grouped by document
- val graph = Graph(docVertices ++ termVertices, edges)
+ val graph = Graph(docTermVertices, edges)
.partitionBy(PartitionStrategy.EdgePartition1D)
new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
index b0e991d2f2344..0a3f21ecee0dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala
@@ -130,7 +130,7 @@ abstract class LDAModel private[clustering] {
/* TODO
* Compute the estimated topic distribution for each document.
- * This is often called “theta” in the literature.
+ * This is often called 'theta' in the literature.
*
* @param documents RDD of documents, which are term (word) count vectors paired with IDs.
* The term count vectors are "bags of words" with a fixed-size vocabulary
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index 3b1caf0c679ef..180023922a9b0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.clustering
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.graphx._
import org.apache.spark.graphx.impl.GraphImpl
import org.apache.spark.mllib.linalg.Vectors
@@ -32,12 +33,12 @@ import org.apache.spark.util.random.XORShiftRandom
* Model produced by [[PowerIterationClustering]].
*
* @param k number of clusters
- * @param assignments an RDD of (vertexID, clusterID) pairs
+ * @param assignments an RDD of clustering [[PowerIterationClustering#Assignment]]s
*/
@Experimental
class PowerIterationClusteringModel(
val k: Int,
- val assignments: RDD[(Long, Int)]) extends Serializable
+ val assignments: RDD[PowerIterationClustering.Assignment]) extends Serializable
/**
* :: Experimental ::
@@ -115,6 +116,14 @@ class PowerIterationClustering private[clustering] (
pic(w0)
}
+ /**
+ * A Java-friendly version of [[PowerIterationClustering.run]].
+ */
+ def run(similarities: JavaRDD[(java.lang.Long, java.lang.Long, java.lang.Double)])
+ : PowerIterationClusteringModel = {
+ run(similarities.rdd.asInstanceOf[RDD[(Long, Long, Double)]])
+ }
+
/**
* Runs the PIC algorithm.
*
@@ -124,16 +133,33 @@ class PowerIterationClustering private[clustering] (
*/
private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = {
val v = powerIter(w, maxIterations)
- val assignments = kMeans(v, k)
+ val assignments = kMeans(v, k).mapPartitions({ iter =>
+ iter.map { case (id, cluster) =>
+ new Assignment(id, cluster)
+ }
+ }, preservesPartitioning = true)
new PowerIterationClusteringModel(k, assignments)
}
}
-private[clustering] object PowerIterationClustering extends Logging {
+@Experimental
+object PowerIterationClustering extends Logging {
+
+ /**
+ * :: Experimental ::
+ * Cluster assignment.
+ * @param id node id
+ * @param cluster assigned cluster id
+ */
+ @Experimental
+ class Assignment(val id: Long, val cluster: Int) extends Serializable
+
/**
* Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
*/
- def normalize(similarities: RDD[(Long, Long, Double)]): Graph[Double, Double] = {
+ private[clustering]
+ def normalize(similarities: RDD[(Long, Long, Double)])
+ : Graph[Double, Double] = {
val edges = similarities.flatMap { case (i, j, s) =>
if (s < 0.0) {
throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.")
@@ -164,6 +190,7 @@ private[clustering] object PowerIterationClustering extends Logging {
* @return a graph with edges representing W and vertices representing a random vector
* with unit 1-norm
*/
+ private[clustering]
def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = {
val r = g.vertices.mapPartitionsWithIndex(
(part, iter) => {
@@ -185,6 +212,7 @@ private[clustering] object PowerIterationClustering extends Logging {
* @param g a graph representing the normalized affinity matrix (W)
* @return a graph with edges representing W and vertices representing the degree vector
*/
+ private[clustering]
def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = {
val sum = g.vertices.values.sum()
val v0 = g.vertices.mapValues(_ / sum)
@@ -198,6 +226,7 @@ private[clustering] object PowerIterationClustering extends Logging {
* @param maxIterations maximum number of iterations
* @return a [[VertexRDD]] representing the pseudo-eigenvector
*/
+ private[clustering]
def powerIter(
g: Graph[Double, Double],
maxIterations: Int): VertexRDD[Double] = {
@@ -237,6 +266,7 @@ private[clustering] object PowerIterationClustering extends Logging {
* @param k number of clusters
* @return a [[VertexRDD]] representing the clustering assignments
*/
+ private[clustering]
def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = {
val points = v.mapValues(x => Vectors.dense(x)).cache()
val model = new KMeans()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index a3e40200bc063..59a79e5c6a4ac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -21,7 +21,7 @@ import java.lang.{Iterable => JavaIterable}
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.ArrayBuilder
import com.github.fommil.netlib.BLAS.{getInstance => blas}
@@ -272,7 +272,7 @@ class Word2Vec extends Serializable with Logging {
def hasNext: Boolean = iter.hasNext
def next(): Array[Int] = {
- var sentence = new ArrayBuffer[Int]
+ val sentence = ArrayBuilder.make[Int]
var sentenceLength = 0
while (iter.hasNext && sentenceLength < MAX_SENTENCE_LENGTH) {
val word = bcVocabHash.value.get(iter.next())
@@ -283,7 +283,7 @@ class Word2Vec extends Serializable with Logging {
case None =>
}
}
- sentence.toArray
+ sentence.result()
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 3168d608c9556..efa8459d3cdba 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -26,8 +26,9 @@ import scala.reflect.ClassTag
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
import org.apache.spark.annotation.Experimental
-import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -35,18 +36,11 @@ import org.apache.spark.storage.StorageLevel
* :: Experimental ::
*
* Model trained by [[FPGrowth]], which holds frequent itemsets.
- * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
+ * @param freqItemsets frequent itemset, which is an RDD of [[FreqItemset]]
* @tparam Item item type
*/
@Experimental
-class FPGrowthModel[Item: ClassTag](
- val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
-
- /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */
- def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = {
- JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]]
- }
-}
+class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable
/**
* :: Experimental ::
@@ -151,7 +145,7 @@ class FPGrowth private (
data: RDD[Array[Item]],
minCount: Long,
freqItems: Array[Item],
- partitioner: Partitioner): RDD[(Array[Item], Long)] = {
+ partitioner: Partitioner): RDD[FreqItemset[Item]] = {
val itemToRank = freqItems.zipWithIndex.toMap
data.flatMap { transaction =>
genCondTransactions(transaction, itemToRank, partitioner)
@@ -161,7 +155,7 @@ class FPGrowth private (
.flatMap { case (part, tree) =>
tree.extract(minCount, x => partitioner.getPartition(x) == part)
}.map { case (ranks, count) =>
- (ranks.map(i => freqItems(i)).toArray, count)
+ new FreqItemset(ranks.map(i => freqItems(i)).toArray, count)
}
}
@@ -193,3 +187,26 @@ class FPGrowth private (
output
}
}
+
+/**
+ * :: Experimental ::
+ */
+@Experimental
+object FPGrowth {
+
+ /**
+ * Frequent itemset.
+ * @param items items in this itemset. Java users should call [[FreqItemset#javaItems]] instead.
+ * @param freq frequency
+ * @tparam Item item type
+ */
+ class FreqItemset[Item](val items: Array[Item], val freq: Long) extends Serializable {
+
+ /**
+ * Returns items in a Java List.
+ */
+ def javaItems: java.util.List[Item] = {
+ items.toList.asJava
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 079f7ca564a92..87052e1ba8539 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -235,12 +235,24 @@ private[spark] object BLAS extends Serializable with Logging {
* @param x the vector x that contains the n elements.
* @param A the symmetric matrix A. Size of n x n.
*/
- def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+ def syr(alpha: Double, x: Vector, A: DenseMatrix) {
val mA = A.numRows
val nA = A.numCols
- require(mA == nA, s"A is not a symmetric matrix. A: $mA x $nA")
+ require(mA == nA, s"A is not a square matrix (and hence is not symmetric). A: $mA x $nA")
require(mA == x.size, s"The size of x doesn't match the rank of A. A: $mA x $nA, x: ${x.size}")
+ x match {
+ case dv: DenseVector => syr(alpha, dv, A)
+ case sv: SparseVector => syr(alpha, sv, A)
+ case _ =>
+ throw new IllegalArgumentException(s"syr doesn't support vector type ${x.getClass}.")
+ }
+ }
+
+ private def syr(alpha: Double, x: DenseVector, A: DenseMatrix) {
+ val nA = A.numRows
+ val mA = A.numCols
+
nativeBLAS.dsyr("U", x.size, alpha, x.values, 1, A.values, nA)
// Fill lower triangular part of A
@@ -255,6 +267,26 @@ private[spark] object BLAS extends Serializable with Logging {
}
}
+ private def syr(alpha: Double, x: SparseVector, A: DenseMatrix) {
+ val mA = A.numCols
+ val xIndices = x.indices
+ val xValues = x.values
+ val nnz = xValues.length
+ val Avalues = A.values
+
+ var i = 0
+ while (i < nnz) {
+ val multiplier = alpha * xValues(i)
+ val offset = xIndices(i) * mA
+ var j = 0
+ while (j < nnz) {
+ Avalues(xIndices(j) + offset) += multiplier * xValues(j)
+ j += 1
+ }
+ i += 1
+ }
+ }
+
/**
* C := alpha * A * B + beta * C
* @param alpha a scalar to scale the multiplication A * B.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
index 9d6f97528148e..866936aa4f118 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala
@@ -117,7 +117,7 @@ private[mllib] object EigenValueDecomposition {
info.`val` match {
case 1 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
" Maximum number of iterations taken. (Refer ARPACK user guide for details)")
- case 2 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
+ case 3 => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
" No shifts could be applied. Try to increase NCV. " +
"(Refer ARPACK user guide for details)")
case _ => throw new IllegalStateException("ARPACK returns non-zero info = " + info.`val` +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 89b38679b7494..0e4a4d0085895 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -706,7 +706,7 @@ object Matrices {
}
/**
- * Generate a `DenseMatrix` consisting of zeros.
+ * Generate a `Matrix` consisting of zeros.
* @param numRows number of rows of the matrix
* @param numCols number of columns of the matrix
* @return `Matrix` with size `numRows` x `numCols` and values of zeros
@@ -778,8 +778,8 @@ object Matrices {
SparseMatrix.sprandn(numRows, numCols, density, rng)
/**
- * Generate a diagonal matrix in `DenseMatrix` format from the supplied values.
- * @param vector a `Vector` tat will form the values on the diagonal of the matrix
+ * Generate a diagonal matrix in `Matrix` format from the supplied values.
+ * @param vector a `Vector` that will form the values on the diagonal of the matrix
* @return Square `Matrix` with size `values.length` x `values.length` and `values`
* on the diagonal
*/
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 480bbfb5fe94a..4bdcb283da09c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -247,7 +247,7 @@ object Vectors {
}
/**
- * Creates a dense vector of all zeros.
+ * Creates a vector of all zeros.
*
* @param size vector size
* @return a zero vector
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index 0acdab797e8f3..8bfa0d2b64995 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -63,10 +63,12 @@ abstract class Gradient extends Serializable {
* http://statweb.stanford.edu/~tibs/ElemStatLearn/ , Eq. (4.17) on page 119 gives the formula of
* multinomial logistic regression model. A simple calculation shows that
*
+ * {{{
* P(y=0|x, w) = 1 / (1 + \sum_i^{K-1} \exp(x w_i))
* P(y=1|x, w) = exp(x w_1) / (1 + \sum_i^{K-1} \exp(x w_i))
* ...
* P(y=K-1|x, w) = exp(x w_{K-1}) / (1 + \sum_i^{K-1} \exp(x w_i))
+ * }}}
*
* for K classes multiclass classification problem.
*
@@ -75,9 +77,11 @@ abstract class Gradient extends Serializable {
* will be (K-1) * N.
*
* As a result, the loss of objective function for a single instance of data can be written as
+ * {{{
* l(w, x) = -log P(y|x, w) = -\alpha(y) log P(y=0|x, w) - (1-\alpha(y)) log P(y|x, w)
* = log(1 + \sum_i^{K-1}\exp(x w_i)) - (1-\alpha(y)) x w_{y-1}
* = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1}
+ * }}}
*
* where \alpha(i) = 1 if i != 0, and
* \alpha(i) = 0 if i == 0,
@@ -86,14 +90,16 @@ abstract class Gradient extends Serializable {
* For optimization, we have to calculate the first derivative of the loss function, and
* a simple calculation shows that
*
+ * {{{
* \frac{\partial l(w, x)}{\partial w_{ij}}
* = (\exp(x w_i) / (1 + \sum_k^{K-1} \exp(x w_k)) - (1-\alpha(y)\delta_{y, i+1})) * x_j
* = multiplier_i * x_j
+ * }}}
*
* where \delta_{i, j} = 1 if i == j,
* \delta_{i, j} = 0 if i != j, and
- * multiplier
- * = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1})
+ * multiplier =
+ * \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1})
*
* If any of margins is larger than 709.78, the numerical computation of multiplier and loss
* function will be suffered from arithmetic overflow. This issue occurs when there are outliers
@@ -103,10 +109,12 @@ abstract class Gradient extends Serializable {
* Fortunately, when max(margins) = maxMargin > 0, the loss function and the multiplier can be
* easily rewritten into the following equivalent numerically stable formula.
*
+ * {{{
* l(w, x) = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1}
* = log(\exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin)) + maxMargin
* - (1-\alpha(y)) margins_{y-1}
* = log(1 + sum) + maxMargin - (1-\alpha(y)) margins_{y-1}
+ * }}}
*
* where sum = \exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin) - 1.
*
@@ -115,8 +123,10 @@ abstract class Gradient extends Serializable {
*
* For multiplier, similar trick can be applied as the following,
*
+ * {{{
* multiplier = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1})
* = \exp(margins_i - maxMargin) / (1 + sum) - (1-\alpha(y)\delta_{y, i+1})
+ * }}}
*
* where each term in \exp is also smaller than zero, so overflow is not a concern.
*
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 9ff06ac362a31..c399496568bfb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -22,6 +22,9 @@ import java.lang.{Integer => JavaInteger}
import org.apache.hadoop.fs.Path
import org.jblas.DoubleMatrix
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
@@ -153,7 +156,7 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
import org.apache.spark.mllib.util.Loader._
override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
- val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
+ val (loadedClassName, formatVersion, _) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
(loadedClassName, formatVersion) match {
case (className, "1.0") if className == classNameV1_0 =>
@@ -180,20 +183,21 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
def save(model: MatrixFactorizationModel, path: String): Unit = {
val sc = model.userFeatures.sparkContext
val sqlContext = new SQLContext(sc)
- import sqlContext.implicits.createDataFrame
- val metadata = (thisClassName, thisFormatVersion, model.rank)
- val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
- metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
- model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
- model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
+ import sqlContext.implicits._
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
+ model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path))
+ model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path))
}
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
+ implicit val formats = DefaultFormats
val sqlContext = new SQLContext(sc)
val (className, formatVersion, metadata) = loadMetadata(sc, path)
assert(className == thisClassName)
assert(formatVersion == thisFormatVersion)
- val rank = metadata.select("rank").first().getInt(0)
+ val rank = (metadata \ "rank").extract[Int]
val userFeatures = sqlContext.parquetFile(userPath(path))
.map { case Row(id: Int, features: Seq[Double]) =>
(id, features.toArray)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 17de215b97f9d..7c66e8cdebdbe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -126,7 +126,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
/**
* The dimension of training features.
*/
- protected var numFeatures: Int = 0
+ protected var numFeatures: Int = -1
/**
* Set if the algorithm should use feature scaling to improve the convergence during optimization.
@@ -163,7 +163,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* RDD of LabeledPoint entries.
*/
def run(input: RDD[LabeledPoint]): M = {
- numFeatures = input.first().features.size
+ if (numFeatures < 0) {
+ numFeatures = input.map(_.features.size).first()
+ }
/**
* When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights,
@@ -193,7 +195,6 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* of LabeledPoint entries starting from the initial weights provided.
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
- numFeatures = input.first().features.size
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
@@ -205,7 +206,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
throw new SparkException("Input validation failed.")
}
- /**
+ /*
* Scaling columns to unit variance as a heuristic to reduce the condition number:
*
* During the optimization process, the convergence (rate) depends on the condition number of
@@ -225,26 +226,27 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* Currently, it's only enabled in LogisticRegressionWithLBFGS
*/
val scaler = if (useFeatureScaling) {
- (new StandardScaler(withStd = true, withMean = false)).fit(input.map(x => x.features))
+ new StandardScaler(withStd = true, withMean = false).fit(input.map(_.features))
} else {
null
}
// Prepend an extra variable consisting of all 1.0's for the intercept.
- val data = if (addIntercept) {
- if (useFeatureScaling) {
- input.map(labeledPoint =>
- (labeledPoint.label, appendBias(scaler.transform(labeledPoint.features))))
- } else {
- input.map(labeledPoint => (labeledPoint.label, appendBias(labeledPoint.features)))
- }
- } else {
- if (useFeatureScaling) {
- input.map(labeledPoint => (labeledPoint.label, scaler.transform(labeledPoint.features)))
+ // TODO: Apply feature scaling to the weight vector instead of input data.
+ val data =
+ if (addIntercept) {
+ if (useFeatureScaling) {
+ input.map(lp => (lp.label, appendBias(scaler.transform(lp.features)))).cache()
+ } else {
+ input.map(lp => (lp.label, appendBias(lp.features))).cache()
+ }
} else {
- input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
+ if (useFeatureScaling) {
+ input.map(lp => (lp.label, scaler.transform(lp.features))).cache()
+ } else {
+ input.map(lp => (lp.label, lp.features))
+ }
}
- }
/**
* TODO: For better convergence, in logistic regression, the intercepts should be computed
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index 1159e59fff5f6..e8b03816573cf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -58,7 +58,7 @@ object LassoModel extends Loader[LassoModel] {
val classNameV1_0 = "org.apache.spark.mllib.regression.LassoModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val numFeatures = RegressionModel.getNumFeatures(metadata)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new LassoModel(data.weights, data.intercept)
case _ => throw new Exception(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
index 0136dcfdceaef..6fa7ad52a5b33 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala
@@ -58,7 +58,7 @@ object LinearRegressionModel extends Loader[LinearRegressionModel] {
val classNameV1_0 = "org.apache.spark.mllib.regression.LinearRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val numFeatures = RegressionModel.getNumFeatures(metadata)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new LinearRegressionModel(data.weights, data.intercept)
case _ => throw new Exception(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
index 843e59bdfbdd2..214ac4d0ed7dd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RegressionModel.scala
@@ -17,12 +17,12 @@
package org.apache.spark.mllib.regression
+import org.json4s.{DefaultFormats, JValue}
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
-import org.apache.spark.mllib.util.Loader
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
@Experimental
trait RegressionModel extends Serializable {
@@ -55,16 +55,10 @@ private[mllib] object RegressionModel {
/**
* Helper method for loading GLM regression model metadata.
- *
- * @param modelClass String name for model class (used for error messages)
* @return numFeatures
*/
- def getNumFeatures(metadata: DataFrame, modelClass: String, path: String): Int = {
- metadata.select("numFeatures").take(1)(0) match {
- case Row(nFeatures: Int) => nFeatures
- case _ => throw new Exception(s"$modelClass unable to load" +
- s" numFeatures from metadata: ${Loader.metadataPath(path)}")
- }
+ def getNumFeatures(metadata: JValue): Int = {
+ implicit val formats = DefaultFormats
+ (metadata \ "numFeatures").extract[Int]
}
-
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index f2a5f1db1ece6..8838ca8c14718 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -59,7 +59,7 @@ object RidgeRegressionModel extends Loader[RidgeRegressionModel] {
val classNameV1_0 = "org.apache.spark.mllib.regression.RidgeRegressionModel"
(loadedClassName, version) match {
case (className, "1.0") if className == classNameV1_0 =>
- val numFeatures = RegressionModel.getNumFeatures(metadata, classNameV1_0, path)
+ val numFeatures = RegressionModel.getNumFeatures(metadata)
val data = GLMRegressionModel.SaveLoadV1_0.loadData(sc, path, classNameV1_0, numFeatures)
new RidgeRegressionModel(data.weights, data.intercept)
case _ => throw new Exception(
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index 838100e949ec2..bd7e340ca2d8e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -17,6 +17,9 @@
package org.apache.spark.mllib.regression.impl
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.SparkContext
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.Loader
@@ -48,14 +51,14 @@ private[regression] object GLMRegressionModel {
import sqlContext.implicits._
// Create JSON metadata.
- val metadataRDD =
- sc.parallelize(Seq((modelClass, thisFormatVersion, weights.size)), 1)
- .toDataFrame("class", "version", "numFeatures")
- metadataRDD.toJSON.saveAsTextFile(Loader.metadataPath(path))
+ val metadata = compact(render(
+ ("class" -> modelClass) ~ ("version" -> thisFormatVersion) ~
+ ("numFeatures" -> weights.size)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
// Create Parquet data.
val data = Data(weights, intercept)
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF()
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index fd186b5ee6f72..cd6add9d60b0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -17,7 +17,7 @@
package org.apache.spark.mllib.stat.distribution
-import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym}
+import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV}
import org.apache.spark.annotation.DeveloperApi;
import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix}
@@ -62,21 +62,21 @@ class MultivariateGaussian (
/** Returns density of this multivariate Gaussian at given point, x */
def pdf(x: Vector): Double = {
- pdf(x.toBreeze.toDenseVector)
+ pdf(x.toBreeze)
}
/** Returns the log-density of this multivariate Gaussian at given point, x */
def logpdf(x: Vector): Double = {
- logpdf(x.toBreeze.toDenseVector)
+ logpdf(x.toBreeze)
}
/** Returns density of this multivariate Gaussian at given point, x */
- private[mllib] def pdf(x: DBV[Double]): Double = {
+ private[mllib] def pdf(x: BV[Double]): Double = {
math.exp(logpdf(x))
}
/** Returns the log-density of this multivariate Gaussian at given point, x */
- private[mllib] def logpdf(x: DBV[Double]): Double = {
+ private[mllib] def logpdf(x: BV[Double]): Double = {
val delta = x - breezeMu
val v = rootSigmaInv * delta
u + v.t * v * -0.5
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index b3e8ed9af8c51..b9d0c56dd1ea3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -19,12 +19,11 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.collection.mutable.ArrayBuffer
-
+import scala.collection.mutable.ArrayBuilder
+import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.RandomForest.NodeIndexInfo
import org.apache.spark.mllib.tree.configuration.Strategy
@@ -32,13 +31,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.impl._
-import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
-import org.apache.spark.SparkContext._
-
/**
* :: Experimental ::
@@ -331,14 +327,14 @@ object DecisionTree extends Serializable with Logging {
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+ * @param splits possible splits indexed (numFeatures)(numSplits)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- bins: Array[Array[Bin]],
+ splits: Array[Array[Split]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
@@ -366,7 +362,7 @@ object DecisionTree extends Serializable with Logging {
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
- if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
+ if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
@@ -510,8 +506,8 @@ object DecisionTree extends Serializable with Logging {
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
- mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
- instanceWeight, featuresForNode)
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
+ metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
}
}
@@ -1028,35 +1024,15 @@ object DecisionTree extends Serializable with Logging {
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
- // TODO: The second half of the bins are unused. Actually, we could just use
- // splits and not build bins for unordered features. That should be part of
- // a later PR since it will require changing other code (using splits instead
- // of bins in a few places).
// Unordered features
- // 2^(maxFeatureValue - 1) - 1 combinations
+ // 2^(maxFeatureValue - 1) - 1 combinations
splits(featureIndex) = new Array[Split](numSplits)
- bins(featureIndex) = new Array[Bin](numBins)
var splitIndex = 0
while (splitIndex < numSplits) {
val categories: List[Double] =
extractMultiClassCategories(splitIndex + 1, featureArity)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, Double.MinValue, Categorical, categories)
- bins(featureIndex)(splitIndex) = {
- if (splitIndex == 0) {
- new Bin(
- new DummyCategoricalSplit(featureIndex, Categorical),
- splits(featureIndex)(0),
- Categorical,
- Double.MinValue)
- } else {
- new Bin(
- splits(featureIndex)(splitIndex - 1),
- splits(featureIndex)(splitIndex),
- Categorical,
- Double.MinValue)
- }
- }
splitIndex += 1
}
} else {
@@ -1064,8 +1040,11 @@ object DecisionTree extends Serializable with Logging {
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
splits(featureIndex) = new Array[Split](0)
- bins(featureIndex) = new Array[Bin](0)
}
+ // For ordered features, bins correspond to feature values.
+ // For unordered categorical features, there is no need to construct the bins.
+ // since there is a one-to-one correspondence between the splits and the bins.
+ bins(featureIndex) = new Array[Bin](0)
}
featureIndex += 1
}
@@ -1140,7 +1119,7 @@ object DecisionTree extends Serializable with Logging {
logDebug("stride = " + stride)
// iterate `valueCount` to find splits
- val splits = new ArrayBuffer[Double]
+ val splitsBuilder = ArrayBuilder.make[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCounts(0)._2
@@ -1158,13 +1137,13 @@ object DecisionTree extends Serializable with Logging {
// makes the gap between currentCount and targetCount smaller,
// previous value is a split threshold.
if (previousGap < currentGap) {
- splits.append(valueCounts(index - 1)._1)
+ splitsBuilder += valueCounts(index - 1)._1
targetCount += stride
}
index += 1
}
- splits.toArray
+ splitsBuilder.result()
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index 61f6b1313f82e..a9c93e181e3ce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -60,11 +60,12 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
def run(input: RDD[LabeledPoint]): GradientBoostedTreesModel = {
val algo = boostingStrategy.treeStrategy.algo
algo match {
- case Regression => GradientBoostedTrees.boost(input, boostingStrategy)
+ case Regression => GradientBoostedTrees.boost(input, input, boostingStrategy, validate=false)
case Classification =>
// Map labels to -1, +1 so binary classification can be treated as regression.
val remappedInput = input.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- GradientBoostedTrees.boost(remappedInput, boostingStrategy)
+ GradientBoostedTrees.boost(remappedInput,
+ remappedInput, boostingStrategy, validate=false)
case _ =>
throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
}
@@ -76,8 +77,46 @@ class GradientBoostedTrees(private val boostingStrategy: BoostingStrategy)
def run(input: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
run(input.rdd)
}
-}
+ /**
+ * Method to validate a gradient boosting model
+ * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ * @param validationInput Validation dataset:
+ RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+ Should be different from and follow the same distribution as input.
+ e.g., these two datasets could be created from an original dataset
+ by using [[org.apache.spark.rdd.RDD.randomSplit()]]
+ * @return a gradient boosted trees model that can be used for prediction
+ */
+ def runWithValidation(
+ input: RDD[LabeledPoint],
+ validationInput: RDD[LabeledPoint]): GradientBoostedTreesModel = {
+ val algo = boostingStrategy.treeStrategy.algo
+ algo match {
+ case Regression => GradientBoostedTrees.boost(
+ input, validationInput, boostingStrategy, validate=true)
+ case Classification =>
+ // Map labels to -1, +1 so binary classification can be treated as regression.
+ val remappedInput = input.map(
+ x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val remappedValidationInput = validationInput.map(
+ x => new LabeledPoint((x.label * 2) - 1, x.features))
+ GradientBoostedTrees.boost(remappedInput, remappedValidationInput, boostingStrategy,
+ validate=true)
+ case _ =>
+ throw new IllegalArgumentException(s"$algo is not supported by the gradient boosting.")
+ }
+ }
+
+ /**
+ * Java-friendly API for [[org.apache.spark.mllib.tree.GradientBoostedTrees!#runWithValidation]].
+ */
+ def runWithValidation(
+ input: JavaRDD[LabeledPoint],
+ validationInput: JavaRDD[LabeledPoint]): GradientBoostedTreesModel = {
+ runWithValidation(input.rdd, validationInput.rdd)
+ }
+}
object GradientBoostedTrees extends Logging {
@@ -108,12 +147,16 @@ object GradientBoostedTrees extends Logging {
/**
* Internal method for performing regression using trees as base learners.
* @param input training dataset
+ * @param validationInput validation dataset, ignored if validate is set to false.
* @param boostingStrategy boosting parameters
+ * @param validate whether or not to use the validation dataset.
* @return a gradient boosted trees model that can be used for prediction
*/
private def boost(
input: RDD[LabeledPoint],
- boostingStrategy: BoostingStrategy): GradientBoostedTreesModel = {
+ validationInput: RDD[LabeledPoint],
+ boostingStrategy: BoostingStrategy,
+ validate: Boolean): GradientBoostedTreesModel = {
val timer = new TimeTracker()
timer.start("total")
@@ -129,6 +172,7 @@ object GradientBoostedTrees extends Logging {
val learningRate = boostingStrategy.learningRate
// Prepare strategy for individual trees, which use regression with variance impurity.
val treeStrategy = boostingStrategy.treeStrategy.copy
+ val validationTol = boostingStrategy.validationTol
treeStrategy.algo = Regression
treeStrategy.impurity = Variance
treeStrategy.assertValid()
@@ -152,13 +196,16 @@ object GradientBoostedTrees extends Logging {
baseLearnerWeights(0) = 1.0
val startingModel = new GradientBoostedTreesModel(Regression, Array(firstTreeModel), Array(1.0))
logDebug("error of gbt = " + loss.computeError(startingModel, input))
+
// Note: A model of type regression is used since we require raw prediction
timer.stop("building tree 0")
+ var bestValidateError = if (validate) loss.computeError(startingModel, validationInput) else 0.0
+ var bestM = 1
+
// psuedo-residual for second iteration
data = input.map(point => LabeledPoint(loss.gradient(startingModel, point),
point.features))
-
var m = 1
while (m < numIterations) {
timer.start(s"building tree $m")
@@ -177,6 +224,23 @@ object GradientBoostedTrees extends Logging {
val partialModel = new GradientBoostedTreesModel(
Regression, baseLearners.slice(0, m + 1), baseLearnerWeights.slice(0, m + 1))
logDebug("error of gbt = " + loss.computeError(partialModel, input))
+
+ if (validate) {
+ // Stop training early if
+ // 1. Reduction in error is less than the validationTol or
+ // 2. If the error increases, that is if the model is overfit.
+ // We want the model returned corresponding to the best validation error.
+ val currentValidateError = loss.computeError(partialModel, validationInput)
+ if (bestValidateError - currentValidateError < validationTol) {
+ return new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo,
+ baseLearners.slice(0, bestM),
+ baseLearnerWeights.slice(0, bestM))
+ } else if (currentValidateError < bestValidateError) {
+ bestValidateError = currentValidateError
+ bestM = m + 1
+ }
+ }
// Update data with pseudo-residuals
data = input.map(point => LabeledPoint(-loss.gradient(partialModel, point),
point.features))
@@ -187,8 +251,15 @@ object GradientBoostedTrees extends Logging {
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
-
- new GradientBoostedTreesModel(
- boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+ if (validate) {
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo,
+ baseLearners.slice(0, bestM),
+ baseLearnerWeights.slice(0, bestM))
+ } else {
+ new GradientBoostedTreesModel(
+ boostingStrategy.treeStrategy.algo, baseLearners, baseLearnerWeights)
+ }
}
+
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index ed8e6a796f8c4..664c8df019233 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -34,6 +34,9 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* weak hypotheses used in the final model.
* @param learningRate Learning rate for shrinking the contribution of each estimator. The
* learning rate should be between in the interval (0, 1]
+ * @param validationTol Useful when runWithValidation is used. If the error rate on the
+ * validation input between two iterations is less than the validationTol
+ * then stop. Ignored when [[run]] is used.
*/
@Experimental
case class BoostingStrategy(
@@ -42,7 +45,8 @@ case class BoostingStrategy(
@BeanProperty var loss: Loss,
// Optional boosting parameters
@BeanProperty var numIterations: Int = 100,
- @BeanProperty var learningRate: Double = 0.1) extends Serializable {
+ @BeanProperty var learningRate: Double = 0.1,
+ @BeanProperty var validationTol: Double = 1e-5) extends Serializable {
/**
* Check validity of parameters.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index 35e361ae309cc..50b292e71b067 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -55,17 +55,15 @@ private[tree] object TreePoint {
input: RDD[LabeledPoint],
bins: Array[Array[Bin]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
- // Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
+ // Construct arrays for featureArity for efficiency in the inner loop.
val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
- val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
var featureIndex = 0
while (featureIndex < metadata.numFeatures) {
featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
- isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
featureIndex += 1
}
input.map { x =>
- TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
+ TreePoint.labeledPointToTreePoint(x, bins, featureArity)
}
}
@@ -74,19 +72,17 @@ private[tree] object TreePoint {
* @param bins Bins for features, of size (numFeatures, numBins).
* @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
* for categorical features.
- * @param isUnordered Array index by feature, with value true for unordered categorical features.
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
bins: Array[Array[Bin]],
- featureArity: Array[Int],
- isUnordered: Array[Boolean]): TreePoint = {
+ featureArity: Array[Int]): TreePoint = {
val numFeatures = labeledPoint.features.size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
- isUnordered(featureIndex), bins)
+ bins)
featureIndex += 1
}
new TreePoint(labeledPoint.label, arr)
@@ -96,14 +92,12 @@ private[tree] object TreePoint {
* Find bin for one (labeledPoint, feature).
*
* @param featureArity 0 for continuous features; number of categories for categorical features.
- * @param isUnorderedFeature (only applies if feature is categorical)
* @param bins Bins for features, of size (numFeatures, numBins).
*/
private def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
featureArity: Int,
- isUnorderedFeature: Boolean,
bins: Array[Array[Bin]]): Int = {
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index a25e625a4017a..060fd5b859a51 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -17,11 +17,21 @@
package org.apache.spark.mllib.tree.model
+import scala.collection.mutable
+
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.{Algo, FeatureType}
import org.apache.spark.mllib.tree.configuration.Algo._
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
/**
* :: Experimental ::
@@ -31,7 +41,7 @@ import org.apache.spark.rdd.RDD
* @param algo algorithm type -- classification or regression
*/
@Experimental
-class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable {
+class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable with Saveable {
/**
* Predict values for a single data point using the model trained.
@@ -98,4 +108,183 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
header + topNode.subtreeToString(2)
}
+ override def save(sc: SparkContext, path: String): Unit = {
+ DecisionTreeModel.SaveLoadV1_0.save(sc, path, this)
+ }
+
+ override protected def formatVersion: String = "1.0"
+}
+
+object DecisionTreeModel extends Loader[DecisionTreeModel] {
+
+ private[tree] object SaveLoadV1_0 {
+
+ def thisFormatVersion = "1.0"
+
+ // Hard-code class name string in case it changes in the future
+ def thisClassName = "org.apache.spark.mllib.tree.DecisionTreeModel"
+
+ case class PredictData(predict: Double, prob: Double) {
+ def toPredict: Predict = new Predict(predict, prob)
+ }
+
+ object PredictData {
+ def apply(p: Predict): PredictData = PredictData(p.predict, p.prob)
+
+ def apply(r: Row): PredictData = PredictData(r.getDouble(0), r.getDouble(1))
+ }
+
+ case class SplitData(
+ feature: Int,
+ threshold: Double,
+ featureType: Int,
+ categories: Seq[Double]) { // TODO: Change to List once SPARK-3365 is fixed
+ def toSplit: Split = {
+ new Split(feature, threshold, FeatureType(featureType), categories.toList)
+ }
+ }
+
+ object SplitData {
+ def apply(s: Split): SplitData = {
+ SplitData(s.feature, s.threshold, s.featureType.id, s.categories)
+ }
+
+ def apply(r: Row): SplitData = {
+ SplitData(r.getInt(0), r.getDouble(1), r.getInt(2), r.getAs[Seq[Double]](3))
+ }
+ }
+
+ /** Model data for model import/export */
+ case class NodeData(
+ treeId: Int,
+ nodeId: Int,
+ predict: PredictData,
+ impurity: Double,
+ isLeaf: Boolean,
+ split: Option[SplitData],
+ leftNodeId: Option[Int],
+ rightNodeId: Option[Int],
+ infoGain: Option[Double])
+
+ object NodeData {
+ def apply(treeId: Int, n: Node): NodeData = {
+ NodeData(treeId, n.id, PredictData(n.predict), n.impurity, n.isLeaf,
+ n.split.map(SplitData.apply), n.leftNode.map(_.id), n.rightNode.map(_.id),
+ n.stats.map(_.gain))
+ }
+
+ def apply(r: Row): NodeData = {
+ val split = if (r.isNullAt(5)) None else Some(SplitData(r.getStruct(5)))
+ val leftNodeId = if (r.isNullAt(6)) None else Some(r.getInt(6))
+ val rightNodeId = if (r.isNullAt(7)) None else Some(r.getInt(7))
+ val infoGain = if (r.isNullAt(8)) None else Some(r.getDouble(8))
+ NodeData(r.getInt(0), r.getInt(1), PredictData(r.getStruct(2)), r.getDouble(3),
+ r.getBoolean(4), split, leftNodeId, rightNodeId, infoGain)
+ }
+ }
+
+ def save(sc: SparkContext, path: String, model: DecisionTreeModel): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~
+ ("algo" -> model.algo.toString) ~ ("numNodes" -> model.numNodes)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val nodes = model.topNode.subtreeIterator.toSeq
+ val dataRDD: DataFrame = sc.parallelize(nodes)
+ .map(NodeData.apply(0, _))
+ .toDF()
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String, algo: String, numNodes: Int): DecisionTreeModel = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ // Load Parquet data.
+ val dataRDD = sqlContext.parquetFile(datapath)
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[NodeData](dataRDD.schema)
+ val nodes = dataRDD.map(NodeData.apply)
+ // Build node data into a tree.
+ val trees = constructTrees(nodes)
+ assert(trees.size == 1,
+ "Decision tree should contain exactly one tree but got ${trees.size} trees.")
+ val model = new DecisionTreeModel(trees(0), Algo.fromString(algo))
+ assert(model.numNodes == numNodes, s"Unable to load DecisionTreeModel data from: $datapath." +
+ s" Expected $numNodes nodes but found ${model.numNodes}")
+ model
+ }
+
+ def constructTrees(nodes: RDD[NodeData]): Array[Node] = {
+ val trees = nodes
+ .groupBy(_.treeId)
+ .mapValues(_.toArray)
+ .collect()
+ .map { case (treeId, data) =>
+ (treeId, constructTree(data))
+ }.sortBy(_._1)
+ val numTrees = trees.size
+ val treeIndices = trees.map(_._1).toSeq
+ assert(treeIndices == (0 until numTrees),
+ s"Tree indices must start from 0 and increment by 1, but we found $treeIndices.")
+ trees.map(_._2)
+ }
+
+ /**
+ * Given a list of nodes from a tree, construct the tree.
+ * @param data array of all node data in a tree.
+ */
+ def constructTree(data: Array[NodeData]): Node = {
+ val dataMap: Map[Int, NodeData] = data.map(n => n.nodeId -> n).toMap
+ assert(dataMap.contains(1),
+ s"DecisionTree missing root node (id = 1).")
+ constructNode(1, dataMap, mutable.Map.empty)
+ }
+
+ /**
+ * Builds a node from the node data map and adds new nodes to the input nodes map.
+ */
+ private def constructNode(
+ id: Int,
+ dataMap: Map[Int, NodeData],
+ nodes: mutable.Map[Int, Node]): Node = {
+ if (nodes.contains(id)) {
+ return nodes(id)
+ }
+ val data = dataMap(id)
+ val node =
+ if (data.isLeaf) {
+ Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf)
+ } else {
+ val leftNode = constructNode(data.leftNodeId.get, dataMap, nodes)
+ val rightNode = constructNode(data.rightNodeId.get, dataMap, nodes)
+ val stats = new InformationGainStats(data.infoGain.get, data.impurity, leftNode.impurity,
+ rightNode.impurity, leftNode.predict, rightNode.predict)
+ new Node(data.nodeId, data.predict.toPredict, data.impurity, data.isLeaf,
+ data.split.map(_.toSplit), Some(leftNode), Some(rightNode), Some(stats))
+ }
+ nodes += node.id -> node
+ node
+ }
+ }
+
+ override def load(sc: SparkContext, path: String): DecisionTreeModel = {
+ implicit val formats = DefaultFormats
+ val (loadedClassName, version, metadata) = Loader.loadMetadata(sc, path)
+ val algo = (metadata \ "algo").extract[String]
+ val numNodes = (metadata \ "numNodes").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ SaveLoadV1_0.load(sc, path, algo, numNodes)
+ case _ => throw new Exception(
+ s"DecisionTreeModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index 9a50ecb550c38..80990aa9a603f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -49,7 +49,9 @@ class InformationGainStats(
gain == other.gain &&
impurity == other.impurity &&
leftImpurity == other.leftImpurity &&
- rightImpurity == other.rightImpurity
+ rightImpurity == other.rightImpurity &&
+ leftPredict == other.leftPredict &&
+ rightPredict == other.rightPredict
}
case _ => false
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 2179da8dbe03e..d961081d185e9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -166,6 +166,11 @@ class Node (
}
}
+ /** Returns an iterator that traverses (DFS, left to right) the subtree of this node. */
+ private[tree] def subtreeIterator: Iterator[Node] = {
+ Iterator.single(this) ++ leftNode.map(_.subtreeIterator).getOrElse(Iterator.empty) ++
+ rightNode.map(_.subtreeIterator).getOrElse(Iterator.empty)
+ }
}
private[tree] object Node {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
index 004838ee5ba0e..ad4c0dbbfb3e5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -32,4 +32,11 @@ class Predict(
override def toString = {
"predict = %f, prob = %f".format(predict, prob)
}
+
+ override def equals(other: Any): Boolean = {
+ other match {
+ case p: Predict => predict == p.predict && prob == p.prob
+ case _ => false
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index 22997110de8dd..4897906aea5b3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -20,13 +20,20 @@ package org.apache.spark.mllib.tree.model
import scala.collection.mutable
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy._
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
/**
* :: Experimental ::
@@ -38,9 +45,42 @@ import org.apache.spark.rdd.RDD
@Experimental
class RandomForestModel(override val algo: Algo, override val trees: Array[DecisionTreeModel])
extends TreeEnsembleModel(algo, trees, Array.fill(trees.size)(1.0),
- combiningStrategy = if (algo == Classification) Vote else Average) {
+ combiningStrategy = if (algo == Classification) Vote else Average)
+ with Saveable {
require(trees.forall(_.algo == algo))
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
+ RandomForestModel.SaveLoadV1_0.thisClassName)
+ }
+
+ override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+}
+
+object RandomForestModel extends Loader[RandomForestModel] {
+
+ override def load(sc: SparkContext, path: String): RandomForestModel = {
+ val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata)
+ assert(metadata.treeWeights.forall(_ == 1.0))
+ val trees =
+ TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
+ new RandomForestModel(Algo.fromString(metadata.algo), trees)
+ case _ => throw new Exception(s"RandomForestModel.load did not recognize model" +
+ s" with (className, format version): ($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private object SaveLoadV1_0 {
+ // Hard-code class name string in case it changes in the future
+ def thisClassName = "org.apache.spark.mllib.tree.model.RandomForestModel"
+ }
+
}
/**
@@ -56,9 +96,42 @@ class GradientBoostedTreesModel(
override val algo: Algo,
override val trees: Array[DecisionTreeModel],
override val treeWeights: Array[Double])
- extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum) {
+ extends TreeEnsembleModel(algo, trees, treeWeights, combiningStrategy = Sum)
+ with Saveable {
require(trees.size == treeWeights.size)
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ TreeEnsembleModel.SaveLoadV1_0.save(sc, path, this,
+ GradientBoostedTreesModel.SaveLoadV1_0.thisClassName)
+ }
+
+ override protected def formatVersion: String = TreeEnsembleModel.SaveLoadV1_0.thisFormatVersion
+}
+
+object GradientBoostedTreesModel extends Loader[GradientBoostedTreesModel] {
+
+ override def load(sc: SparkContext, path: String): GradientBoostedTreesModel = {
+ val (loadedClassName, version, jsonMetadata) = Loader.loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, version) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ val metadata = TreeEnsembleModel.SaveLoadV1_0.readMetadata(jsonMetadata)
+ assert(metadata.combiningStrategy == Sum.toString)
+ val trees =
+ TreeEnsembleModel.SaveLoadV1_0.loadTrees(sc, path, metadata.treeAlgo)
+ new GradientBoostedTreesModel(Algo.fromString(metadata.algo), trees, metadata.treeWeights)
+ case _ => throw new Exception(s"GradientBoostedTreesModel.load did not recognize model" +
+ s" with (className, format version): ($loadedClassName, $version). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private object SaveLoadV1_0 {
+ // Hard-code class name string in case it changes in the future
+ def thisClassName = "org.apache.spark.mllib.tree.model.GradientBoostedTreesModel"
+ }
+
}
/**
@@ -176,3 +249,74 @@ private[tree] sealed class TreeEnsembleModel(
*/
def totalNumNodes: Int = trees.map(_.numNodes).sum
}
+
+private[tree] object TreeEnsembleModel {
+
+ object SaveLoadV1_0 {
+
+ import org.apache.spark.mllib.tree.model.DecisionTreeModel.SaveLoadV1_0.{NodeData, constructTrees}
+
+ def thisFormatVersion = "1.0"
+
+ case class Metadata(
+ algo: String,
+ treeAlgo: String,
+ combiningStrategy: String,
+ treeWeights: Array[Double])
+
+ /**
+ * Model data for model import/export.
+ * We have to duplicate NodeData here since Spark SQL does not yet support extracting subfields
+ * of nested fields; once that is possible, we can use something like:
+ * case class EnsembleNodeData(treeId: Int, node: NodeData),
+ * where NodeData is from DecisionTreeModel.
+ */
+ case class EnsembleNodeData(treeId: Int, node: NodeData)
+
+ def save(sc: SparkContext, path: String, model: TreeEnsembleModel, className: String): Unit = {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Create JSON metadata.
+ implicit val format = DefaultFormats
+ val ensembleMetadata = Metadata(model.algo.toString, model.trees(0).algo.toString,
+ model.combiningStrategy.toString, model.treeWeights)
+ val metadata = compact(render(
+ ("class" -> className) ~ ("version" -> thisFormatVersion) ~
+ ("metadata" -> Extraction.decompose(ensembleMetadata))))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Create Parquet data.
+ val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
+ tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
+ }.toDF()
+ dataRDD.saveAsParquetFile(Loader.dataPath(path))
+ }
+
+ /**
+ * Read metadata from the loaded JSON metadata.
+ */
+ def readMetadata(metadata: JValue): Metadata = {
+ implicit val formats = DefaultFormats
+ (metadata \ "metadata").extract[Metadata]
+ }
+
+ /**
+ * Load trees for an ensemble, and return them in order.
+ * @param path path to load the model from
+ * @param treeAlgo Algorithm for individual trees (which may differ from the ensemble's
+ * algorithm).
+ */
+ def loadTrees(
+ sc: SparkContext,
+ path: String,
+ treeAlgo: String): Array[DecisionTreeModel] = {
+ val datapath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val nodes = sqlContext.parquetFile(datapath).map(NodeData.apply)
+ val trees = constructTrees(nodes)
+ trees.map(new DecisionTreeModel(_, Algo.fromString(treeAlgo)))
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala
index f7cba6c6cb628..308f7f3578e21 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/NumericParser.scala
@@ -19,7 +19,7 @@ package org.apache.spark.mllib.util
import java.util.StringTokenizer
-import scala.collection.mutable.{ArrayBuffer, ListBuffer}
+import scala.collection.mutable.{ArrayBuilder, ListBuffer}
import org.apache.spark.SparkException
@@ -51,7 +51,7 @@ private[mllib] object NumericParser {
}
private def parseArray(tokenizer: StringTokenizer): Array[Double] = {
- val values = ArrayBuffer.empty[Double]
+ val values = ArrayBuilder.make[Double]
var parsing = true
var allowComma = false
var token: String = null
@@ -67,14 +67,14 @@ private[mllib] object NumericParser {
}
} else {
// expecting a number
- values.append(parseDouble(token))
+ values += parseDouble(token)
allowComma = true
}
}
if (parsing) {
throw new SparkException(s"An array must end with ']'.")
}
- values.toArray
+ values.result()
}
private def parseTuple(tokenizer: StringTokenizer): Seq[_] = {
@@ -114,7 +114,7 @@ private[mllib] object NumericParser {
try {
java.lang.Double.parseDouble(s)
} catch {
- case e: Throwable =>
+ case e: NumberFormatException =>
throw new SparkException(s"Cannot parse a double from: $s", e)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
index 56b77a7d12e83..4458340497f0b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/modelSaveLoad.scala
@@ -20,13 +20,13 @@ package org.apache.spark.mllib.util
import scala.reflect.runtime.universe.TypeTag
import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkContext
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.types.{DataType, StructType, StructField}
-
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
/**
* :: DeveloperApi ::
@@ -120,20 +120,11 @@ private[mllib] object Loader {
* Load metadata from the given path.
* @return (class name, version, metadata)
*/
- def loadMetadata(sc: SparkContext, path: String): (String, String, DataFrame) = {
- val sqlContext = new SQLContext(sc)
- val metadata = sqlContext.jsonFile(metadataPath(path))
- val (clazz, version) = try {
- val metadataArray = metadata.select("class", "version").take(1)
- assert(metadataArray.size == 1)
- metadataArray(0) match {
- case Row(clazz: String, version: String) => (clazz, version)
- }
- } catch {
- case e: Exception =>
- throw new Exception(s"Unable to load model metadata from: ${metadataPath(path)}")
- }
+ def loadMetadata(sc: SparkContext, path: String): (String, String, JValue) = {
+ implicit val formats = DefaultFormats
+ val metadata = parse(sc.textFile(metadataPath(path)).first())
+ val clazz = (metadata \ "class").extract[String]
+ val version = (metadata \ "version").extract[String]
(clazz, version, metadata)
}
-
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 50995ffef9ad5..0a8c9e5954676 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -45,7 +45,7 @@ public void setUp() {
jsql = new SQLContext(jsc);
JavaRDD points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
- dataset = jsql.applySchema(points, LabeledPoint.class);
+ dataset = jsql.createDataFrame(points, LabeledPoint.class);
}
@After
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index d4b664479255d..3f8e59de0f05c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -50,7 +50,7 @@ public void setUp() {
jsql = new SQLContext(jsc);
List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
index ac945ba6f23c1..640d2ec55e4e7 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaStreamingLogisticRegressionSuite.java
@@ -47,7 +47,7 @@ public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
ssc = new JavaStreamingContext(conf, new Duration(1000));
ssc.checkpoint("checkpoint");
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 40d5a92bb32af..0cc36c8d56d70 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -46,7 +46,7 @@ public void setUp() {
jsql = new SQLContext(jsc);
List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
datasetRDD = jsc.parallelize(points, 2);
- dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset = jsql.createDataFrame(datasetRDD, LabeledPoint.class);
dataset.registerTempTable("dataset");
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 074b58c07df7a..0bb6b489f2757 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -45,7 +45,7 @@ public void setUp() {
jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
jsql = new SQLContext(jsc);
List points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
- dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ dataset = jsql.createDataFrame(jsc.parallelize(points, 2), LabeledPoint.class);
}
@After
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 851707c8a19c4..bd0edf2b9ea62 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -19,6 +19,7 @@
import java.io.Serializable;
import java.util.ArrayList;
+import java.util.List;
import org.junit.After;
import org.junit.Before;
@@ -28,6 +29,7 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -55,30 +57,18 @@ public void runFPGrowth() {
Lists.newArrayList("z".split(" ")),
Lists.newArrayList("x z y r q t p".split(" "))), 2);
- FPGrowth fpg = new FPGrowth();
-
- FPGrowthModel model6 = fpg
- .setMinSupport(0.9)
- .setNumPartitions(1)
- .run(rdd);
- assertEquals(0, model6.javaFreqItemsets().count());
-
- FPGrowthModel model3 = fpg
+ FPGrowthModel model = new FPGrowth()
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd);
- assertEquals(18, model3.javaFreqItemsets().count());
- FPGrowthModel model2 = fpg
- .setMinSupport(0.3)
- .setNumPartitions(4)
- .run(rdd);
- assertEquals(54, model2.javaFreqItemsets().count());
+ List> freqItemsets = model.freqItemsets().toJavaRDD().collect();
+ assertEquals(18, freqItemsets.size());
- FPGrowthModel model1 = fpg
- .setMinSupport(0.1)
- .setNumPartitions(8)
- .run(rdd);
- assertEquals(625, model1.javaFreqItemsets().count());
+ for (FreqItemset itemset: freqItemsets) {
+ // Test return types.
+ List items = itemset.javaItems();
+ long freq = itemset.freq();
+ }
}
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
index a4dd1ac39a3c8..899c4ea607869 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java
@@ -45,7 +45,7 @@ public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
ssc = new JavaStreamingContext(conf, new Duration(1000));
ssc.checkpoint("checkpoint");
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index cb7d57de35c34..bb86bafc0eb0a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.scalatest.FunSuite
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkException}
import org.apache.spark.ml.recommendation.ALS._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
@@ -358,8 +358,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
.setNumUserBlocks(numUserBlocks)
.setNumItemBlocks(numItemBlocks)
val alpha = als.getAlpha
- val model = als.fit(training)
- val predictions = model.transform(test)
+ val model = als.fit(training.toDF())
+ val predictions = model.transform(test.toDF())
.select("rating", "prediction")
.map { case Row(rating: Float, prediction: Float) =>
(rating.toDouble, prediction.toDouble)
@@ -455,4 +455,34 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
assert(isNonnegative(itemFactors))
// TODO: Validate the solution.
}
+
+ test("als partitioner is a projection") {
+ for (p <- Seq(1, 10, 100, 1000)) {
+ val part = new ALSPartitioner(p)
+ var k = 0
+ while (k < p) {
+ assert(k === part.getPartition(k))
+ assert(k === part.getPartition(k.toLong))
+ k += 1
+ }
+ }
+ }
+
+ test("partitioner in returned factors") {
+ val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01)
+ val (userFactors, itemFactors) = ALS.train(
+ ratings, rank = 2, maxIter = 4, numUserBlocks = 3, numItemBlocks = 4)
+ for ((tpe, factors) <- Seq(("User", userFactors), ("Item", itemFactors))) {
+ assert(userFactors.partitioner.isDefined, s"$tpe factors should have partitioner.")
+ val part = userFactors.partitioner.get
+ userFactors.mapPartitionsWithIndex { (idx, items) =>
+ items.foreach { case (id, _) =>
+ if (part.getPartition(id) != idx) {
+ throw new SparkException(s"$tpe with ID $id should not be in partition $idx.")
+ }
+ }
+ Iterator.empty
+ }.count()
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
index c2cd56ea40adc..1b46a4012d731 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala
@@ -31,7 +31,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
Vectors.dense(5.0, 10.0),
Vectors.dense(4.0, 11.0)
))
-
+
// expectations
val Ew = 1.0
val Emu = Vectors.dense(5.0, 10.0)
@@ -44,6 +44,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
}
+
}
test("two clusters") {
@@ -54,7 +55,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
))
-
+
// we set an initial gaussian to induce expected results
val initialGmm = new GaussianMixtureModel(
Array(0.5, 0.5),
@@ -63,7 +64,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
)
)
-
+
val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
@@ -72,7 +73,7 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
.setK(2)
.setInitialModel(initialGmm)
.run(data)
-
+
assert(gmm.weights(0) ~== Ew(0) absTol 1E-3)
assert(gmm.weights(1) ~== Ew(1) absTol 1E-3)
assert(gmm.gaussians(0).mu ~== Emu(0) absTol 1E-3)
@@ -80,4 +81,61 @@ class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext {
assert(gmm.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
assert(gmm.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
}
+
+ test("single cluster with sparse data") {
+ val data = sc.parallelize(Array(
+ Vectors.sparse(3, Array(0, 2), Array(4.0, 2.0)),
+ Vectors.sparse(3, Array(0, 2), Array(2.0, 4.0)),
+ Vectors.sparse(3, Array(1), Array(6.0))
+ ))
+
+ val Ew = 1.0
+ val Emu = Vectors.dense(2.0, 2.0, 2.0)
+ val Esigma = Matrices.dense(3, 3,
+ Array(8.0 / 3.0, -4.0, 4.0 / 3.0, -4.0, 8.0, -4.0, 4.0 / 3.0, -4.0, 8.0 / 3.0)
+ )
+
+ val seeds = Array(42, 1994, 27, 11, 0)
+ seeds.foreach { seed =>
+ val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data)
+ assert(gmm.weights(0) ~== Ew absTol 1E-5)
+ assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5)
+ assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5)
+ }
+ }
+
+ test("two clusters with sparse data") {
+ val data = sc.parallelize(Array(
+ Vectors.dense(-5.1971), Vectors.dense(-2.5359), Vectors.dense(-3.8220),
+ Vectors.dense(-5.2211), Vectors.dense(-5.0602), Vectors.dense( 4.7118),
+ Vectors.dense( 6.8989), Vectors.dense( 3.4592), Vectors.dense( 4.6322),
+ Vectors.dense( 5.7048), Vectors.dense( 4.6567), Vectors.dense( 5.5026),
+ Vectors.dense( 4.5605), Vectors.dense( 5.2043), Vectors.dense( 6.2734)
+ ))
+
+ val sparseData = data.map(point => Vectors.sparse(1, Array(0), point.toArray))
+ // we set an initial gaussian to induce expected results
+ val initialGmm = new GaussianMixtureModel(
+ Array(0.5, 0.5),
+ Array(
+ new MultivariateGaussian(Vectors.dense(-1.0), Matrices.dense(1, 1, Array(1.0))),
+ new MultivariateGaussian(Vectors.dense(1.0), Matrices.dense(1, 1, Array(1.0)))
+ )
+ )
+ val Ew = Array(1.0 / 3.0, 2.0 / 3.0)
+ val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604))
+ val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644)))
+
+ val sparseGMM = new GaussianMixture()
+ .setK(2)
+ .setInitialModel(initialGmm)
+ .run(data)
+
+ assert(sparseGMM.weights(0) ~== Ew(0) absTol 1E-3)
+ assert(sparseGMM.weights(1) ~== Ew(1) absTol 1E-3)
+ assert(sparseGMM.gaussians(0).mu ~== Emu(0) absTol 1E-3)
+ assert(sparseGMM.gaussians(1).mu ~== Emu(1) absTol 1E-3)
+ assert(sparseGMM.gaussians(0).sigma ~== Esigma(0) absTol 1E-3)
+ assert(sparseGMM.gaussians(1).sigma ~== Esigma(1) absTol 1E-3)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 03ecd9ca730be..6315c03a700f1 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -51,8 +51,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
.setK(2)
.run(sc.parallelize(similarities, 2))
val predictions = Array.fill(2)(mutable.Set.empty[Long])
- model.assignments.collect().foreach { case (i, c) =>
- predictions(c) += i
+ model.assignments.collect().foreach { a =>
+ predictions(a.cluster) += a.id
}
assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
@@ -61,8 +61,8 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
.setInitializationMode("degree")
.run(sc.parallelize(similarities, 2))
val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
- model2.assignments.collect().foreach { case (i, c) =>
- predictions2(c) += i
+ model2.assignments.collect().foreach { a =>
+ predictions2(a.cluster) += a.id
}
assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 68128284b8608..bd5b9cc3afa10 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -46,8 +46,8 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
- val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
- (items.toSet, count)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
}
val expected = Set(
(Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
@@ -96,10 +96,10 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
.setMinSupport(0.5)
.setNumPartitions(2)
.run(rdd)
- assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass,
+ assert(model3.freqItemsets.first().items.getClass === Array(1).getClass,
"frequent itemsets should use primitive arrays")
- val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
- (items.toSet, count)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
}
val expected = Set(
(Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
index b0b78acd6df16..002cb253862b5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala
@@ -166,6 +166,14 @@ class BLASSuite extends FunSuite {
syr(alpha, y, dA)
}
}
+
+ val xSparse = new SparseVector(4, Array(0, 2, 3), Array(1.0, 3.0, 4.0))
+ val dD = new DenseMatrix(4, 4,
+ Array(0.0, 1.2, 2.2, 3.1, 1.2, 3.2, 5.3, 4.6, 2.2, 5.3, 1.8, 3.0, 3.1, 4.6, 3.0, 0.8))
+ syr(0.1, xSparse, dD)
+ val expectedSparse = new DenseMatrix(4, 4,
+ Array(0.1, 1.2, 2.5, 3.5, 1.2, 3.2, 5.3, 4.6, 2.5, 5.3, 2.7, 4.2, 3.5, 4.6, 4.2, 2.4))
+ assert(dD ~== expectedSparse absTol 1e-15)
}
test("gemm") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 9347eaf9221a8..4c162df810bb2 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -29,8 +29,10 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
+import org.apache.spark.mllib.tree.model._
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
@@ -188,7 +190,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 3)
- assert(bins(0).length === 6)
+ assert(bins(0).length === 0)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
@@ -226,41 +228,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(splits(1)(2).categories.contains(0.0))
assert(splits(1)(2).categories.contains(1.0))
- // Check bins.
-
- assert(bins(0)(0).category === Double.MinValue)
- assert(bins(0)(0).lowSplit.categories.length === 0)
- assert(bins(0)(0).highSplit.categories.length === 1)
- assert(bins(0)(0).highSplit.categories.contains(0.0))
- assert(bins(1)(0).category === Double.MinValue)
- assert(bins(1)(0).lowSplit.categories.length === 0)
- assert(bins(1)(0).highSplit.categories.length === 1)
- assert(bins(1)(0).highSplit.categories.contains(0.0))
-
- assert(bins(0)(1).category === Double.MinValue)
- assert(bins(0)(1).lowSplit.categories.length === 1)
- assert(bins(0)(1).lowSplit.categories.contains(0.0))
- assert(bins(0)(1).highSplit.categories.length === 1)
- assert(bins(0)(1).highSplit.categories.contains(1.0))
- assert(bins(1)(1).category === Double.MinValue)
- assert(bins(1)(1).lowSplit.categories.length === 1)
- assert(bins(1)(1).lowSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.length === 1)
- assert(bins(1)(1).highSplit.categories.contains(1.0))
-
- assert(bins(0)(2).category === Double.MinValue)
- assert(bins(0)(2).lowSplit.categories.length === 1)
- assert(bins(0)(2).lowSplit.categories.contains(1.0))
- assert(bins(0)(2).highSplit.categories.length === 2)
- assert(bins(0)(2).highSplit.categories.contains(1.0))
- assert(bins(0)(2).highSplit.categories.contains(0.0))
- assert(bins(1)(2).category === Double.MinValue)
- assert(bins(1)(2).lowSplit.categories.length === 1)
- assert(bins(1)(2).lowSplit.categories.contains(1.0))
- assert(bins(1)(2).highSplit.categories.length === 2)
- assert(bins(1)(2).highSplit.categories.contains(1.0))
- assert(bins(1)(2).highSplit.categories.contains(0.0))
-
}
test("Multiclass classification with ordered categorical features: split and bin calculations") {
@@ -857,9 +824,32 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(topNode.leftNode.get.impurity === 0.0)
assert(topNode.rightNode.get.impurity === 0.0)
}
+
+ test("Node.subtreeIterator") {
+ val model = DecisionTreeSuite.createModel(Classification)
+ val nodeIds = model.topNode.subtreeIterator.map(_.id).toArray.sorted
+ assert(nodeIds === DecisionTreeSuite.createdModelNodeIds)
+ }
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(Classification, Regression).foreach { algo =>
+ val model = DecisionTreeSuite.createModel(algo)
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = DecisionTreeModel.load(sc, path)
+ DecisionTreeSuite.checkEqual(model, sameModel)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
}
-object DecisionTreeSuite {
+object DecisionTreeSuite extends FunSuite {
def validateClassifier(
model: DecisionTreeModel,
@@ -979,4 +969,95 @@ object DecisionTreeSuite {
arr
}
+ /** Create a leaf node with the given node ID */
+ private def createLeafNode(id: Int): Node = {
+ Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = true)
+ }
+
+ /**
+ * Create an internal node with the given node ID and feature type.
+ * Note: This does NOT set the child nodes.
+ */
+ private def createInternalNode(id: Int, featureType: FeatureType): Node = {
+ val node = Node(nodeIndex = id, new Predict(0.0, 1.0), impurity = 0.5, isLeaf = false)
+ featureType match {
+ case Continuous =>
+ node.split = Some(new Split(feature = 0, threshold = 0.5, Continuous,
+ categories = List.empty[Double]))
+ case Categorical =>
+ node.split = Some(new Split(feature = 1, threshold = 0.0, Categorical,
+ categories = List(0.0, 1.0)))
+ }
+ // TODO: The information gain stats should be consistent with the same info stored in children.
+ node.stats = Some(new InformationGainStats(gain = 0.1, impurity = 0.2,
+ leftImpurity = 0.3, rightImpurity = 0.4, new Predict(1.0, 0.4), new Predict(0.0, 0.6)))
+ node
+ }
+
+ /**
+ * Create a tree model. This is deterministic and contains a variety of node and feature types.
+ */
+ private[tree] def createModel(algo: Algo): DecisionTreeModel = {
+ val topNode = createInternalNode(id = 1, Continuous)
+ val (node2, node3) = (createLeafNode(id = 2), createInternalNode(id = 3, Categorical))
+ val (node6, node7) = (createLeafNode(id = 6), createLeafNode(id = 7))
+ topNode.leftNode = Some(node2)
+ topNode.rightNode = Some(node3)
+ node3.leftNode = Some(node6)
+ node3.rightNode = Some(node7)
+ new DecisionTreeModel(topNode, algo)
+ }
+
+ /** Sorted Node IDs matching the model returned by [[createModel()]] */
+ private val createdModelNodeIds = Array(1, 2, 3, 6, 7)
+
+ /**
+ * Check if the two trees are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ * If the trees are not equal, this prints the two trees and throws an exception.
+ */
+ private[tree] def checkEqual(a: DecisionTreeModel, b: DecisionTreeModel): Unit = {
+ try {
+ assert(a.algo === b.algo)
+ checkEqual(a.topNode, b.topNode)
+ } catch {
+ case ex: Exception =>
+ throw new AssertionError("checkEqual failed since the two trees were not identical.\n" +
+ "TREE A:\n" + a.toDebugString + "\n" +
+ "TREE B:\n" + b.toDebugString + "\n", ex)
+ }
+ }
+
+ /**
+ * Return true iff the two nodes and their descendents are exactly the same.
+ * Note: I hesitate to override Node.equals since it could cause problems if users
+ * make mistakes such as creating loops of Nodes.
+ */
+ private def checkEqual(a: Node, b: Node): Unit = {
+ assert(a.id === b.id)
+ assert(a.predict === b.predict)
+ assert(a.impurity === b.impurity)
+ assert(a.isLeaf === b.isLeaf)
+ assert(a.split === b.split)
+ (a.stats, b.stats) match {
+ // TODO: Check other fields besides the infomation gain.
+ case (Some(aStats), Some(bStats)) => assert(aStats.gain === bStats.gain)
+ case (None, None) =>
+ case _ => throw new AssertionError(
+ s"Only one instance has stats defined. (a.stats: ${a.stats}, b.stats: ${b.stats})")
+ }
+ (a.leftNode, b.leftNode) match {
+ case (Some(aNode), Some(bNode)) => checkEqual(aNode, bNode)
+ case (None, None) =>
+ case _ => throw new AssertionError("Only one instance has leftNode defined. " +
+ s"(a.leftNode: ${a.leftNode}, b.leftNode: ${b.leftNode})")
+ }
+ (a.rightNode, b.rightNode) match {
+ case (Some(aNode: Node), Some(bNode: Node)) => checkEqual(aNode, bNode)
+ case (None, None) =>
+ case _ => throw new AssertionError("Only one instance has rightNode defined. " +
+ s"(a.rightNode: ${a.rightNode}, b.rightNode: ${b.rightNode})")
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index e8341a5d0d104..b437aeaaf0547 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -24,8 +24,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy}
import org.apache.spark.mllib.tree.impurity.Variance
import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss}
-
+import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
/**
* Test suite for [[GradientBoostedTrees]].
@@ -35,32 +37,30 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
test("Regression with continuous features: SquaredError") {
GradientBoostedTreesSuite.testCombinations.foreach {
case (numIterations, learningRate, subsamplingRate) =>
- GradientBoostedTreesSuite.randomSeeds.foreach { randomSeed =>
- val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
-
- val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
- categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
- val boostingStrategy =
- new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
-
- val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
-
- assert(gbt.trees.size === numIterations)
- try {
- EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
- } catch {
- case e: java.lang.AssertionError =>
- println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
- s" subsamplingRate=$subsamplingRate")
- throw e
- }
+ val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+ val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty, subsamplingRate = subsamplingRate)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, SquaredError, numIterations, learningRate)
- val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
- val dt = DecisionTree.train(remappedInput, treeStrategy)
+ val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
- // Make sure trees are the same.
- assert(gbt.trees.head.toString == dt.toString)
+ assert(gbt.trees.size === numIterations)
+ try {
+ EnsembleTestHelper.validateRegressor(gbt, GradientBoostedTreesSuite.data, 0.06)
+ } catch {
+ case e: java.lang.AssertionError =>
+ println(s"FAILED for numIterations=$numIterations, learningRate=$learningRate," +
+ s" subsamplingRate=$subsamplingRate")
+ throw e
}
+
+ val remappedInput = rdd.map(x => new LabeledPoint((x.label * 2) - 1, x.features))
+ val dt = DecisionTree.train(remappedInput, treeStrategy)
+
+ // Make sure trees are the same.
+ assert(gbt.trees.head.toString == dt.toString)
}
}
@@ -133,14 +133,73 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext {
BoostingStrategy.defaultParams(algo)
}
}
+
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(Regression)).toArray
+ val treeWeights = Array(0.1, 0.3, 1.1)
+
+ Array(Classification, Regression).foreach { algo =>
+ val model = new GradientBoostedTreesModel(algo, trees, treeWeights)
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = GradientBoostedTreesModel.load(sc, path)
+ assert(model.algo == sameModel.algo)
+ model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+ DecisionTreeSuite.checkEqual(treeA, treeB)
+ }
+ assert(model.treeWeights === sameModel.treeWeights)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
+
+ test("runWithValidation stops early and performs better on a validation dataset") {
+ // Set numIterations large enough so that it stops early.
+ val numIterations = 20
+ val trainRdd = sc.parallelize(GradientBoostedTreesSuite.trainData, 2)
+ val validateRdd = sc.parallelize(GradientBoostedTreesSuite.validateData, 2)
+
+ val algos = Array(Regression, Regression, Classification)
+ val losses = Array(SquaredError, AbsoluteError, LogLoss)
+ (algos zip losses) map {
+ case (algo, loss) => {
+ val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+ categoricalFeaturesInfo = Map.empty)
+ val boostingStrategy =
+ new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+ val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+ .runWithValidation(trainRdd, validateRdd)
+ assert(gbtValidate.numTrees !== numIterations)
+
+ // Test that it performs better on the validation dataset.
+ val gbt = GradientBoostedTrees.train(trainRdd, boostingStrategy)
+ val (errorWithoutValidation, errorWithValidation) = {
+ if (algo == Classification) {
+ val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+ (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
+ } else {
+ (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
+ }
+ }
+ assert(errorWithValidation <= errorWithoutValidation)
+ }
+ }
+ }
+
}
-object GradientBoostedTreesSuite {
+private object GradientBoostedTreesSuite {
// Combinations for estimators, learning rates and subsamplingRate
val testCombinations = Array((10, 1.0, 1.0), (10, 0.1, 1.0), (10, 0.5, 0.75), (10, 0.1, 0.75))
- val randomSeeds = Array(681283, 4398)
-
val data = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 10, 100)
+ val trainData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 120)
+ val validateData = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 20, 80)
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 55e963977b54f..ee3bc98486862 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -27,8 +27,10 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata
import org.apache.spark.mllib.tree.impurity.{Gini, Variance}
-import org.apache.spark.mllib.tree.model.Node
+import org.apache.spark.mllib.tree.model.{Node, RandomForestModel}
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
+
/**
* Test suite for [[RandomForest]].
@@ -212,6 +214,26 @@ class RandomForestSuite extends FunSuite with MLlibTestSparkContext {
assert(rf1.toDebugString != rf2.toDebugString)
}
-}
-
+ test("model save/load") {
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ Array(Classification, Regression).foreach { algo =>
+ val trees = Range(0, 3).map(_ => DecisionTreeSuite.createModel(algo)).toArray
+ val model = new RandomForestModel(algo, trees)
+
+ // Save model, load it back, and compare.
+ try {
+ model.save(sc, path)
+ val sameModel = RandomForestModel.load(sc, path)
+ assert(model.algo == sameModel.algo)
+ model.trees.zip(sameModel.trees).foreach { case (treeA, treeB) =>
+ DecisionTreeSuite.checkEqual(treeA, treeB)
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+ }
+}
diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml
index acec8f18f2b5c..39b99f54f6dbc 100644
--- a/network/yarn/pom.xml
+++ b/network/yarn/pom.xml
@@ -33,6 +33,8 @@
http://spark.apache.org/network-yarn
+
+ provided
@@ -47,7 +49,6 @@
org.apache.hadoophadoop-client
- provided
diff --git a/pom.xml b/pom.xml
index f6f176d2004b7..bb355bf735bee 100644
--- a/pom.xml
+++ b/pom.xml
@@ -342,7 +342,7 @@
-
+
@@ -395,7 +395,7 @@
provided
-
+
org.apache.commonscommons-lang3
@@ -404,7 +404,7 @@
commons-codeccommons-codec
- 1.5
+ 1.10org.apache.commons
@@ -619,19 +619,6 @@
2.2.1test
-
- org.easymock
- easymockclassextension
- 3.1
- test
-
-
-
- asm
- asm
- 3.3.1
- test
- org.mockitomockito-all
@@ -1096,6 +1083,12 @@
scala-maven-plugin3.2.0
+
+ eclipse-add-source
+
+ add-source
+
+ scala-compile-firstprocess-resources
@@ -1178,13 +1171,19 @@
${project.build.directory}/surefire-reports-Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m
+
+
+ ${test_classpath}
+ true${session.executionRootDirectory}1falsefalse
- ${test_classpath}truefalse
@@ -1572,7 +1571,7 @@
2.3.02.5.0
- 0.9.2
+ 0.9.30.98.7-hadoop23.1.1hadoop2
@@ -1585,7 +1584,7 @@
2.4.02.5.0
- 0.9.2
+ 0.9.30.98.7-hadoop23.1.1hadoop2
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 4065a562a1a18..ee6229aa6bbe1 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -148,6 +148,11 @@ object MimaExcludes {
ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.mllib.linalg.VectorUDT"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.serialize"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.linalg.VectorUDT.sqlType")
+ ) ++ Seq(
+ // SPARK-4682
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.RealClock"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.Clock"),
+ ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.TestClock")
)
case v if v.startsWith("1.2") =>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 95f8dfa3d270f..e4b1b96527fbd 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -361,9 +361,16 @@ object Unidoc {
publish := {},
unidocProjectFilter in(ScalaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn),
unidocProjectFilter in(JavaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, catalyst, streamingFlumeSink, yarn),
+ inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn),
+
+ // Skip actual catalyst, but include the subproject.
+ // Catalyst is not public API and contains quasiquotes which break scaladoc.
+ unidocAllSources in (ScalaUnidoc, unidoc) := {
+ (unidocAllSources in (ScalaUnidoc, unidoc)).value
+ .map(_.filterNot(_.getCanonicalPath.contains("sql/catalyst")))
+ },
// Skip class names containing $ and some internal packages in Javadocs
unidocAllSources in (JavaUnidoc, unidoc) := {
@@ -376,6 +383,7 @@ object Unidoc {
.map(_.filterNot(_.getCanonicalPath.contains("executor")))
.map(_.filterNot(_.getCanonicalPath.contains("python")))
.map(_.filterNot(_.getCanonicalPath.contains("collection")))
+ .map(_.filterNot(_.getCanonicalPath.contains("sql/catalyst")))
},
// Javadoc options: create a window title, and group key packages on index page
@@ -411,6 +419,10 @@ object TestSettings {
lazy val settings = Seq (
// Fork new JVMs for tests and set Java options for those
fork := true,
+ // Setting SPARK_DIST_CLASSPATH is a simple way to make sure any child processes
+ // launched by the tests have access to the correct test-time classpath.
+ envVars in Test += ("SPARK_DIST_CLASSPATH" ->
+ (fullClasspath in Test).value.files.map(_.getAbsolutePath).mkString(":").stripSuffix(":")),
javaOptions in Test += "-Dspark.test.home=" + sparkHome,
javaOptions in Test += "-Dspark.testing=1",
javaOptions in Test += "-Dspark.port.maxRetries=100",
@@ -423,10 +435,6 @@ object TestSettings {
javaOptions in Test += "-ea",
javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
.split(" ").toSeq,
- // This places test scope jars on the classpath of executors during tests.
- javaOptions in Test +=
- "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files.
- map(_.getAbsolutePath).mkString(":").stripSuffix(":"),
javaOptions += "-Xmx3g",
// Show full stack trace and duration in test cases.
testOptions in Test += Tests.Argument("-oDF"),
diff --git a/python/docs/conf.py b/python/docs/conf.py
index b00dce95d65b4..163987dd8e5fa 100644
--- a/python/docs/conf.py
+++ b/python/docs/conf.py
@@ -48,16 +48,16 @@
# General information about the project.
project = u'PySpark'
-copyright = u'2014, Author'
+copyright = u''
# The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the
# built documents.
#
# The short X.Y version.
-version = '1.3-SNAPSHOT'
+version = 'master'
# The full version, including alpha/beta/rc tags.
-release = '1.3-SNAPSHOT'
+release = os.environ.get('RELEASE_VERSION', version)
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
@@ -97,6 +97,10 @@
# If true, keep warnings as "system message" paragraphs in the built documents.
#keep_warnings = False
+# -- Options for autodoc --------------------------------------------------
+
+# Look at the first line of the docstring for function and method signatures.
+autodoc_docstring_signature = True
# -- Options for HTML output ----------------------------------------------
diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst
index 21f66ca344a3c..b706c5e376ef4 100644
--- a/python/docs/pyspark.mllib.rst
+++ b/python/docs/pyspark.mllib.rst
@@ -7,7 +7,7 @@ pyspark.mllib.classification module
.. automodule:: pyspark.mllib.classification
:members:
:undoc-members:
- :show-inheritance:
+ :inherited-members:
pyspark.mllib.clustering module
-------------------------------
@@ -15,7 +15,6 @@ pyspark.mllib.clustering module
.. automodule:: pyspark.mllib.clustering
:members:
:undoc-members:
- :show-inheritance:
pyspark.mllib.feature module
-------------------------------
@@ -39,7 +38,6 @@ pyspark.mllib.random module
.. automodule:: pyspark.mllib.random
:members:
:undoc-members:
- :show-inheritance:
pyspark.mllib.recommendation module
-----------------------------------
@@ -47,7 +45,6 @@ pyspark.mllib.recommendation module
.. automodule:: pyspark.mllib.recommendation
:members:
:undoc-members:
- :show-inheritance:
pyspark.mllib.regression module
-------------------------------
@@ -55,7 +52,7 @@ pyspark.mllib.regression module
.. automodule:: pyspark.mllib.regression
:members:
:undoc-members:
- :show-inheritance:
+ :inherited-members:
pyspark.mllib.stat module
-------------------------
@@ -63,7 +60,6 @@ pyspark.mllib.stat module
.. automodule:: pyspark.mllib.stat
:members:
:undoc-members:
- :show-inheritance:
pyspark.mllib.tree module
-------------------------
@@ -71,7 +67,7 @@ pyspark.mllib.tree module
.. automodule:: pyspark.mllib.tree
:members:
:undoc-members:
- :show-inheritance:
+ :inherited-members:
pyspark.mllib.util module
-------------------------
@@ -79,4 +75,3 @@ pyspark.mllib.util module
.. automodule:: pyspark.mllib.util
:members:
:undoc-members:
- :show-inheritance:
diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst
index 80c6f02a9df41..6259379ed05b7 100644
--- a/python/docs/pyspark.sql.rst
+++ b/python/docs/pyspark.sql.rst
@@ -7,7 +7,6 @@ Module Context
.. automodule:: pyspark.sql
:members:
:undoc-members:
- :show-inheritance:
pyspark.sql.types module
@@ -15,4 +14,10 @@ pyspark.sql.types module
.. automodule:: pyspark.sql.types
:members:
:undoc-members:
- :show-inheritance:
+
+
+pyspark.sql.functions module
+----------------------------
+.. automodule:: pyspark.sql.functions
+ :members:
+ :undoc-members:
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index d3efcdf221d82..5f70ac6ed8fe6 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -22,17 +22,17 @@
- :class:`SparkContext`:
Main entry point for Spark functionality.
- - L{RDD}
+ - :class:`RDD`:
A Resilient Distributed Dataset (RDD), the basic abstraction in Spark.
- - L{Broadcast}
+ - :class:`Broadcast`:
A broadcast variable that gets reused across tasks.
- - L{Accumulator}
+ - :class:`Accumulator`:
An "add-only" shared variable that tasks can only add values to.
- - L{SparkConf}
+ - :class:`SparkConf`:
For configuring Spark.
- - L{SparkFiles}
+ - :class:`SparkFiles`:
Access files shipped with jobs.
- - L{StorageLevel}
+ - :class:`StorageLevel`:
Finer-grained cache persistence levels.
"""
@@ -45,6 +45,7 @@
from pyspark.accumulators import Accumulator, AccumulatorParam
from pyspark.broadcast import Broadcast
from pyspark.serializers import MarshalSerializer, PickleSerializer
+from pyspark.status import *
from pyspark.profiler import Profiler, BasicProfiler
# for back compatibility
@@ -53,5 +54,5 @@
__all__ = [
"SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast",
"Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer",
- "Profiler", "BasicProfiler",
+ "StatusTracker", "SparkJobInfo", "SparkStageInfo", "Profiler", "BasicProfiler",
]
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index bf1f61c8504ed..6011caf9f1c5a 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -32,6 +32,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
+from pyspark.status import StatusTracker
from pyspark.profiler import ProfilerCollector, BasicProfiler
from py4j.java_collections import ListConverter
@@ -64,6 +65,8 @@ class SparkContext(object):
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
+ PACKAGE_EXTENSIONS = ('.zip', '.egg', '.jar')
+
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
gateway=None, jsc=None, profiler_cls=BasicProfiler):
@@ -185,7 +188,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
- if filename.lower().endswith("zip") or filename.lower().endswith("egg"):
+ if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
self._python_includes.append(filename)
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
@@ -705,7 +708,7 @@ def addPyFile(self, path):
self.addFile(path)
(dirname, filename) = os.path.split(path) # dirname may be directory or HDFS/S3 prefix
- if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
+ if filename[-4:].lower() in self.PACKAGE_EXTENSIONS:
self._python_includes.append(filename)
# for tests in local mode
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))
@@ -808,6 +811,12 @@ def cancelAllJobs(self):
"""
self._jsc.sc().cancelAllJobs()
+ def statusTracker(self):
+ """
+ Return :class:`StatusTracker` object
+ """
+ return StatusTracker(self._jsc.statusTracker())
+
def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False):
"""
Executes the given partitionFunc on the specified set of partitions,
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index a0a028446d5fd..936857e75c7e9 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -17,19 +17,20 @@
import atexit
import os
-import sys
+import select
import signal
import shlex
+import socket
import platform
from subprocess import Popen, PIPE
-from threading import Thread
from py4j.java_gateway import java_import, JavaGateway, GatewayClient
+from pyspark.serializers import read_int
+
def launch_gateway():
SPARK_HOME = os.environ["SPARK_HOME"]
- gateway_port = -1
if "PYSPARK_GATEWAY_PORT" in os.environ:
gateway_port = int(os.environ["PYSPARK_GATEWAY_PORT"])
else:
@@ -41,36 +42,42 @@ def launch_gateway():
submit_args = submit_args if submit_args is not None else ""
submit_args = shlex.split(submit_args)
command = [os.path.join(SPARK_HOME, script)] + submit_args + ["pyspark-shell"]
+
+ # Start a socket that will be used by PythonGatewayServer to communicate its port to us
+ callback_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ callback_socket.bind(('127.0.0.1', 0))
+ callback_socket.listen(1)
+ callback_host, callback_port = callback_socket.getsockname()
+ env = dict(os.environ)
+ env['_PYSPARK_DRIVER_CALLBACK_HOST'] = callback_host
+ env['_PYSPARK_DRIVER_CALLBACK_PORT'] = str(callback_port)
+
+ # Launch the Java gateway.
+ # We open a pipe to stdin so that the Java gateway can die when the pipe is broken
if not on_windows:
# Don't send ctrl-c / SIGINT to the Java gateway:
def preexec_func():
signal.signal(signal.SIGINT, signal.SIG_IGN)
- env = dict(os.environ)
env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits
- proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env)
+ proc = Popen(command, stdin=PIPE, preexec_fn=preexec_func, env=env)
else:
# preexec_fn not supported on Windows
- proc = Popen(command, stdout=PIPE, stdin=PIPE)
+ proc = Popen(command, stdin=PIPE, env=env)
- try:
- # Determine which ephemeral port the server started on:
- gateway_port = proc.stdout.readline()
- gateway_port = int(gateway_port)
- except ValueError:
- # Grab the remaining lines of stdout
- (stdout, _) = proc.communicate()
- exit_code = proc.poll()
- error_msg = "Launching GatewayServer failed"
- error_msg += " with exit code %d!\n" % exit_code if exit_code else "!\n"
- error_msg += "Warning: Expected GatewayServer to output a port, but found "
- if gateway_port == "" and stdout == "":
- error_msg += "no output.\n"
- else:
- error_msg += "the following:\n\n"
- error_msg += "--------------------------------------------------------------\n"
- error_msg += gateway_port + stdout
- error_msg += "--------------------------------------------------------------\n"
- raise Exception(error_msg)
+ gateway_port = None
+ # We use select() here in order to avoid blocking indefinitely if the subprocess dies
+ # before connecting
+ while gateway_port is None and proc.poll() is None:
+ timeout = 1 # (seconds)
+ readable, _, _ = select.select([callback_socket], [], [], timeout)
+ if callback_socket in readable:
+ gateway_connection = callback_socket.accept()[0]
+ # Determine which ephemeral port the server started on:
+ gateway_port = read_int(gateway_connection.makefile())
+ gateway_connection.close()
+ callback_socket.close()
+ if gateway_port is None:
+ raise Exception("Java gateway process exited before sending the driver its port number")
# In Windows, ensure the Java child processes do not linger after Python has exited.
# In UNIX-based systems, the child process can kill itself on broken pipe (i.e. when
@@ -88,21 +95,6 @@ def killChild():
Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)])
atexit.register(killChild)
- # Create a thread to echo output from the GatewayServer, which is required
- # for Java log output to show up:
- class EchoOutputThread(Thread):
-
- def __init__(self, stream):
- Thread.__init__(self)
- self.daemon = True
- self.stream = stream
-
- def run(self):
- while True:
- line = self.stream.readline()
- sys.stderr.write(line)
- EchoOutputThread(proc.stdout).start()
-
# Connect to the gateway
gateway = JavaGateway(GatewayClient(port=gateway_port), auto_convert=False)
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index b4a844713745a..efc1ef9396412 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -35,8 +35,8 @@
def _do_python_join(rdd, other, numPartitions, dispatch):
- vs = rdd.map(lambda (k, v): (k, (1, v)))
- ws = other.map(lambda (k, v): (k, (2, v)))
+ vs = rdd.mapValues(lambda v: (1, v))
+ ws = other.mapValues(lambda v: (2, v))
return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__()))
@@ -98,8 +98,8 @@ def dispatch(seq):
def python_cogroup(rdds, numPartitions):
def make_mapper(i):
- return lambda (k, v): (k, (i, v))
- vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
+ return lambda v: (i, v)
+ vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)]
union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
rdd_len = len(vrdds)
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 6bd2aa8e47837..4ff7463498cce 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -15,10 +15,11 @@
# limitations under the License.
#
-from pyspark.ml.util import inherit_doc
+from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel
from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\
HasRegParam
+from pyspark.mllib.common import inherit_doc
__all__ = ['LogisticRegression', 'LogisticRegressionModel']
@@ -32,22 +33,46 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
>>> from pyspark.sql import Row
>>> from pyspark.mllib.linalg import Vectors
- >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \
- Row(label=1.0, features=Vectors.dense(1.0)), \
- Row(label=0.0, features=Vectors.sparse(1, [], []))]))
- >>> lr = LogisticRegression() \
- .setMaxIter(5) \
- .setRegParam(0.01)
- >>> model = lr.fit(dataset)
- >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))]))
+ >>> df = sc.parallelize([
+ ... Row(label=1.0, features=Vectors.dense(1.0)),
+ ... Row(label=0.0, features=Vectors.sparse(1, [], []))]).toDF()
+ >>> lr = LogisticRegression(maxIter=5, regParam=0.01)
+ >>> model = lr.fit(df)
+ >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF()
>>> print model.transform(test0).head().prediction
0.0
- >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]))
+ >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF()
>>> print model.transform(test1).head().prediction
1.0
+ >>> lr.setParams("vector")
+ Traceback (most recent call last):
+ ...
+ TypeError: Method setParams forces keyword arguments.
"""
_java_class = "org.apache.spark.ml.classification.LogisticRegression"
+ @keyword_only
+ def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxIter=100, regParam=0.1):
+ """
+ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ maxIter=100, regParam=0.1)
+ """
+ super(LogisticRegression, self).__init__()
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
+ maxIter=100, regParam=0.1):
+ """
+ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
+ maxIter=100, regParam=0.1)
+ Sets params for logistic regression.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set_params(**kwargs)
+
def _create_model(self, java_model):
return LogisticRegressionModel(java_model)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index e088acd0ca82d..433b4fb5d22bf 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -16,8 +16,9 @@
#
from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
-from pyspark.ml.util import inherit_doc
+from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaTransformer
+from pyspark.mllib.common import inherit_doc
__all__ = ['Tokenizer', 'HashingTF']
@@ -29,18 +30,45 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
splits it by white spaces.
>>> from pyspark.sql import Row
- >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")]))
- >>> tokenizer = Tokenizer() \
- .setInputCol("text") \
- .setOutputCol("words")
- >>> print tokenizer.transform(dataset).head()
+ >>> df = sc.parallelize([Row(text="a b c")]).toDF()
+ >>> tokenizer = Tokenizer(inputCol="text", outputCol="words")
+ >>> print tokenizer.transform(df).head()
Row(text=u'a b c', words=[u'a', u'b', u'c'])
- >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head()
+ >>> # Change a parameter.
+ >>> print tokenizer.setParams(outputCol="tokens").transform(df).head()
Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
+ >>> # Temporarily modify a parameter.
+ >>> print tokenizer.transform(df, {tokenizer.outputCol: "words"}).head()
+ Row(text=u'a b c', words=[u'a', u'b', u'c'])
+ >>> print tokenizer.transform(df).head()
+ Row(text=u'a b c', tokens=[u'a', u'b', u'c'])
+ >>> # Must use keyword arguments to specify params.
+ >>> tokenizer.setParams("text")
+ Traceback (most recent call last):
+ ...
+ TypeError: Method setParams forces keyword arguments.
"""
_java_class = "org.apache.spark.ml.feature.Tokenizer"
+ @keyword_only
+ def __init__(self, inputCol="input", outputCol="output"):
+ """
+ __init__(self, inputCol="input", outputCol="output")
+ """
+ super(Tokenizer, self).__init__()
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, inputCol="input", outputCol="output"):
+ """
+ setParams(self, inputCol="input", outputCol="output")
+ Sets params for this Tokenizer.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set_params(**kwargs)
+
@inherit_doc
class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
@@ -49,20 +77,37 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
hashing trick.
>>> from pyspark.sql import Row
- >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])]))
- >>> hashingTF = HashingTF() \
- .setNumFeatures(10) \
- .setInputCol("words") \
- .setOutputCol("features")
- >>> print hashingTF.transform(dataset).head().features
+ >>> df = sc.parallelize([Row(words=["a", "b", "c"])]).toDF()
+ >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+ >>> print hashingTF.transform(df).head().features
+ (10,[7,8,9],[1.0,1.0,1.0])
+ >>> print hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
(10,[7,8,9],[1.0,1.0,1.0])
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
- >>> print hashingTF.transform(dataset, params).head().vector
+ >>> print hashingTF.transform(df, params).head().vector
(5,[2,3,4],[1.0,1.0,1.0])
"""
_java_class = "org.apache.spark.ml.feature.HashingTF"
+ @keyword_only
+ def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
+ """
+ __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
+ """
+ super(HashingTF, self).__init__()
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
+ """
+ setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
+ Sets params for this HashingTF.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set_params(**kwargs)
+
if __name__ == "__main__":
import doctest
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 5566792cead48..e3a53dd780c4c 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -80,3 +80,11 @@ def _dummy():
dummy = Params()
dummy.uid = "undefined"
return dummy
+
+ def _set_params(self, **kwargs):
+ """
+ Sets params.
+ """
+ for param, value in kwargs.iteritems():
+ self.paramMap[getattr(self, param)] = value
+ return self
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 2d239f8c802a0..5233c5801e2e6 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -18,7 +18,8 @@
from abc import ABCMeta, abstractmethod
from pyspark.ml.param import Param, Params
-from pyspark.ml.util import inherit_doc
+from pyspark.ml.util import keyword_only
+from pyspark.mllib.common import inherit_doc
__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']
@@ -89,10 +90,16 @@ class Pipeline(Estimator):
identity transformer.
"""
- def __init__(self):
+ @keyword_only
+ def __init__(self, stages=[]):
+ """
+ __init__(self, stages=[])
+ """
super(Pipeline, self).__init__()
#: Param for pipeline stages.
self.stages = Param(self, "stages", "pipeline stages")
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
def setStages(self, value):
"""
@@ -110,6 +117,15 @@ def getStages(self):
if self.stages in self.paramMap:
return self.paramMap[self.stages]
+ @keyword_only
+ def setParams(self, stages=[]):
+ """
+ setParams(self, stages=[])
+ Sets params for Pipeline.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set_params(**kwargs)
+
def fit(self, dataset, params={}):
paramMap = self._merge_params(params)
stages = paramMap[self.stages]
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index b1caa84b6306a..6f7f39c40eb5a 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -15,21 +15,22 @@
# limitations under the License.
#
+from functools import wraps
import uuid
-def inherit_doc(cls):
- for name, func in vars(cls).items():
- # only inherit docstring for public functions
- if name.startswith("_"):
- continue
- if not func.__doc__:
- for parent in cls.__bases__:
- parent_func = getattr(parent, name, None)
- if parent_func and getattr(parent_func, "__doc__", None):
- func.__doc__ = parent_func.__doc__
- break
- return cls
+def keyword_only(func):
+ """
+ A decorator that forces keyword arguments in the wrapped method
+ and saves actual input keyword arguments in `_input_kwargs`.
+ """
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ if len(args) > 1:
+ raise TypeError("Method %s forces keyword arguments." % func.__name__)
+ wrapper._input_kwargs = kwargs
+ return func(*args, **kwargs)
+ return wrapper
class Identifiable(object):
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 9e12ddc3d9b8f..4bae96f678388 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -21,7 +21,7 @@
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
from pyspark.ml.pipeline import Estimator, Transformer
-from pyspark.ml.util import inherit_doc
+from pyspark.mllib.common import inherit_doc
def _jvm():
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
index c3217620e3c4e..6449800d9c120 100644
--- a/python/pyspark/mllib/__init__.py
+++ b/python/pyspark/mllib/__init__.py
@@ -19,7 +19,7 @@
Python bindings for MLlib.
"""
-# MLlib currently needs and NumPy 1.4+, so complain if lower
+# MLlib currently needs NumPy 1.4+, so complain if lower
import numpy
if numpy.version.version < '1.4':
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index f6b97abb1723c..949db5705abd7 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -152,7 +152,7 @@ def predictSoft(self, x):
class GaussianMixture(object):
"""
- Estimate model parameters with the expectation-maximization algorithm.
+ Learning algorithm for Gaussian Mixtures using the expectation-maximization algorithm.
:param data: RDD of data points
:param k: Number of components
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index 3c5ee66cd8b64..621591c26b77f 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -134,3 +134,20 @@ def __del__(self):
def call(self, name, *a):
"""Call method of java_model"""
return callJavaFunc(self._sc, getattr(self._java_model, name), *a)
+
+
+def inherit_doc(cls):
+ """
+ A decorator that makes a class inherit documentation from its parents.
+ """
+ for name, func in vars(cls).items():
+ # only inherit docstring for public functions
+ if name.startswith("_"):
+ continue
+ if not func.__doc__:
+ for parent in cls.__bases__:
+ parent_func = getattr(parent, name, None)
+ if parent_func and getattr(parent_func, "__doc__", None):
+ func.__doc__ = parent_func.__doc__
+ break
+ return cls
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 10df6288065b8..0ffe092a07365 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -58,7 +58,8 @@ class Normalizer(VectorTransformer):
For any 1 <= `p` < float('inf'), normalizes samples using
sum(abs(vector) :sup:`p`) :sup:`(1/p)` as norm.
- For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization.
+ For `p` = float('inf'), max(abs(vector)) will be used as norm for
+ normalization.
>>> v = Vectors.dense(range(3))
>>> nor = Normalizer(1)
@@ -120,9 +121,14 @@ def transform(self, vector):
"""
Applies standardization transformation on a vector.
+ Note: In Python, transform cannot currently be used within
+ an RDD transformation or action.
+ Call transform directly on the RDD instead.
+
:param vector: Vector or RDD of Vector to be standardized.
- :return: Standardized vector. If the variance of a column is zero,
- it will return default `0.0` for the column with zero variance.
+ :return: Standardized vector. If the variance of a column is
+ zero, it will return default `0.0` for the column with
+ zero variance.
"""
return JavaVectorTransformer.transform(self, vector)
@@ -148,9 +154,10 @@ def __init__(self, withMean=False, withStd=True):
"""
:param withMean: False by default. Centers the data with mean
before scaling. It will build a dense output, so this
- does not work on sparse input and will raise an exception.
- :param withStd: True by default. Scales the data to unit standard
- deviation.
+ does not work on sparse input and will raise an
+ exception.
+ :param withStd: True by default. Scales the data to unit
+ standard deviation.
"""
if not (withMean or withStd):
warnings.warn("Both withMean and withStd are false. The model does nothing.")
@@ -159,10 +166,11 @@ def __init__(self, withMean=False, withStd=True):
def fit(self, dataset):
"""
- Computes the mean and variance and stores as a model to be used for later scaling.
+ Computes the mean and variance and stores as a model to be used
+ for later scaling.
- :param data: The data used to compute the mean and variance to build
- the transformation model.
+ :param data: The data used to compute the mean and variance
+ to build the transformation model.
:return: a StandardScalarModel
"""
dataset = dataset.map(_convert_to_vector)
@@ -174,7 +182,8 @@ class HashingTF(object):
"""
.. note:: Experimental
- Maps a sequence of terms to their term frequencies using the hashing trick.
+ Maps a sequence of terms to their term frequencies using the hashing
+ trick.
Note: the terms must be hashable (can not be dict/set/list...).
@@ -195,8 +204,9 @@ def indexOf(self, term):
def transform(self, document):
"""
- Transforms the input document (list of terms) to term frequency vectors,
- or transform the RDD of document to RDD of term frequency vectors.
+ Transforms the input document (list of terms) to term frequency
+ vectors, or transform the RDD of document to RDD of term
+ frequency vectors.
"""
if isinstance(document, RDD):
return document.map(self.transform)
@@ -220,7 +230,12 @@ def transform(self, x):
the terms which occur in fewer than `minDocFreq`
documents will have an entry of 0.
- :param x: an RDD of term frequency vectors or a term frequency vector
+ Note: In Python, transform cannot currently be used within
+ an RDD transformation or action.
+ Call transform directly on the RDD instead.
+
+ :param x: an RDD of term frequency vectors or a term frequency
+ vector
:return: an RDD of TF-IDF vectors or a TF-IDF vector
"""
if isinstance(x, RDD):
@@ -241,9 +256,9 @@ class IDF(object):
of documents that contain term `t`.
This implementation supports filtering out terms which do not appear
- in a minimum number of documents (controlled by the variable `minDocFreq`).
- For terms that are not in at least `minDocFreq` documents, the IDF is
- found as 0, resulting in TF-IDFs of 0.
+ in a minimum number of documents (controlled by the variable
+ `minDocFreq`). For terms that are not in at least `minDocFreq`
+ documents, the IDF is found as 0, resulting in TF-IDFs of 0.
>>> n = 4
>>> freqs = [Vectors.sparse(n, (1, 3), (1.0, 2.0)),
@@ -325,15 +340,16 @@ class Word2Vec(object):
The vector representation can be used as features in
natural language processing and machine learning algorithms.
- We used skip-gram model in our implementation and hierarchical softmax
- method to train the model. The variable names in the implementation
- matches the original C implementation.
+ We used skip-gram model in our implementation and hierarchical
+ softmax method to train the model. The variable names in the
+ implementation matches the original C implementation.
- For original C implementation, see https://code.google.com/p/word2vec/
+ For original C implementation,
+ see https://code.google.com/p/word2vec/
For research papers, see
Efficient Estimation of Word Representations in Vector Space
- and
- Distributed Representations of Words and Phrases and their Compositionality.
+ and Distributed Representations of Words and Phrases and their
+ Compositionality.
>>> sentence = "a b " * 100 + "a c " * 10
>>> localDoc = [sentence, sentence]
@@ -374,15 +390,16 @@ def setLearningRate(self, learningRate):
def setNumPartitions(self, numPartitions):
"""
- Sets number of partitions (default: 1). Use a small number for accuracy.
+ Sets number of partitions (default: 1). Use a small number for
+ accuracy.
"""
self.numPartitions = numPartitions
return self
def setNumIterations(self, numIterations):
"""
- Sets number of iterations (default: 1), which should be smaller than or equal to number of
- partitions.
+ Sets number of iterations (default: 1), which should be smaller
+ than or equal to number of partitions.
"""
self.numIterations = numIterations
return self
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 210060140fd91..66617abb85670 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,7 +18,7 @@
import numpy as np
from numpy import array
-from pyspark.mllib.common import callMLlibFunc
+from pyspark.mllib.common import callMLlibFunc, inherit_doc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
@@ -31,8 +31,11 @@ class LabeledPoint(object):
The features and labels of a data point.
:param label: Label for this data point.
- :param features: Vector of features for this point (NumPy array, list,
- pyspark.mllib.linalg.SparseVector, or scipy.sparse column matrix)
+ :param features: Vector of features for this point (NumPy array,
+ list, pyspark.mllib.linalg.SparseVector, or scipy.sparse
+ column matrix)
+
+ Note: 'label' and 'features' are accessible as class attributes.
"""
def __init__(self, label, features):
@@ -69,6 +72,7 @@ def __repr__(self):
return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept)
+@inherit_doc
class LinearRegressionModelBase(LinearModel):
"""A linear regression model.
@@ -89,6 +93,7 @@ def predict(self, x):
return self.weights.dot(x) + self.intercept
+@inherit_doc
class LinearRegressionModel(LinearRegressionModelBase):
"""A linear regression model derived from a least-squares fit.
@@ -162,7 +167,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
- are activated or not).
+ are activated or not). (default: False)
"""
def train(rdd, i):
return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations),
@@ -172,6 +177,7 @@ def train(rdd, i):
return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights)
+@inherit_doc
class LassoModel(LinearRegressionModelBase):
"""A linear regression model derived from a least-squares fit with an
@@ -218,6 +224,7 @@ def train(rdd, i):
return _regression_train_wrapper(train, LassoModel, data, initialWeights)
+@inherit_doc
class RidgeRegressionModel(LinearRegressionModelBase):
"""A linear regression model derived from a least-squares fit with an
diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py
index b686d955a0080..e3e128513e0d7 100644
--- a/python/pyspark/mllib/stat/__init__.py
+++ b/python/pyspark/mllib/stat/__init__.py
@@ -21,5 +21,7 @@
from pyspark.mllib.stat._statistics import *
from pyspark.mllib.stat.distribution import MultivariateGaussian
+from pyspark.mllib.stat.test import ChiSqTestResult
-__all__ = ["Statistics", "MultivariateStatisticalSummary", "MultivariateGaussian"]
+__all__ = ["Statistics", "MultivariateStatisticalSummary", "ChiSqTestResult",
+ "MultivariateGaussian"]
diff --git a/python/pyspark/mllib/stat/distribution.py b/python/pyspark/mllib/stat/distribution.py
index 07792e1532046..46f7a1d2f277a 100644
--- a/python/pyspark/mllib/stat/distribution.py
+++ b/python/pyspark/mllib/stat/distribution.py
@@ -22,7 +22,8 @@
class MultivariateGaussian(namedtuple('MultivariateGaussian', ['mu', 'sigma'])):
- """ Represents a (mu, sigma) tuple
+ """Represents a (mu, sigma) tuple
+
>>> m = MultivariateGaussian(Vectors.dense([11,12]),DenseMatrix(2, 2, (1.0, 3.0, 5.0, 2.0)))
>>> (m.mu, m.sigma.toArray())
(DenseVector([11.0, 12.0]), array([[ 1., 5.],[ 3., 2.]]))
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 49e5c9d58e5db..06207a076eece 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -335,7 +335,7 @@ def test_infer_schema(self):
sqlCtx = SQLContext(self.sc)
rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
srdd = sqlCtx.inferSchema(rdd)
- schema = srdd.schema()
+ schema = srdd.schema
field = [f for f in schema.fields if f.name == "features"][0]
self.assertEqual(field.dataType, self.udt)
vectors = srdd.map(lambda p: p.features).collect()
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index aae48f213246b..73618f0449ad4 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -20,12 +20,12 @@
import random
from pyspark import SparkContext, RDD
-from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
+from pyspark.mllib.common import callMLlibFunc, inherit_doc, JavaModelWrapper
from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel',
- 'RandomForest', 'GradientBoostedTrees']
+ 'RandomForest', 'GradientBoostedTreesModel', 'GradientBoostedTrees']
class TreeEnsembleModel(JavaModelWrapper):
@@ -33,6 +33,10 @@ def predict(self, x):
"""
Predict values for a single data point or an RDD of points using
the model trained.
+
+ Note: In Python, predict cannot currently be used within an RDD
+ transformation or action.
+ Call predict directly on the RDD instead.
"""
if isinstance(x, RDD):
return self.call("predict", x.map(_convert_to_vector))
@@ -48,7 +52,8 @@ def numTrees(self):
def totalNumNodes(self):
"""
- Get total number of nodes, summed over all trees in the ensemble.
+ Get total number of nodes, summed over all trees in the
+ ensemble.
"""
return self.call("totalNumNodes")
@@ -71,6 +76,10 @@ def predict(self, x):
"""
Predict the label of one or more examples.
+ Note: In Python, predict cannot currently be used within an RDD
+ transformation or action.
+ Call predict directly on the RDD instead.
+
:param x: Data point (feature vector),
or an RDD of data points (feature vectors).
"""
@@ -99,7 +108,8 @@ class DecisionTree(object):
"""
.. note:: Experimental
- Learning algorithm for a decision tree model for classification or regression.
+ Learning algorithm for a decision tree model for classification or
+ regression.
"""
@classmethod
@@ -176,17 +186,17 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
:param data: Training data: RDD of LabeledPoint.
Labels are real numbers.
- :param categoricalFeaturesInfo: Map from categorical feature index
- to number of categories.
- Any feature not in this map
- is treated as continuous.
+ :param categoricalFeaturesInfo: Map from categorical feature
+ index to number of categories.
+ Any feature not in this map is treated as continuous.
:param impurity: Supported values: "variance"
:param maxDepth: Max depth of tree.
- E.g., depth 0 means 1 leaf node.
- Depth 1 means 1 internal node + 2 leaf nodes.
- :param maxBins: Number of bins used for finding splits at each node.
- :param minInstancesPerNode: Min number of instances required at child
- nodes to create the parent split
+ E.g., depth 0 means 1 leaf node.
+ Depth 1 means 1 internal node + 2 leaf nodes.
+ :param maxBins: Number of bins used for finding splits at each
+ node.
+ :param minInstancesPerNode: Min number of instances required at
+ child nodes to create the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
@@ -216,6 +226,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+@inherit_doc
class RandomForestModel(TreeEnsembleModel):
"""
.. note:: Experimental
@@ -228,7 +239,8 @@ class RandomForest(object):
"""
.. note:: Experimental
- Learning algorithm for a random forest model for classification or regression.
+ Learning algorithm for a random forest model for classification or
+ regression.
"""
supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird")
@@ -255,26 +267,33 @@ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
Method to train a decision tree model for binary or multiclass
classification.
- :param data: Training dataset: RDD of LabeledPoint. Labels should take
- values {0, 1, ..., numClasses-1}.
+ :param data: Training dataset: RDD of LabeledPoint. Labels
+ should take values {0, 1, ..., numClasses-1}.
:param numClasses: number of classes for classification.
- :param categoricalFeaturesInfo: Map storing arity of categorical features.
- E.g., an entry (n -> k) indicates that feature n is categorical
- with k categories indexed from 0: {0, 1, ..., k-1}.
+ :param categoricalFeaturesInfo: Map storing arity of categorical
+ features. E.g., an entry (n -> k) indicates that
+ feature n is categorical with k categories indexed
+ from 0: {0, 1, ..., k-1}.
:param numTrees: Number of trees in the random forest.
- :param featureSubsetStrategy: Number of features to consider for splits at
- each node.
- Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
- If "auto" is set, this parameter is set based on numTrees:
- if numTrees == 1, set to "all";
- if numTrees > 1 (forest) set to "sqrt".
- :param impurity: Criterion used for information gain calculation.
+ :param featureSubsetStrategy: Number of features to consider for
+ splits at each node.
+ Supported: "auto" (default), "all", "sqrt", "log2",
+ "onethird".
+ If "auto" is set, this parameter is set based on
+ numTrees:
+ if numTrees == 1, set to "all";
+ if numTrees > 1 (forest) set to "sqrt".
+ :param impurity: Criterion used for information gain
+ calculation.
Supported values: "gini" (recommended) or "entropy".
- :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node;
- depth 1 means 1 internal node + 2 leaf nodes. (default: 4)
- :param maxBins: maximum number of bins used for splitting features
+ :param maxDepth: Maximum depth of the tree.
+ E.g., depth 0 means 1 leaf node; depth 1 means
+ 1 internal node + 2 leaf nodes. (default: 4)
+ :param maxBins: maximum number of bins used for splitting
+ features
(default: 100)
- :param seed: Random seed for bootstrapping and choosing feature subsets.
+ :param seed: Random seed for bootstrapping and choosing feature
+ subsets.
:return: RandomForestModel that can be used for prediction
Example usage:
@@ -336,19 +355,24 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt
{0, 1, ..., k-1}.
:param numTrees: Number of trees in the random forest.
:param featureSubsetStrategy: Number of features to consider for
- splits at each node.
- Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
- If "auto" is set, this parameter is set based on numTrees:
- if numTrees == 1, set to "all";
- if numTrees > 1 (forest) set to "onethird" for regression.
- :param impurity: Criterion used for information gain calculation.
- Supported values: "variance".
- :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1
- leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- (default: 4)
- :param maxBins: maximum number of bins used for splitting features
- (default: 100)
- :param seed: Random seed for bootstrapping and choosing feature subsets.
+ splits at each node.
+ Supported: "auto" (default), "all", "sqrt", "log2",
+ "onethird".
+ If "auto" is set, this parameter is set based on
+ numTrees:
+ if numTrees == 1, set to "all";
+ if numTrees > 1 (forest) set to "onethird" for
+ regression.
+ :param impurity: Criterion used for information gain
+ calculation.
+ Supported values: "variance".
+ :param maxDepth: Maximum depth of the tree. E.g., depth 0 means
+ 1 leaf node; depth 1 means 1 internal node + 2 leaf
+ nodes. (default: 4)
+ :param maxBins: maximum number of bins used for splitting
+ features (default: 100)
+ :param seed: Random seed for bootstrapping and choosing feature
+ subsets.
:return: RandomForestModel that can be used for prediction
Example usage:
@@ -381,6 +405,7 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt
featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
+@inherit_doc
class GradientBoostedTreesModel(TreeEnsembleModel):
"""
.. note:: Experimental
@@ -393,7 +418,8 @@ class GradientBoostedTrees(object):
"""
.. note:: Experimental
- Learning algorithm for a gradient boosted trees model for classification or regression.
+ Learning algorithm for a gradient boosted trees model for
+ classification or regression.
"""
@classmethod
@@ -409,24 +435,29 @@ def _train(cls, data, algo, categoricalFeaturesInfo,
def trainClassifier(cls, data, categoricalFeaturesInfo,
loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3):
"""
- Method to train a gradient-boosted trees model for classification.
+ Method to train a gradient-boosted trees model for
+ classification.
- :param data: Training dataset: RDD of LabeledPoint. Labels should take values {0, 1}.
+ :param data: Training dataset: RDD of LabeledPoint.
+ Labels should take values {0, 1}.
:param categoricalFeaturesInfo: Map storing arity of categorical
features. E.g., an entry (n -> k) indicates that feature
n is categorical with k categories indexed from 0:
{0, 1, ..., k-1}.
- :param loss: Loss function used for minimization during gradient boosting.
- Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}.
+ :param loss: Loss function used for minimization during gradient
+ boosting. Supported: {"logLoss" (default),
+ "leastSquaresError", "leastAbsoluteError"}.
:param numIterations: Number of iterations of boosting.
(default: 100)
- :param learningRate: Learning rate for shrinking the contribution of each estimator.
- The learning rate should be between in the interval (0, 1]
- (default: 0.1)
- :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1
- leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- (default: 3)
- :return: GradientBoostedTreesModel that can be used for prediction
+ :param learningRate: Learning rate for shrinking the
+ contribution of each estimator. The learning rate
+ should be between in the interval (0, 1].
+ (default: 0.1)
+ :param maxDepth: Maximum depth of the tree. E.g., depth 0 means
+ 1 leaf node; depth 1 means 1 internal node + 2 leaf
+ nodes. (default: 3)
+ :return: GradientBoostedTreesModel that can be used for
+ prediction
Example usage:
@@ -470,17 +501,20 @@ def trainRegressor(cls, data, categoricalFeaturesInfo,
features. E.g., an entry (n -> k) indicates that feature
n is categorical with k categories indexed from 0:
{0, 1, ..., k-1}.
- :param loss: Loss function used for minimization during gradient boosting.
- Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}.
+ :param loss: Loss function used for minimization during gradient
+ boosting. Supported: {"logLoss" (default),
+ "leastSquaresError", "leastAbsoluteError"}.
:param numIterations: Number of iterations of boosting.
(default: 100)
- :param learningRate: Learning rate for shrinking the contribution of each estimator.
- The learning rate should be between in the interval (0, 1]
- (default: 0.1)
- :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1
- leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- (default: 3)
- :return: GradientBoostedTreesModel that can be used for prediction
+ :param learningRate: Learning rate for shrinking the
+ contribution of each estimator. The learning rate
+ should be between in the interval (0, 1].
+ (default: 0.1)
+ :param maxDepth: Maximum depth of the tree. E.g., depth 0 means
+ 1 leaf node; depth 1 means 1 internal node + 2 leaf
+ nodes. (default: 3)
+ :return: GradientBoostedTreesModel that can be used for
+ prediction
Example usage:
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index bd4f16e058045..cb12fed98c53d 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -111,6 +111,19 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
+class Partitioner(object):
+ def __init__(self, numPartitions, partitionFunc):
+ self.numPartitions = numPartitions
+ self.partitionFunc = partitionFunc
+
+ def __eq__(self, other):
+ return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions
+ and self.partitionFunc == other.partitionFunc)
+
+ def __call__(self, k):
+ return self.partitionFunc(k) % self.numPartitions
+
+
class RDD(object):
"""
@@ -126,7 +139,7 @@ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSeri
self.ctx = ctx
self._jrdd_deserializer = jrdd_deserializer
self._id = jrdd.id()
- self._partitionFunc = None
+ self.partitioner = None
def _pickled(self):
return self._reserialize(AutoBatchedSerializer(PickleSerializer()))
@@ -450,14 +463,17 @@ def union(self, other):
if self._jrdd_deserializer == other._jrdd_deserializer:
rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
self._jrdd_deserializer)
- return rdd
else:
# These RDDs contain data in different serialized formats, so we
# must normalize them to the default serializer.
self_copy = self._reserialize()
other_copy = other._reserialize()
- return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
- self.ctx.serializer)
+ rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+ self.ctx.serializer)
+ if (self.partitioner == other.partitioner and
+ self.getNumPartitions() == rdd.getNumPartitions()):
+ rdd.partitioner = self.partitioner
+ return rdd
def intersection(self, other):
"""
@@ -1588,6 +1604,9 @@ def partitionBy(self, numPartitions, partitionFunc=portable_hash):
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+ partitioner = Partitioner(numPartitions, partitionFunc)
+ if self.partitioner == partitioner:
+ return self
# Transferring O(n) objects to Java is too expensive.
# Instead, we'll form the hash buckets in Python,
@@ -1632,18 +1651,16 @@ def add_shuffle_key(split, iterator):
yield pack_long(split)
yield outputSerializer.dumps(items)
- keyed = self.mapPartitionsWithIndex(add_shuffle_key)
+ keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True)
keyed._bypass_serializer = True
with SCCallSiteSync(self.context) as css:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
- id(partitionFunc))
- jrdd = pairRDD.partitionBy(partitioner).values()
+ jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+ id(partitionFunc))
+ jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner))
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
- # This is required so that id(partitionFunc) remains unique,
- # even if partitionFunc is a lambda:
- rdd._partitionFunc = partitionFunc
+ rdd.partitioner = partitioner
return rdd
# TODO: add control over map-side aggregation
@@ -1689,7 +1706,7 @@ def combineLocally(iterator):
merger.mergeValues(iterator)
return merger.iteritems()
- locally_combined = self.mapPartitions(combineLocally)
+ locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True)
shuffled = locally_combined.partitionBy(numPartitions)
def _mergeCombiners(iterator):
@@ -1698,7 +1715,7 @@ def _mergeCombiners(iterator):
merger.mergeCombiners(iterator)
return merger.iteritems()
- return shuffled.mapPartitions(_mergeCombiners, True)
+ return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True)
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
@@ -1933,7 +1950,7 @@ def batch_as(rdd, batchSize):
my_batch = get_batch_size(self._jrdd_deserializer)
other_batch = get_batch_size(other._jrdd_deserializer)
- if my_batch != other_batch:
+ if my_batch != other_batch or not my_batch:
# use the smallest batchSize for both of them
batchSize = min(my_batch, other_batch)
if batchSize <= 0:
@@ -2077,8 +2094,8 @@ def lookup(self, key):
"""
values = self.filter(lambda (k, v): k == key).values()
- if self._partitionFunc is not None:
- return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False)
+ if self.partitioner is not None:
+ return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)
return values.collect()
@@ -2094,6 +2111,7 @@ def _to_java_object_rdd(self):
def countApprox(self, timeout, confidence=0.95):
"""
.. note:: Experimental
+
Approximate version of count() that returns a potentially incomplete
result within a timeout, even if not all tasks have finished.
@@ -2107,6 +2125,7 @@ def countApprox(self, timeout, confidence=0.95):
def sumApprox(self, timeout, confidence=0.95):
"""
.. note:: Experimental
+
Approximate operation to return the sum within a timeout
or meet the confidence.
@@ -2123,6 +2142,7 @@ def sumApprox(self, timeout, confidence=0.95):
def meanApprox(self, timeout, confidence=0.95):
"""
.. note:: Experimental
+
Approximate operation to return the mean within a timeout
or meet the confidence.
@@ -2139,6 +2159,7 @@ def meanApprox(self, timeout, confidence=0.95):
def countApproxDistinct(self, relativeSD=0.05):
"""
.. note:: Experimental
+
Return approximate number of distinct elements in the RDD.
The algorithm used is based on streamlib's implementation of
@@ -2243,7 +2264,7 @@ def pipeline_func(split, iterator):
self._id = None
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
- self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
+ self.partitioner = prev.partitioner if self.preservesPartitioning else None
self._broadcast = None
def __del__(self):
diff --git a/python/pyspark/shell.py b/python/pyspark/shell.py
index 89cf76920e353..1a02fece9c5a5 100644
--- a/python/pyspark/shell.py
+++ b/python/pyspark/shell.py
@@ -31,13 +31,18 @@
import atexit
import os
import platform
+
+import py4j
+
import pyspark
from pyspark.context import SparkContext
+from pyspark.sql import SQLContext, HiveContext
from pyspark.storagelevel import StorageLevel
-# this is the equivalent of ADD_JARS
-add_files = (os.environ.get("ADD_FILES").split(',')
- if os.environ.get("ADD_FILES") is not None else None)
+# this is the deprecated equivalent of ADD_JARS
+add_files = None
+if os.environ.get("ADD_FILES") is not None:
+ add_files = os.environ.get("ADD_FILES").split(',')
if os.environ.get("SPARK_EXECUTOR_URI"):
SparkContext.setSystemProperty("spark.executor.uri", os.environ["SPARK_EXECUTOR_URI"])
@@ -45,6 +50,13 @@
sc = SparkContext(appName="PySparkShell", pyFiles=add_files)
atexit.register(lambda: sc.stop())
+try:
+ # Try to access HiveConf, it will raise exception if Hive is not added
+ sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
+ sqlCtx = HiveContext(sc)
+except py4j.protocol.Py4JError:
+ sqlCtx = SQLContext(sc)
+
print("""Welcome to
____ __
/ __/__ ___ _____/ /__
@@ -56,9 +68,10 @@
platform.python_version(),
platform.python_build()[0],
platform.python_build()[1]))
-print("SparkContext available as sc.")
+print("SparkContext available as sc, %s available as sqlCtx." % sqlCtx.__class__.__name__)
if add_files is not None:
+ print("Warning: ADD_FILES environment variable is deprecated, use --py-files argument instead")
print("Adding files: [%s]" % ", ".join(add_files))
# The ./bin/pyspark script stores the old PYTHONSTARTUP value in OLD_PYTHONSTARTUP,
diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py
index 0a5ba00393aab..b9ffd6945ea7e 100644
--- a/python/pyspark/sql/__init__.py
+++ b/python/pyspark/sql/__init__.py
@@ -34,9 +34,8 @@
from pyspark.sql.context import SQLContext, HiveContext
from pyspark.sql.types import Row
-from pyspark.sql.dataframe import DataFrame, GroupedData, Column, Dsl, SchemaRDD
+from pyspark.sql.dataframe import DataFrame, GroupedData, Column, SchemaRDD
__all__ = [
'SQLContext', 'HiveContext', 'DataFrame', 'GroupedData', 'Column', 'Row',
- 'Dsl',
]
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 49f016a9cf2e9..795ef0dbc4c47 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -17,20 +17,45 @@
import warnings
import json
-from array import array
from itertools import imap
from py4j.protocol import Py4JError
+from py4j.java_collections import MapConverter
-from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.rdd import RDD, _prepare_for_python_RDD
from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
-from pyspark.sql.types import StringType, StructType, _verify_type, \
+from pyspark.sql.types import Row, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _python_to_sql_converter
from pyspark.sql.dataframe import DataFrame
+try:
+ import pandas
+ has_pandas = True
+except ImportError:
+ has_pandas = False
+
__all__ = ["SQLContext", "HiveContext"]
+def _monkey_patch_RDD(sqlCtx):
+ def toDF(self, schema=None, sampleRatio=None):
+ """
+ Convert current :class:`RDD` into a :class:`DataFrame`
+
+ This is a shorthand for `sqlCtx.createDataFrame(rdd, schema, sampleRatio)`
+
+ :param schema: a StructType or list of names of columns
+ :param samplingRatio: the sample ratio of rows used for inferring
+ :return: a DataFrame
+
+ >>> rdd.toDF().collect()
+ [Row(name=u'Alice', age=1)]
+ """
+ return sqlCtx.createDataFrame(self, schema, sampleRatio)
+
+ RDD.toDF = toDF
+
+
class SQLContext(object):
"""Main entry point for Spark SQL functionality.
@@ -42,27 +67,20 @@ class SQLContext(object):
def __init__(self, sparkContext, sqlContext=None):
"""Create a new SQLContext.
+ It will add a method called `toDF` to :class:`RDD`, which could be
+ used to convert an RDD into a DataFrame, it's a shorthand for
+ :func:`SQLContext.createDataFrame`.
+
:param sparkContext: The SparkContext to wrap.
:param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new
SQLContext in the JVM, instead we make all calls to this object.
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- TypeError:...
-
- >>> bad_rdd = sc.parallelize([1,2,3])
- >>> sqlCtx.inferSchema(bad_rdd) # doctest: +IGNORE_EXCEPTION_DETAIL
- Traceback (most recent call last):
- ...
- ValueError:...
-
>>> from datetime import datetime
+ >>> sqlCtx = SQLContext(sc)
>>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L,
... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1),
... time=datetime(2014, 8, 1, 14, 1, 5))])
- >>> df = sqlCtx.inferSchema(allTypes)
+ >>> df = allTypes.toDF()
>>> df.registerTempTable("allTypes")
>>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a '
... 'from allTypes where b and i > 0').collect()
@@ -75,6 +93,7 @@ def __init__(self, sparkContext, sqlContext=None):
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
self._scala_SQLContext = sqlContext
+ _monkey_patch_RDD(self)
@property
def _ssql_ctx(self):
@@ -87,6 +106,18 @@ def _ssql_ctx(self):
self._scala_SQLContext = self._jvm.SQLContext(self._jsc.sc())
return self._scala_SQLContext
+ def setConf(self, key, value):
+ """Sets the given Spark SQL configuration property.
+ """
+ self._ssql_ctx.setConf(key, value)
+
+ def getConf(self, key, defaultValue):
+ """Returns the value of Spark SQL configuration property for the given key.
+
+ If the key is not set, returns defaultValue.
+ """
+ return self._ssql_ctx.getConf(key, defaultValue)
+
def registerFunction(self, name, f, returnType=StringType()):
"""Registers a lambda function as a UDF so it can be used in SQL statements.
@@ -97,6 +128,7 @@ def registerFunction(self, name, f, returnType=StringType()):
>>> sqlCtx.registerFunction("stringLengthString", lambda x: len(x))
>>> sqlCtx.sql("SELECT stringLengthString('test')").collect()
[Row(c0=u'4')]
+
>>> from pyspark.sql.types import IntegerType
>>> sqlCtx.registerFunction("stringLengthInt", lambda x: len(x), IntegerType())
>>> sqlCtx.sql("SELECT stringLengthInt('test')").collect()
@@ -115,9 +147,37 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc._javaAccumulator,
returnType.json())
+ def _inferSchema(self, rdd, samplingRatio=None):
+ first = rdd.first()
+ if not first:
+ raise ValueError("The first row in RDD is empty, "
+ "can not infer schema")
+ if type(first) is dict:
+ warnings.warn("Using RDD of dict to inferSchema is deprecated,"
+ "please use pyspark.sql.Row instead")
+
+ if samplingRatio is None:
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for row in rdd.take(100)[1:]:
+ schema = _merge_type(schema, _infer_schema(row))
+ if not _has_nulltype(schema):
+ break
+ else:
+ raise ValueError("Some of types cannot be determined by the "
+ "first 100 rows, please try again with sampling")
+ else:
+ if samplingRatio < 0.99:
+ rdd = rdd.sample(False, float(samplingRatio))
+ schema = rdd.map(_infer_schema).reduce(_merge_type)
+ return schema
+
def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
+ ::note:
+ Deprecated in 1.3, use :func:`createDataFrame` instead
+
When samplingRatio is specified, the schema is inferred by looking
at the types of each row in the sampled dataset. Otherwise, the
first 100 rows of the RDD are inspected. Nested collections are
@@ -137,59 +197,12 @@ def inferSchema(self, rdd, samplingRatio=None):
>>> df = sqlCtx.inferSchema(rdd)
>>> df.collect()[0]
Row(field1=1, field2=u'row1')
-
- >>> NestedRow = Row("f1", "f2")
- >>> nestedRdd1 = sc.parallelize([
- ... NestedRow(array('i', [1, 2]), {"row1": 1.0}),
- ... NestedRow(array('i', [2, 3]), {"row2": 2.0})])
- >>> df = sqlCtx.inferSchema(nestedRdd1)
- >>> df.collect()
- [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})]
-
- >>> nestedRdd2 = sc.parallelize([
- ... NestedRow([[1, 2], [2, 3]], [1, 2]),
- ... NestedRow([[2, 3], [3, 4]], [2, 3])])
- >>> df = sqlCtx.inferSchema(nestedRdd2)
- >>> df.collect()
- [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])]
-
- >>> from collections import namedtuple
- >>> CustomRow = namedtuple('CustomRow', 'field1 field2')
- >>> rdd = sc.parallelize(
- ... [CustomRow(field1=1, field2="row1"),
- ... CustomRow(field1=2, field2="row2"),
- ... CustomRow(field1=3, field2="row3")])
- >>> df = sqlCtx.inferSchema(rdd)
- >>> df.collect()[0]
- Row(field1=1, field2=u'row1')
"""
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
- first = rdd.first()
- if not first:
- raise ValueError("The first row in RDD is empty, "
- "can not infer schema")
- if type(first) is dict:
- warnings.warn("Using RDD of dict to inferSchema is deprecated,"
- "please use pyspark.sql.Row instead")
-
- if samplingRatio is None:
- schema = _infer_schema(first)
- if _has_nulltype(schema):
- for row in rdd.take(100)[1:]:
- schema = _merge_type(schema, _infer_schema(row))
- if not _has_nulltype(schema):
- break
- else:
- warnings.warn("Some of types cannot be determined by the "
- "first 100 rows, please try again with sampling")
- else:
- if samplingRatio > 0.99:
- rdd = rdd.sample(False, float(samplingRatio))
- schema = rdd.map(_infer_schema).reduce(_merge_type)
-
+ schema = self._inferSchema(rdd, samplingRatio)
converter = _create_converter(schema)
rdd = rdd.map(converter)
return self.applySchema(rdd, schema)
@@ -198,6 +211,9 @@ def applySchema(self, rdd, schema):
"""
Applies the given schema to the given RDD of L{tuple} or L{list}.
+ ::note:
+ Deprecated in 1.3, use :func:`createDataFrame` instead
+
These tuples or lists can contain complex nested structures like
lists, maps or nested rows.
@@ -211,63 +227,15 @@ def applySchema(self, rdd, schema):
>>> schema = StructType([StructField("field1", IntegerType(), False),
... StructField("field2", StringType(), False)])
>>> df = sqlCtx.applySchema(rdd2, schema)
- >>> sqlCtx.registerRDDAsTable(df, "table1")
- >>> df2 = sqlCtx.sql("SELECT * from table1")
- >>> df2.collect()
- [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
-
- >>> from datetime import date, datetime
- >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
- ... date(2010, 1, 1),
- ... datetime(2010, 1, 1, 1, 1, 1),
- ... {"a": 1}, (2,), [1, 2, 3], None)])
- >>> schema = StructType([
- ... StructField("byte1", ByteType(), False),
- ... StructField("byte2", ByteType(), False),
- ... StructField("short1", ShortType(), False),
- ... StructField("short2", ShortType(), False),
- ... StructField("int", IntegerType(), False),
- ... StructField("float", FloatType(), False),
- ... StructField("date", DateType(), False),
- ... StructField("time", TimestampType(), False),
- ... StructField("map",
- ... MapType(StringType(), IntegerType(), False), False),
- ... StructField("struct",
- ... StructType([StructField("b", ShortType(), False)]), False),
- ... StructField("list", ArrayType(ByteType(), False), False),
- ... StructField("null", DoubleType(), True)])
- >>> df = sqlCtx.applySchema(rdd, schema)
- >>> results = df.map(
- ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
- ... x.time, x.map["a"], x.struct.b, x.list, x.null))
- >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
- (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
- datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
-
- >>> df.registerTempTable("table2")
- >>> sqlCtx.sql(
- ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
- ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " +
- ... "float + 1.5 as float FROM table2").collect()
- [Row(byte1=126, byte2=-127, short1=-32767, short2=32766, int=2147483646, float=2.5)]
-
- >>> from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
- >>> rdd = sc.parallelize([(127, -32768, 1.0,
- ... datetime(2010, 1, 1, 1, 1, 1),
- ... {"a": 1}, (2,), [1, 2, 3])])
- >>> abstract = "byte short float time map{} struct(b) list[]"
- >>> schema = _parse_schema_abstract(abstract)
- >>> typedSchema = _infer_schema_type(rdd.first(), schema)
- >>> df = sqlCtx.applySchema(rdd, typedSchema)
>>> df.collect()
- [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])]
+ [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
"""
if isinstance(rdd, DataFrame):
raise TypeError("Cannot apply schema to DataFrame")
if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
+ raise TypeError("schema should be StructType, but got %s" % schema)
# take the first few rows to verify schema
rows = rdd.take(10)
@@ -287,18 +255,102 @@ def applySchema(self, rdd, schema):
df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return DataFrame(df, self)
- def registerRDDAsTable(self, rdd, tableName):
+ def createDataFrame(self, data, schema=None, samplingRatio=None):
+ """
+ Create a DataFrame from an RDD of tuple/list, list or pandas.DataFrame.
+
+ `schema` could be :class:`StructType` or a list of column names.
+
+ When `schema` is a list of column names, the type of each column
+ will be inferred from `rdd`.
+
+ When `schema` is None, it will try to infer the column name and type
+ from `rdd`, which should be an RDD of :class:`Row`, or namedtuple,
+ or dict.
+
+ If referring needed, `samplingRatio` is used to determined how many
+ rows will be used to do referring. The first row will be used if
+ `samplingRatio` is None.
+
+ :param data: an RDD of Row/tuple/list/dict, list, or pandas.DataFrame
+ :param schema: a StructType or list of names of columns
+ :param samplingRatio: the sample ratio of rows used for inferring
+ :return: a DataFrame
+
+ >>> l = [('Alice', 1)]
+ >>> sqlCtx.createDataFrame(l).collect()
+ [Row(_1=u'Alice', _2=1)]
+ >>> sqlCtx.createDataFrame(l, ['name', 'age']).collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> d = [{'name': 'Alice', 'age': 1}]
+ >>> sqlCtx.createDataFrame(d).collect()
+ [Row(age=1, name=u'Alice')]
+
+ >>> rdd = sc.parallelize(l)
+ >>> sqlCtx.createDataFrame(rdd).collect()
+ [Row(_1=u'Alice', _2=1)]
+ >>> df = sqlCtx.createDataFrame(rdd, ['name', 'age'])
+ >>> df.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql import Row
+ >>> Person = Row('name', 'age')
+ >>> person = rdd.map(lambda r: Person(*r))
+ >>> df2 = sqlCtx.createDataFrame(person)
+ >>> df2.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> from pyspark.sql.types import *
+ >>> schema = StructType([
+ ... StructField("name", StringType(), True),
+ ... StructField("age", IntegerType(), True)])
+ >>> df3 = sqlCtx.createDataFrame(rdd, schema)
+ >>> df3.collect()
+ [Row(name=u'Alice', age=1)]
+
+ >>> sqlCtx.createDataFrame(df.toPandas()).collect() # doctest: +SKIP
+ [Row(name=u'Alice', age=1)]
+ """
+ if isinstance(data, DataFrame):
+ raise TypeError("data is already a DataFrame")
+
+ if has_pandas and isinstance(data, pandas.DataFrame):
+ if schema is None:
+ schema = list(data.columns)
+ data = [r.tolist() for r in data.to_records(index=False)]
+
+ if not isinstance(data, RDD):
+ try:
+ # data could be list, tuple, generator ...
+ data = self._sc.parallelize(data)
+ except Exception:
+ raise ValueError("cannot create an RDD from type: %s" % type(data))
+
+ if schema is None:
+ return self.inferSchema(data, samplingRatio)
+
+ if isinstance(schema, (list, tuple)):
+ first = data.first()
+ if not isinstance(first, (list, tuple)):
+ raise ValueError("each row in `rdd` should be list or tuple, "
+ "but got %r" % type(first))
+ row_cls = Row(*schema)
+ schema = self._inferSchema(data.map(lambda r: row_cls(*r)), samplingRatio)
+
+ return self.applySchema(data, schema)
+
+ def registerDataFrameAsTable(self, rdd, tableName):
"""Registers the given RDD as a temporary table in the catalog.
Temporary tables exist only during the lifetime of this instance of
SQLContext.
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> sqlCtx.registerDataFrameAsTable(df, "table1")
"""
if (rdd.__class__ is DataFrame):
df = rdd._jdf
- self._ssql_ctx.registerRDDAsTable(df, tableName)
+ self._ssql_ctx.registerDataFrameAsTable(df, tableName)
else:
raise ValueError("Can only register DataFrame as table")
@@ -308,18 +360,16 @@ def parquetFile(self, *paths):
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
>>> shutil.rmtree(parquetFile)
- >>> df = sqlCtx.inferSchema(rdd)
>>> df.saveAsParquetFile(parquetFile)
>>> df2 = sqlCtx.parquetFile(parquetFile)
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
gateway = self._sc._gateway
- jpath = paths[0]
- jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths) - 1)
- for i in range(1, len(paths)):
+ jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths))
+ for i in range(0, len(paths)):
jpaths[i] = paths[i]
- jdf = self._ssql_ctx.parquetFile(jpath, jpaths)
+ jdf = self._ssql_ctx.parquetFile(jpaths)
return DataFrame(jdf, self)
def jsonFile(self, path, schema=None, samplingRatio=1.0):
@@ -336,46 +386,28 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0):
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
>>> shutil.rmtree(jsonFile)
- >>> ofn = open(jsonFile, 'w')
- >>> for json in jsonStrings:
- ... print>>ofn, json
- >>> ofn.close()
+ >>> with open(jsonFile, 'w') as f:
+ ... f.writelines(jsonStrings)
>>> df1 = sqlCtx.jsonFile(jsonFile)
- >>> sqlCtx.registerRDDAsTable(df1, "table1")
- >>> df2 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table1")
- >>> for r in df2.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema())
- >>> sqlCtx.registerRDDAsTable(df3, "table2")
- >>> df4 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table2")
- >>> for r in df4.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+ >>> df1.printSchema()
+ root
+ |-- field1: long (nullable = true)
+ |-- field2: string (nullable = true)
+ |-- field3: struct (nullable = true)
+ | |-- field4: long (nullable = true)
>>> from pyspark.sql.types import *
>>> schema = StructType([
- ... StructField("field2", StringType(), True),
+ ... StructField("field2", StringType()),
... StructField("field3",
- ... StructType([
- ... StructField("field5",
- ... ArrayType(IntegerType(), False), True)]), False)])
- >>> df5 = sqlCtx.jsonFile(jsonFile, schema)
- >>> sqlCtx.registerRDDAsTable(df5, "table3")
- >>> df6 = sqlCtx.sql(
- ... "SELECT field2 AS f1, field3.field5 as f2, "
- ... "field3.field5[0] as f3 from table3")
- >>> df6.collect()
- [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
+ ... StructType([StructField("field5", ArrayType(IntegerType()))]))])
+ >>> df2 = sqlCtx.jsonFile(jsonFile, schema)
+ >>> df2.printSchema()
+ root
+ |-- field2: string (nullable = true)
+ |-- field3: struct (nullable = true)
+ | |-- field5: array (nullable = true)
+ | | |-- element: integer (containsNull = true)
"""
if schema is None:
df = self._ssql_ctx.jsonFile(path, samplingRatio)
@@ -394,48 +426,23 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
determine the schema.
>>> df1 = sqlCtx.jsonRDD(json)
- >>> sqlCtx.registerRDDAsTable(df1, "table1")
- >>> df2 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table1")
- >>> for r in df2.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
-
- >>> df3 = sqlCtx.jsonRDD(json, df1.schema())
- >>> sqlCtx.registerRDDAsTable(df3, "table2")
- >>> df4 = sqlCtx.sql(
- ... "SELECT field1 AS f1, field2 as f2, field3 as f3, "
- ... "field6 as f4 from table2")
- >>> for r in df4.collect():
- ... print r
- Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None)
- Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')])
- Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None)
+ >>> df1.first()
+ Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
+
+ >>> df2 = sqlCtx.jsonRDD(json, df1.schema)
+ >>> df2.first()
+ Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None)
>>> from pyspark.sql.types import *
>>> schema = StructType([
- ... StructField("field2", StringType(), True),
+ ... StructField("field2", StringType()),
... StructField("field3",
- ... StructType([
- ... StructField("field5",
- ... ArrayType(IntegerType(), False), True)]), False)])
- >>> df5 = sqlCtx.jsonRDD(json, schema)
- >>> sqlCtx.registerRDDAsTable(df5, "table3")
- >>> df6 = sqlCtx.sql(
- ... "SELECT field2 AS f1, field3.field5 as f2, "
- ... "field3.field5[0] as f3 from table3")
- >>> df6.collect()
- [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
-
- >>> sqlCtx.jsonRDD(sc.parallelize(['{}',
- ... '{"key0": {"key1": "value1"}}'])).collect()
- [Row(key0=None), Row(key0=Row(key1=u'value1'))]
- >>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
- ... '{"key0": {"key1": "value1"}}'])).collect()
- [Row(key0=None), Row(key0=Row(key1=u'value1'))]
+ ... StructType([StructField("field5", ArrayType(IntegerType()))]))
+ ... ])
+ >>> df3 = sqlCtx.jsonRDD(json, schema)
+ >>> df3.first()
+ Row(field2=u'row1', field3=Row(field5=None))
+
"""
def func(iterator):
@@ -455,11 +462,65 @@ def func(iterator):
df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return DataFrame(df, self)
+ def load(self, path=None, source=None, schema=None, **options):
+ """Returns the dataset in a data source as a DataFrame.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Optionally, a schema can be provided as the schema of the returned DataFrame.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ if schema is None:
+ df = self._ssql_ctx.load(source, joptions)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.load(source, scala_datatype, joptions)
+ return DataFrame(df, self)
+
+ def createExternalTable(self, tableName, path=None, source=None,
+ schema=None, **options):
+ """Creates an external table based on the dataset in a data source.
+
+ It returns the DataFrame associated with the external table.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Optionally, a schema can be provided as the schema of the returned DataFrame and
+ created external table.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ if schema is None:
+ df = self._ssql_ctx.createExternalTable(tableName, source, joptions)
+ else:
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
+ df = self._ssql_ctx.createExternalTable(tableName, source, scala_datatype,
+ joptions)
+ return DataFrame(df, self)
+
def sql(self, sqlQuery):
"""Return a L{DataFrame} representing the result of the given query.
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> sqlCtx.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1")
>>> df2.collect()
[Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')]
@@ -469,14 +530,47 @@ def sql(self, sqlQuery):
def table(self, tableName):
"""Returns the specified table as a L{DataFrame}.
- >>> df = sqlCtx.inferSchema(rdd)
- >>> sqlCtx.registerRDDAsTable(df, "table1")
+ >>> sqlCtx.registerDataFrameAsTable(df, "table1")
>>> df2 = sqlCtx.table("table1")
>>> sorted(df.collect()) == sorted(df2.collect())
True
"""
return DataFrame(self._ssql_ctx.table(tableName), self)
+ def tables(self, dbName=None):
+ """Returns a DataFrame containing names of tables in the given database.
+
+ If `dbName` is not specified, the current database will be used.
+
+ The returned DataFrame has two columns, tableName and isTemporary
+ (a column with BooleanType indicating if a table is a temporary one or not).
+
+ >>> sqlCtx.registerDataFrameAsTable(df, "table1")
+ >>> df2 = sqlCtx.tables()
+ >>> df2.filter("tableName = 'table1'").first()
+ Row(tableName=u'table1', isTemporary=True)
+ """
+ if dbName is None:
+ return DataFrame(self._ssql_ctx.tables(), self)
+ else:
+ return DataFrame(self._ssql_ctx.tables(dbName), self)
+
+ def tableNames(self, dbName=None):
+ """Returns a list of names of tables in the database `dbName`.
+
+ If `dbName` is not specified, the current database will be used.
+
+ >>> sqlCtx.registerDataFrameAsTable(df, "table1")
+ >>> "table1" in sqlCtx.tableNames()
+ True
+ >>> "table1" in sqlCtx.tableNames("db")
+ True
+ """
+ if dbName is None:
+ return [name for name in self._ssql_ctx.tableNames()]
+ else:
+ return [name for name in self._ssql_ctx.tableNames(dbName)]
+
def cacheTable(self, tableName):
"""Caches the specified table in-memory."""
self._ssql_ctx.cacheTable(tableName)
@@ -485,6 +579,10 @@ def uncacheTable(self, tableName):
"""Removes the specified table from the in-memory cache."""
self._ssql_ctx.uncacheTable(tableName)
+ def clearCache(self):
+ """Removes all cached tables from the in-memory cache. """
+ self._ssql_ctx.clearCache()
+
class HiveContext(SQLContext):
@@ -521,93 +619,6 @@ def _get_hive_ctx(self):
return self._jvm.HiveContext(self._jsc.sc())
-def _create_row(fields, values):
- row = Row(*values)
- row.__FIELDS__ = fields
- return row
-
-
-class Row(tuple):
-
- """
- A row in L{DataFrame}. The fields in it can be accessed like attributes.
-
- Row can be used to create a row object by using named arguments,
- the fields will be sorted by names.
-
- >>> row = Row(name="Alice", age=11)
- >>> row
- Row(age=11, name='Alice')
- >>> row.name, row.age
- ('Alice', 11)
-
- Row also can be used to create another Row like class, then it
- could be used to create Row objects, such as
-
- >>> Person = Row("name", "age")
- >>> Person
-
- >>> Person("Alice", 11)
- Row(name='Alice', age=11)
- """
-
- def __new__(self, *args, **kwargs):
- if args and kwargs:
- raise ValueError("Can not use both args "
- "and kwargs to create Row")
- if args:
- # create row class or objects
- return tuple.__new__(self, args)
-
- elif kwargs:
- # create row objects
- names = sorted(kwargs.keys())
- values = tuple(kwargs[n] for n in names)
- row = tuple.__new__(self, values)
- row.__FIELDS__ = names
- return row
-
- else:
- raise ValueError("No args or kwargs")
-
- def asDict(self):
- """
- Return as an dict
- """
- if not hasattr(self, "__FIELDS__"):
- raise TypeError("Cannot convert a Row class into dict")
- return dict(zip(self.__FIELDS__, self))
-
- # let obect acs like class
- def __call__(self, *args):
- """create new Row object"""
- return _create_row(self, args)
-
- def __getattr__(self, item):
- if item.startswith("__"):
- raise AttributeError(item)
- try:
- # it will be slow when it has many fields,
- # but this will not be used in normal cases
- idx = self.__FIELDS__.index(item)
- return self[idx]
- except IndexError:
- raise AttributeError(item)
-
- def __reduce__(self):
- if hasattr(self, "__FIELDS__"):
- return (_create_row, (self.__FIELDS__, tuple(self)))
- else:
- return tuple.__reduce__(self)
-
- def __repr__(self):
- if hasattr(self, "__FIELDS__"):
- return "Row(%s)" % ", ".join("%s=%r" % (k, v)
- for k, v in zip(self.__FIELDS__, self))
- else:
- return "" % ", ".join(self)
-
-
def _test():
import doctest
from pyspark.context import SparkContext
@@ -617,11 +628,13 @@ def _test():
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = sqlCtx = SQLContext(sc)
- globs['rdd'] = sc.parallelize(
+ globs['rdd'] = rdd = sc.parallelize(
[Row(field1=1, field2="row1"),
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
+ _monkey_patch_RDD(sqlCtx)
+ globs['df'] = rdd.toDF()
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
@@ -632,7 +645,8 @@ def _test():
globs['jsonStrings'] = jsonStrings
globs['json'] = sc.parallelize(jsonStrings)
(failure_count, test_count) = doctest.testmod(
- pyspark.sql.context, globs=globs, optionflags=doctest.ELLIPSIS)
+ pyspark.sql.context, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index cda704eea75f5..5c3b7377c33b5 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -21,21 +21,19 @@
import random
import os
from tempfile import NamedTemporaryFile
-from itertools import imap
from py4j.java_collections import ListConverter, MapConverter
from pyspark.context import SparkContext
-from pyspark.rdd import RDD, _prepare_for_python_RDD
-from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
- UTF8Deserializer
+from pyspark.rdd import RDD
+from pyspark.serializers import BatchedSerializer, PickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import *
from pyspark.sql.types import _create_cls, _parse_datatype_json_string
-__all__ = ["DataFrame", "GroupedData", "Column", "Dsl", "SchemaRDD"]
+__all__ = ["DataFrame", "GroupedData", "Column", "SchemaRDD"]
class DataFrame(object):
@@ -76,6 +74,7 @@ def __init__(self, jdf, sql_ctx):
self.sql_ctx = sql_ctx
self._sc = sql_ctx and sql_ctx._sc
self.is_cached = False
+ self._schema = None # initialized lazily
@property
def rdd(self):
@@ -86,7 +85,7 @@ def rdd(self):
if not hasattr(self, '_lazy_rdd'):
jrdd = self._jdf.javaToPython()
rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer()))
- schema = self.schema()
+ schema = self.schema
def applySchema(it):
cls = _create_cls(schema)
@@ -97,7 +96,7 @@ def applySchema(it):
return self._lazy_rdd
def toJSON(self, use_unicode=False):
- """Convert a DataFrame into a MappedRDD of JSON documents; one document per row.
+ """Convert a :class:`DataFrame` into a MappedRDD of JSON documents; one document per row.
>>> df.toJSON().first()
'{"age":2,"name":"Alice"}'
@@ -109,7 +108,7 @@ def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.
Files that are written out using this method can be read back in as
- a DataFrame using the L{SQLContext.parquetFile} method.
+ a :class:`DataFrame` using the L{SQLContext.parquetFile} method.
>>> import tempfile, shutil
>>> parquetFile = tempfile.mkdtemp()
@@ -140,24 +139,94 @@ def registerAsTable(self, name):
self.registerTempTable(name)
def insertInto(self, tableName, overwrite=False):
- """Inserts the contents of this DataFrame into the specified table.
+ """Inserts the contents of this :class:`DataFrame` into the specified table.
Optionally overwriting any existing data.
"""
self._jdf.insertInto(tableName, overwrite)
- def saveAsTable(self, tableName):
- """Creates a new table with the contents of this DataFrame."""
- self._jdf.saveAsTable(tableName)
+ def _java_save_mode(self, mode):
+ """Returns the Java save mode based on the Python save mode represented by a string.
+ """
+ jSaveMode = self._sc._jvm.org.apache.spark.sql.SaveMode
+ jmode = jSaveMode.ErrorIfExists
+ mode = mode.lower()
+ if mode == "append":
+ jmode = jSaveMode.Append
+ elif mode == "overwrite":
+ jmode = jSaveMode.Overwrite
+ elif mode == "ignore":
+ jmode = jSaveMode.Ignore
+ elif mode == "error":
+ pass
+ else:
+ raise ValueError(
+ "Only 'append', 'overwrite', 'ignore', and 'error' are acceptable save mode.")
+ return jmode
+
+ def saveAsTable(self, tableName, source=None, mode="append", **options):
+ """Saves the contents of the :class:`DataFrame` to a data source as a table.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Additionally, mode is used to specify the behavior of the saveAsTable operation when
+ table already exists in the data source. There are four modes:
+
+ * append: Contents of this :class:`DataFrame` are expected to be appended \
+ to existing table.
+ * overwrite: Data in the existing table is expected to be overwritten by \
+ the contents of this DataFrame.
+ * error: An exception is expected to be thrown.
+ * ignore: The save operation is expected to not save the contents of the \
+ :class:`DataFrame` and to not change the existing table.
+ """
+ if source is None:
+ source = self.sql_ctx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ jmode = self._java_save_mode(mode)
+ joptions = MapConverter().convert(options,
+ self.sql_ctx._sc._gateway._gateway_client)
+ self._jdf.saveAsTable(tableName, source, jmode, joptions)
+
+ def save(self, path=None, source=None, mode="append", **options):
+ """Saves the contents of the :class:`DataFrame` to a data source.
+
+ The data source is specified by the `source` and a set of `options`.
+ If `source` is not specified, the default data source configured by
+ spark.sql.sources.default will be used.
+
+ Additionally, mode is used to specify the behavior of the save operation when
+ data already exists in the data source. There are four modes:
+
+ * append: Contents of this :class:`DataFrame` are expected to be appended to existing data.
+ * overwrite: Existing data is expected to be overwritten by the contents of this DataFrame.
+ * error: An exception is expected to be thrown.
+ * ignore: The save operation is expected to not save the contents of \
+ the :class:`DataFrame` and to not change the existing data.
+ """
+ if path is not None:
+ options["path"] = path
+ if source is None:
+ source = self.sql_ctx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ jmode = self._java_save_mode(mode)
+ joptions = MapConverter().convert(options,
+ self._sc._gateway._gateway_client)
+ self._jdf.save(source, jmode, joptions)
+ @property
def schema(self):
- """Returns the schema of this DataFrame (represented by
+ """Returns the schema of this :class:`DataFrame` (represented by
a L{StructType}).
- >>> df.schema()
+ >>> df.schema
StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
"""
- return _parse_datatype_json_string(self._jdf.schema().json())
+ if self._schema is None:
+ self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+ return self._schema
def printSchema(self):
"""Prints out the schema in the tree format.
@@ -170,6 +239,55 @@ def printSchema(self):
"""
print (self._jdf.schema().treeString())
+ def explain(self, extended=False):
+ """
+ Prints the plans (logical and physical) to the console for
+ debugging purpose.
+
+ If extended is False, only prints the physical plan.
+
+ >>> df.explain()
+ PhysicalRDD [age#0,name#1], MapPartitionsRDD[...] at mapPartitions at SQLContext.scala:...
+
+ >>> df.explain(True)
+ == Parsed Logical Plan ==
+ ...
+ == Analyzed Logical Plan ==
+ ...
+ == Optimized Logical Plan ==
+ ...
+ == Physical Plan ==
+ ...
+ == RDD ==
+ """
+ if extended:
+ print self._jdf.queryExecution().toString()
+ else:
+ print self._jdf.queryExecution().executedPlan().toString()
+
+ def isLocal(self):
+ """
+ Returns True if the `collect` and `take` methods can be run locally
+ (without any Spark executors).
+ """
+ return self._jdf.isLocal()
+
+ def show(self, n=20):
+ """
+ Print the first n rows.
+
+ >>> df
+ DataFrame[age: int, name: string]
+ >>> df.show()
+ age name
+ 2 Alice
+ 5 Bob
+ """
+ print self._jdf.showString(n).encode('utf8', 'ignore')
+
+ def __repr__(self):
+ return "DataFrame[%s]" % (", ".join("%s: %s" % c for c in self.dtypes))
+
def count(self):
"""Return the number of elements in this RDD.
@@ -200,7 +318,7 @@ def collect(self):
with open(tempFile.name, 'rb') as tempFile:
rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile))
os.unlink(tempFile.name)
- cls = _create_cls(self.schema())
+ cls = _create_cls(self.schema)
return [cls(r) for r in rs]
def limit(self, num):
@@ -226,18 +344,32 @@ def take(self, num):
return self.limit(num).collect()
def map(self, f):
- """ Return a new RDD by applying a function to each Row, it's a
- shorthand for df.rdd.map()
+ """ Return a new RDD by applying a function to each Row
+
+ It's a shorthand for df.rdd.map()
>>> df.map(lambda p: p.name).collect()
[u'Alice', u'Bob']
"""
return self.rdd.map(f)
+ def flatMap(self, f):
+ """ Return a new RDD by first applying a function to all elements of this,
+ and then flattening the results.
+
+ It's a shorthand for df.rdd.flatMap()
+
+ >>> df.flatMap(lambda p: p.name).collect()
+ [u'A', u'l', u'i', u'c', u'e', u'B', u'o', u'b']
+ """
+ return self.rdd.flatMap(f)
+
def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition.
+ It's a shorthand for df.rdd.mapPartitions()
+
>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(iterator): yield 1
>>> rdd.mapPartitions(f).sum()
@@ -245,6 +377,31 @@ def mapPartitions(self, f, preservesPartitioning=False):
"""
return self.rdd.mapPartitions(f, preservesPartitioning)
+ def foreach(self, f):
+ """
+ Applies a function to all rows of this DataFrame.
+
+ It's a shorthand for df.rdd.foreach()
+
+ >>> def f(person):
+ ... print person.name
+ >>> df.foreach(f)
+ """
+ return self.rdd.foreach(f)
+
+ def foreachPartition(self, f):
+ """
+ Applies a function to each partition of this DataFrame.
+
+ It's a shorthand for df.rdd.foreachPartition()
+
+ >>> def f(people):
+ ... for person in people:
+ ... print person.name
+ >>> df.foreachPartition(f)
+ """
+ return self.rdd.foreachPartition(f)
+
def cache(self):
""" Persist with the default storage level (C{MEMORY_ONLY_SER}).
"""
@@ -278,9 +435,20 @@ def unpersist(self, blocking=True):
def repartition(self, numPartitions):
""" Return a new :class:`DataFrame` that has exactly `numPartitions`
partitions.
+
+ >>> df.repartition(10).rdd.getNumPartitions()
+ 10
"""
- rdd = self._jdf.repartition(numPartitions, None)
- return DataFrame(rdd, self.sql_ctx)
+ return DataFrame(self._jdf.repartition(numPartitions), self.sql_ctx)
+
+ def distinct(self):
+ """
+ Return a new :class:`DataFrame` containing the distinct rows in this DataFrame.
+
+ >>> df.distinct().count()
+ 2L
+ """
+ return DataFrame(self._jdf.distinct(), self.sql_ctx)
def sample(self, withReplacement, fraction, seed=None):
"""
@@ -294,29 +462,14 @@ def sample(self, withReplacement, fraction, seed=None):
rdd = self._jdf.sample(withReplacement, fraction, long(seed))
return DataFrame(rdd, self.sql_ctx)
- # def takeSample(self, withReplacement, num, seed=None):
- # """Return a fixed-size sampled subset of this DataFrame.
- #
- # >>> df = sqlCtx.inferSchema(rdd)
- # >>> df.takeSample(False, 2, 97)
- # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')]
- # """
- # seed = seed if seed is not None else random.randint(0, sys.maxint)
- # with SCCallSiteSync(self.context) as css:
- # bytesInJava = self._jdf \
- # .takeSampleToPython(withReplacement, num, long(seed)) \
- # .iterator()
- # cls = _create_cls(self.schema())
- # return map(cls, self._collect_iterator_through_file(bytesInJava))
-
@property
def dtypes(self):
"""Return all column names and their data types as a list.
>>> df.dtypes
- [('age', 'integer'), ('name', 'string')]
+ [('age', 'int'), ('name', 'string')]
"""
- return [(str(f.name), f.dataType.jsonValue()) for f in self.schema().fields]
+ return [(str(f.name), f.dataType.simpleString()) for f in self.schema.fields]
@property
def columns(self):
@@ -325,12 +478,12 @@ def columns(self):
>>> df.columns
[u'age', u'name']
"""
- return [f.name for f in self.schema().fields]
+ return [f.name for f in self.schema.fields]
def join(self, other, joinExprs=None, joinType=None):
"""
- Join with another DataFrame, using the given join expression.
- The following performs a full outer join between `df1` and `df2`::
+ Join with another :class:`DataFrame`, using the given join expression.
+ The following performs a full outer join between `df1` and `df2`.
:param other: Right side of the join
:param joinExprs: Join expression
@@ -352,13 +505,18 @@ def join(self, other, joinExprs=None, joinType=None):
return DataFrame(jdf, self.sql_ctx)
def sort(self, *cols):
- """ Return a new :class:`DataFrame` sorted by the specified column.
+ """ Return a new :class:`DataFrame` sorted by the specified column(s).
:param cols: The columns or expressions used for sorting
>>> df.sort(df.age.desc()).collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
- >>> df.sortBy(df.age.desc()).collect()
+ >>> df.orderBy(df.age.desc()).collect()
+ [Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
+ >>> from pyspark.sql.functions import *
+ >>> df.sort(asc("age")).collect()
+ [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
+ >>> df.orderBy(desc("age"), "name").collect()
[Row(age=5, name=u'Bob'), Row(age=2, name=u'Alice')]
"""
if not cols:
@@ -368,7 +526,7 @@ def sort(self, *cols):
jdf = self._jdf.sort(self._sc._jvm.PythonUtils.toSeq(jcols))
return DataFrame(jdf, self.sql_ctx)
- sortBy = sort
+ orderBy = sort
def head(self, n=None):
""" Return the first `n` rows or the first row if n is None.
@@ -394,7 +552,7 @@ def first(self):
def __getitem__(self, item):
""" Return the column by given name
- >>> df['age'].collect()
+ >>> df.select(df['age']).collect()
[Row(age=2), Row(age=5)]
>>> df[ ["name", "age"]].collect()
[Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
@@ -403,7 +561,7 @@ def __getitem__(self, item):
"""
if isinstance(item, basestring):
jc = self._jdf.apply(item)
- return Column(jc, self.sql_ctx)
+ return Column(jc)
elif isinstance(item, Column):
return self.filter(item)
elif isinstance(item, list):
@@ -414,19 +572,17 @@ def __getitem__(self, item):
def __getattr__(self, name):
""" Return the column by given name
- >>> df.age.collect()
+ >>> df.select(df.age).collect()
[Row(age=2), Row(age=5)]
"""
if name.startswith("__"):
raise AttributeError(name)
jc = self._jdf.apply(name)
- return Column(jc, self.sql_ctx)
+ return Column(jc)
def select(self, *cols):
""" Selecting a set of expressions.
- >>> df.select().collect()
- [Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('*').collect()
[Row(age=2, name=u'Alice'), Row(age=5, name=u'Bob')]
>>> df.select('name', 'age').collect()
@@ -434,8 +590,6 @@ def select(self, *cols):
>>> df.select(df.name, (df.age + 10).alias('age')).collect()
[Row(name=u'Alice', age=12), Row(name=u'Bob', age=15)]
"""
- if not cols:
- cols = ["*"]
jcols = ListConverter().convert([_to_java_column(c) for c in cols],
self._sc._gateway._gateway_client)
jdf = self._jdf.select(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jcols))
@@ -447,7 +601,7 @@ def selectExpr(self, *expr):
`select` that accepts SQL expressions.
>>> df.selectExpr("age * 2", "abs(age)").collect()
- [Row(('age * 2)=4, Abs('age)=2), Row(('age * 2)=10, Abs('age)=5)]
+ [Row((age * 2)=4, Abs(age)=2), Row((age * 2)=10, Abs(age)=5)]
"""
jexpr = ListConverter().convert(expr, self._sc._gateway._gateway_client)
jdf = self._jdf.selectExpr(self._sc._jvm.PythonUtils.toSeq(jexpr))
@@ -455,7 +609,7 @@ def selectExpr(self, *expr):
def filter(self, condition):
""" Filtering rows using the given condition, which could be
- Column expression or string of SQL expression.
+ :class:`Column` expression or string of SQL expression.
where() is an alias for filter().
@@ -502,14 +656,14 @@ def agg(self, *exprs):
>>> df.agg({"age": "max"}).collect()
[Row(MAX(age#0)=5)]
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.min(df.age)).collect()
+ >>> from pyspark.sql import functions as F
+ >>> df.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=2)]
"""
return self.groupBy().agg(*exprs)
def unionAll(self, other):
- """ Return a new DataFrame containing union of rows in this
+ """ Return a new :class:`DataFrame` containing union of rows in this
frame and another frame.
This is equivalent to `UNION ALL` in SQL.
@@ -532,19 +686,30 @@ def subtract(self, other):
"""
return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx)
- def addColumn(self, colName, col):
+ def withColumn(self, colName, col):
""" Return a new :class:`DataFrame` by adding a column.
- >>> df.addColumn('age2', df.age + 2).collect()
+ >>> df.withColumn('age2', df.age + 2).collect()
[Row(age=2, name=u'Alice', age2=4), Row(age=5, name=u'Bob', age2=7)]
"""
return self.select('*', col.alias(colName))
- def to_pandas(self):
+ def withColumnRenamed(self, existing, new):
+ """ Rename an existing column to a new name
+
+ >>> df.withColumnRenamed('age', 'age2').collect()
+ [Row(age2=2, name=u'Alice'), Row(age2=5, name=u'Bob')]
+ """
+ cols = [Column(_to_java_column(c)).alias(new)
+ if c == existing else c
+ for c in self.columns]
+ return self.select(*cols)
+
+ def toPandas(self):
"""
Collect all the rows and return a `pandas.DataFrame`.
- >>> df.to_pandas() # doctest: +SKIP
+ >>> df.toPandas() # doctest: +SKIP
age name
0 2 Alice
1 5 Bob
@@ -570,6 +735,18 @@ def _api(self):
return _api
+def df_varargs_api(f):
+ def _api(self, *args):
+ jargs = ListConverter().convert(args,
+ self.sql_ctx._sc._gateway._gateway_client)
+ name = f.__name__
+ jdf = getattr(self._jdf, name)(self.sql_ctx._sc._jvm.PythonUtils.toSeq(jargs))
+ return DataFrame(jdf, self.sql_ctx)
+ _api.__name__ = f.__name__
+ _api.__doc__ = f.__doc__
+ return _api
+
+
class GroupedData(object):
"""
@@ -592,10 +769,11 @@ def agg(self, *exprs):
name to aggregate methods.
>>> gdf = df.groupBy(df.name)
- >>> gdf.agg({"age": "max"}).collect()
- [Row(name=u'Bob', MAX(age#0)=5), Row(name=u'Alice', MAX(age#0)=2)]
- >>> from pyspark.sql import Dsl
- >>> gdf.agg(Dsl.min(df.age)).collect()
+ >>> gdf.agg({"*": "count"}).collect()
+ [Row(name=u'Bob', COUNT(1)=1), Row(name=u'Alice', COUNT(1)=1)]
+
+ >>> from pyspark.sql import functions as F
+ >>> gdf.agg(F.min(df.age)).collect()
[Row(MIN(age#0)=5), Row(MIN(age#0)=2)]
"""
assert exprs, "exprs should not be empty"
@@ -619,40 +797,70 @@ def count(self):
[Row(age=2, count=1), Row(age=5, count=1)]
"""
- @dfapi
- def mean(self):
+ @df_varargs_api
+ def mean(self, *cols):
"""Compute the average value for each numeric columns
- for each group. This is an alias for `avg`."""
+ for each group. This is an alias for `avg`.
- @dfapi
- def avg(self):
+ >>> df.groupBy().mean('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().mean('age', 'height').collect()
+ [Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
+ """
+
+ @df_varargs_api
+ def avg(self, *cols):
"""Compute the average value for each numeric columns
- for each group."""
+ for each group.
- @dfapi
- def max(self):
+ >>> df.groupBy().avg('age').collect()
+ [Row(AVG(age#0)=3.5)]
+ >>> df3.groupBy().avg('age', 'height').collect()
+ [Row(AVG(age#4L)=3.5, AVG(height#5L)=82.5)]
+ """
+
+ @df_varargs_api
+ def max(self, *cols):
"""Compute the max value for each numeric columns for
- each group. """
+ each group.
- @dfapi
- def min(self):
+ >>> df.groupBy().max('age').collect()
+ [Row(MAX(age#0)=5)]
+ >>> df3.groupBy().max('age', 'height').collect()
+ [Row(MAX(age#4L)=5, MAX(height#5L)=85)]
+ """
+
+ @df_varargs_api
+ def min(self, *cols):
"""Compute the min value for each numeric column for
- each group."""
+ each group.
- @dfapi
- def sum(self):
+ >>> df.groupBy().min('age').collect()
+ [Row(MIN(age#0)=2)]
+ >>> df3.groupBy().min('age', 'height').collect()
+ [Row(MIN(age#4L)=2, MIN(height#5L)=80)]
+ """
+
+ @df_varargs_api
+ def sum(self, *cols):
"""Compute the sum for each numeric columns for each
- group."""
+ group.
+
+ >>> df.groupBy().sum('age').collect()
+ [Row(SUM(age#0)=7)]
+ >>> df3.groupBy().sum('age', 'height').collect()
+ [Row(SUM(age#4L)=7, SUM(height#5L)=165)]
+ """
def _create_column_from_literal(literal):
sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.lit(literal)
+ return sc._jvm.functions.lit(literal)
def _create_column_from_name(name):
sc = SparkContext._active_spark_context
- return sc._jvm.Dsl.col(name)
+ return sc._jvm.functions.col(name)
def _to_java_column(col):
@@ -667,15 +875,16 @@ def _unary_op(name, doc="unary operator"):
""" Create a method for given unary operator """
def _(self):
jc = getattr(self._jc, name)()
- return Column(jc, self.sql_ctx)
+ return Column(jc)
_.__doc__ = doc
return _
-def _dsl_op(name, doc=''):
+def _func_op(name, doc=''):
def _(self):
- jc = getattr(self._sc._jvm.Dsl, name)(self._jc)
- return Column(jc, self.sql_ctx)
+ sc = SparkContext._active_spark_context
+ jc = getattr(sc._jvm.functions, name)(self._jc)
+ return Column(jc)
_.__doc__ = doc
return _
@@ -686,7 +895,7 @@ def _bin_op(name, doc="binary operator"):
def _(self, other):
jc = other._jc if isinstance(other, Column) else other
njc = getattr(self._jc, name)(jc)
- return Column(njc, self.sql_ctx)
+ return Column(njc)
_.__doc__ = doc
return _
@@ -697,19 +906,20 @@ def _reverse_op(name, doc="binary operator"):
def _(self, other):
jother = _create_column_from_literal(other)
jc = getattr(jother, name)(self._jc)
- return Column(jc, self.sql_ctx)
+ return Column(jc)
_.__doc__ = doc
return _
-class Column(DataFrame):
+class Column(object):
"""
A column in a DataFrame.
- `Column` instances can be created by::
+ :class:`Column` instances can be created by::
# 1. Select a column out of a DataFrame
+
df.colName
df["colName"]
@@ -718,12 +928,11 @@ class Column(DataFrame):
1 / df.colName
"""
- def __init__(self, jc, sql_ctx=None):
+ def __init__(self, jc):
self._jc = jc
- super(Column, self).__init__(jc, sql_ctx)
# arithmetic operators
- __neg__ = _dsl_op("negate")
+ __neg__ = _func_op("negate")
__add__ = _bin_op("plus")
__sub__ = _bin_op("minus")
__mul__ = _bin_op("multiply")
@@ -747,7 +956,7 @@ def __init__(self, jc, sql_ctx=None):
# so use bitwise operators as boolean operators
__and__ = _bin_op('and')
__or__ = _bin_op('or')
- __invert__ = _dsl_op('not')
+ __invert__ = _func_op('not')
__rand__ = _bin_op("and")
__ror__ = _bin_op("or")
@@ -764,12 +973,12 @@ def __init__(self, jc, sql_ctx=None):
def substr(self, startPos, length):
"""
- Return a Column which is a substring of the column
+ Return a :class:`Column` which is a substring of the column
:param startPos: start position (int or Column)
:param length: length of the substring (int or Column)
- >>> df.name.substr(1, 3).collect()
+ >>> df.select(df.name.substr(1, 3).alias("col")).collect()
[Row(col=u'Ali'), Row(col=u'Bob')]
"""
if type(startPos) != type(length):
@@ -780,13 +989,15 @@ def substr(self, startPos, length):
jc = self._jc.substr(startPos._jc, length._jc)
else:
raise TypeError("Unexpected type: %s" % type(startPos))
- return Column(jc, self.sql_ctx)
+ return Column(jc)
__getslice__ = substr
# order
- asc = _unary_op("asc")
- desc = _unary_op("desc")
+ asc = _unary_op("asc", "Returns a sort expression based on the"
+ " ascending order of the given column name.")
+ desc = _unary_op("desc", "Returns a sort expression based on the"
+ " descending order of the given column name.")
isNull = _unary_op("isNull", "True if the current expression is null.")
isNotNull = _unary_op("isNotNull", "True if the current expression is not null.")
@@ -794,10 +1005,10 @@ def substr(self, startPos, length):
def alias(self, alias):
"""Return a alias for this column
- >>> df.age.alias("age2").collect()
+ >>> df.select(df.age.alias("age2")).collect()
[Row(age2=2), Row(age2=5)]
"""
- return Column(getattr(self._jc, "as")(alias), self.sql_ctx)
+ return Column(getattr(self._jc, "as")(alias))
def cast(self, dataType):
""" Convert the column into type `dataType`
@@ -807,147 +1018,19 @@ def cast(self, dataType):
>>> df.select(df.age.cast(StringType()).alias('ages')).collect()
[Row(ages=u'2'), Row(ages=u'5')]
"""
- if self.sql_ctx is None:
- sc = SparkContext._active_spark_context
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- else:
- ssql_ctx = self.sql_ctx._ssql_ctx
if isinstance(dataType, basestring):
jc = self._jc.cast(dataType)
elif isinstance(dataType, DataType):
+ sc = SparkContext._active_spark_context
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
jdt = ssql_ctx.parseDataType(dataType.json())
jc = self._jc.cast(jdt)
- return Column(jc, self.sql_ctx)
-
- def to_pandas(self):
- """
- Return a pandas.Series from the column
-
- >>> df.age.to_pandas() # doctest: +SKIP
- 0 2
- 1 5
- dtype: int64
- """
- import pandas as pd
- data = [c for c, in self.collect()]
- return pd.Series(data)
-
-
-def _aggregate_func(name, doc=""):
- """ Create a function for aggregator by name"""
- def _(col):
- sc = SparkContext._active_spark_context
- jc = getattr(sc._jvm.Dsl, name)(_to_java_column(col))
- return Column(jc)
- _.__name__ = name
- _.__doc__ = doc
- return staticmethod(_)
-
-
-class UserDefinedFunction(object):
- def __init__(self, func, returnType):
- self.func = func
- self.returnType = returnType
- self._broadcast = None
- self._judf = self._create_judf()
-
- def _create_judf(self):
- f = self.func # put it in closure `func`
- func = lambda _, it: imap(lambda x: f(*x), it)
- ser = AutoBatchedSerializer(PickleSerializer())
- command = (func, None, ser, ser)
- sc = SparkContext._active_spark_context
- pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
- ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
- jdt = ssql_ctx.parseDataType(self.returnType.json())
- judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
- includes, sc.pythonExec, broadcast_vars,
- sc._javaAccumulator, jdt)
- return judf
-
- def __del__(self):
- if self._broadcast is not None:
- self._broadcast.unpersist()
- self._broadcast = None
-
- def __call__(self, *cols):
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
-
-class Dsl(object):
- """
- A collections of builtin aggregators
- """
- DSLS = {
- 'lit': 'Creates a :class:`Column` of literal value.',
- 'col': 'Returns a :class:`Column` based on the given column name.',
- 'column': 'Returns a :class:`Column` based on the given column name.',
- 'upper': 'Converts a string expression to upper case.',
- 'lower': 'Converts a string expression to upper case.',
- 'sqrt': 'Computes the square root of the specified float value.',
- 'abs': 'Computes the absolutle value.',
-
- 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
- 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
- 'first': 'Aggregate function: returns the first value in a group.',
- 'last': 'Aggregate function: returns the last value in a group.',
- 'count': 'Aggregate function: returns the number of items in a group.',
- 'sum': 'Aggregate function: returns the sum of all values in the expression.',
- 'avg': 'Aggregate function: returns the average of the values in a group.',
- 'mean': 'Aggregate function: returns the average of the values in a group.',
- 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
- }
-
- for _name, _doc in DSLS.items():
- locals()[_name] = _aggregate_func(_name, _doc)
- del _name, _doc
-
- @staticmethod
- def countDistinct(col, *cols):
- """ Return a new Column for distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.countDistinct(df.age, df.name).alias('c')).collect()
- [Row(c=2)]
-
- >>> df.agg(Dsl.countDistinct("age", "name").alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- jcols = ListConverter().convert([_to_java_column(c) for c in cols],
- sc._gateway._gateway_client)
- jc = sc._jvm.Dsl.countDistinct(_to_java_column(col),
- sc._jvm.PythonUtils.toSeq(jcols))
- return Column(jc)
-
- @staticmethod
- def approxCountDistinct(col, rsd=None):
- """ Return a new Column for approxiate distinct count of (col, *cols)
-
- >>> from pyspark.sql import Dsl
- >>> df.agg(Dsl.approxCountDistinct(df.age).alias('c')).collect()
- [Row(c=2)]
- """
- sc = SparkContext._active_spark_context
- if rsd is None:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col))
else:
- jc = sc._jvm.Dsl.approxCountDistinct(_to_java_column(col), rsd)
+ raise TypeError("unexpected type: %s" % type(dataType))
return Column(jc)
- @staticmethod
- def udf(f, returnType=StringType()):
- """Create a user defined function (UDF)
-
- >>> slen = Dsl.udf(lambda s: len(s), IntegerType())
- >>> df.select(slen(df.name).alias('slen')).collect()
- [Row(slen=5), Row(slen=3)]
- """
- return UserDefinedFunction(f, returnType)
+ def __repr__(self):
+ return 'Column<%s>' % self._jc.toString().encode('utf8')
def _test():
@@ -958,13 +1041,16 @@ def _test():
globs = pyspark.sql.dataframe.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
- globs['sqlCtx'] = sqlCtx = SQLContext(sc)
- rdd2 = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)])
- rdd3 = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)])
- globs['df'] = sqlCtx.inferSchema(rdd2)
- globs['df2'] = sqlCtx.inferSchema(rdd3)
+ globs['sqlCtx'] = SQLContext(sc)
+ globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')])\
+ .toDF(StructType([StructField('age', IntegerType()),
+ StructField('name', StringType())]))
+ globs['df2'] = sc.parallelize([Row(name='Tom', height=80), Row(name='Bob', height=85)]).toDF()
+ globs['df3'] = sc.parallelize([Row(name='Alice', age=2, height=80),
+ Row(name='Bob', age=5, height=85)]).toDF()
(failure_count, test_count) = doctest.testmod(
- pyspark.sql.dataframe, globs=globs, optionflags=doctest.ELLIPSIS)
+ pyspark.sql.dataframe, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
globs['sc'].stop()
if failure_count:
exit(-1)
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
new file mode 100644
index 0000000000000..5873f09ae3275
--- /dev/null
+++ b/python/pyspark/sql/functions.py
@@ -0,0 +1,174 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A collections of builtin functions
+"""
+
+from itertools import imap
+
+from py4j.java_collections import ListConverter
+
+from pyspark import SparkContext
+from pyspark.rdd import _prepare_for_python_RDD
+from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
+from pyspark.sql.types import StringType
+from pyspark.sql.dataframe import Column, _to_java_column
+
+
+__all__ = ['countDistinct', 'approxCountDistinct', 'udf']
+
+
+def _create_function(name, doc=""):
+ """ Create a function for aggregator by name"""
+ def _(col):
+ sc = SparkContext._active_spark_context
+ jc = getattr(sc._jvm.functions, name)(col._jc if isinstance(col, Column) else col)
+ return Column(jc)
+ _.__name__ = name
+ _.__doc__ = doc
+ return _
+
+
+_functions = {
+ 'lit': 'Creates a :class:`Column` of literal value.',
+ 'col': 'Returns a :class:`Column` based on the given column name.',
+ 'column': 'Returns a :class:`Column` based on the given column name.',
+ 'asc': 'Returns a sort expression based on the ascending order of the given column name.',
+ 'desc': 'Returns a sort expression based on the descending order of the given column name.',
+
+ 'upper': 'Converts a string expression to upper case.',
+ 'lower': 'Converts a string expression to upper case.',
+ 'sqrt': 'Computes the square root of the specified float value.',
+ 'abs': 'Computes the absolutle value.',
+
+ 'max': 'Aggregate function: returns the maximum value of the expression in a group.',
+ 'min': 'Aggregate function: returns the minimum value of the expression in a group.',
+ 'first': 'Aggregate function: returns the first value in a group.',
+ 'last': 'Aggregate function: returns the last value in a group.',
+ 'count': 'Aggregate function: returns the number of items in a group.',
+ 'sum': 'Aggregate function: returns the sum of all values in the expression.',
+ 'avg': 'Aggregate function: returns the average of the values in a group.',
+ 'mean': 'Aggregate function: returns the average of the values in a group.',
+ 'sumDistinct': 'Aggregate function: returns the sum of distinct values in the expression.',
+}
+
+
+for _name, _doc in _functions.items():
+ globals()[_name] = _create_function(_name, _doc)
+del _name, _doc
+__all__ += _functions.keys()
+__all__.sort()
+
+
+def countDistinct(col, *cols):
+ """ Return a new Column for distinct count of `col` or `cols`
+
+ >>> df.agg(countDistinct(df.age, df.name).alias('c')).collect()
+ [Row(c=2)]
+
+ >>> df.agg(countDistinct("age", "name").alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols], sc._gateway._gateway_client)
+ jc = sc._jvm.functions.countDistinct(_to_java_column(col), sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+
+def approxCountDistinct(col, rsd=None):
+ """ Return a new Column for approximate distinct count of `col`
+
+ >>> df.agg(approxCountDistinct(df.age).alias('c')).collect()
+ [Row(c=2)]
+ """
+ sc = SparkContext._active_spark_context
+ if rsd is None:
+ jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col))
+ else:
+ jc = sc._jvm.functions.approxCountDistinct(_to_java_column(col), rsd)
+ return Column(jc)
+
+
+class UserDefinedFunction(object):
+ """
+ User defined function in Python
+ """
+ def __init__(self, func, returnType):
+ self.func = func
+ self.returnType = returnType
+ self._broadcast = None
+ self._judf = self._create_judf()
+
+ def _create_judf(self):
+ f = self.func # put it in closure `func`
+ func = lambda _, it: imap(lambda x: f(*x), it)
+ ser = AutoBatchedSerializer(PickleSerializer())
+ command = (func, None, ser, ser)
+ sc = SparkContext._active_spark_context
+ pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self)
+ ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc())
+ jdt = ssql_ctx.parseDataType(self.returnType.json())
+ judf = sc._jvm.UserDefinedPythonFunction(f.__name__, bytearray(pickled_command), env,
+ includes, sc.pythonExec, broadcast_vars,
+ sc._javaAccumulator, jdt)
+ return judf
+
+ def __del__(self):
+ if self._broadcast is not None:
+ self._broadcast.unpersist()
+ self._broadcast = None
+
+ def __call__(self, *cols):
+ sc = SparkContext._active_spark_context
+ jcols = ListConverter().convert([_to_java_column(c) for c in cols],
+ sc._gateway._gateway_client)
+ jc = self._judf.apply(sc._jvm.PythonUtils.toSeq(jcols))
+ return Column(jc)
+
+
+def udf(f, returnType=StringType()):
+ """Create a user defined function (UDF)
+
+ >>> from pyspark.sql.types import IntegerType
+ >>> slen = udf(lambda s: len(s), IntegerType())
+ >>> df.select(slen(df.name).alias('slen')).collect()
+ [Row(slen=5), Row(slen=3)]
+ """
+ return UserDefinedFunction(f, returnType)
+
+
+def _test():
+ import doctest
+ from pyspark.context import SparkContext
+ from pyspark.sql import Row, SQLContext
+ import pyspark.sql.functions
+ globs = pyspark.sql.functions.__dict__.copy()
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['sqlCtx'] = SQLContext(sc)
+ globs['df'] = sc.parallelize([Row(name='Alice', age=2), Row(name='Bob', age=5)]).toDF()
+ (failure_count, test_count) = doctest.testmod(
+ pyspark.sql.functions, globs=globs,
+ optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index d25c6365ed067..2720439416682 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -24,6 +24,9 @@
import pydoc
import shutil
import tempfile
+import pickle
+
+import py4j
if sys.version_info[:2] <= (2, 6):
try:
@@ -34,10 +37,9 @@
else:
import unittest
-
-from pyspark.sql import SQLContext, Column
-from pyspark.sql.types import IntegerType, Row, ArrayType, StructType, StructField, \
- UserDefinedType, DoubleType, LongType
+from pyspark.sql import SQLContext, HiveContext, Column, Row
+from pyspark.sql.types import *
+from pyspark.sql.types import UserDefinedType, _infer_type
from pyspark.tests import ReusedPySparkTestCase
@@ -87,6 +89,14 @@ def __eq__(self, other):
other.x == self.x and other.y == self.y
+class DataTypeTests(unittest.TestCase):
+ # regression test for SPARK-6055
+ def test_data_type_eq(self):
+ lt = LongType()
+ lt2 = pickle.loads(pickle.dumps(LongType()))
+ self.assertEquals(lt, lt2)
+
+
class SQLTests(ReusedPySparkTestCase):
@classmethod
@@ -97,7 +107,7 @@ def setUpClass(cls):
cls.sqlCtx = SQLContext(cls.sc)
cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
rdd = cls.sc.parallelize(cls.testData)
- cls.df = cls.sqlCtx.inferSchema(rdd)
+ cls.df = rdd.toDF()
@classmethod
def tearDownClass(cls):
@@ -111,14 +121,14 @@ def test_udf(self):
def test_udf2(self):
self.sqlCtx.registerFunction("strlen", lambda string: len(string), IntegerType())
- self.sqlCtx.inferSchema(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
+ self.sqlCtx.createDataFrame(self.sc.parallelize([Row(a="test")])).registerTempTable("test")
[res] = self.sqlCtx.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect()
self.assertEqual(4, res[0])
def test_udf_with_array_type(self):
d = [Row(l=range(3), d={"key": range(5)})]
rdd = self.sc.parallelize(d)
- self.sqlCtx.inferSchema(rdd).registerTempTable("test")
+ self.sqlCtx.createDataFrame(rdd).registerTempTable("test")
self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType()))
self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType())
[(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect()
@@ -139,7 +149,7 @@ def test_basic_functions(self):
df = self.sqlCtx.jsonRDD(rdd)
df.count()
df.collect()
- df.schema()
+ df.schema
# cache and checkpoint
self.assertFalse(df.is_cached)
@@ -156,17 +166,17 @@ def test_basic_functions(self):
def test_apply_schema_to_row(self):
df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
- df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema())
+ df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema)
self.assertEqual(df.collect(), df2.collect())
rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
- df3 = self.sqlCtx.applySchema(rdd, df.schema())
+ df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
row = df.head()
self.assertEqual(1, len(row.l))
self.assertEqual(1, row.l[0].a)
@@ -185,28 +195,89 @@ def test_serialize_nested_array_and_map(self):
self.assertEqual("2", row.d)
def test_infer_schema(self):
- d = [Row(l=[], d={}),
+ d = [Row(l=[], d={}, s=None),
Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sqlCtx.createDataFrame(rdd)
self.assertEqual([], df.map(lambda r: r.l).first())
self.assertEqual([None, ""], df.map(lambda r: r.s).collect())
df.registerTempTable("test")
result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
- df2 = self.sqlCtx.inferSchema(rdd, 1.0)
- self.assertEqual(df.schema(), df2.schema())
+ df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0)
+ self.assertEqual(df.schema, df2.schema)
self.assertEqual({}, df2.map(lambda r: r.d).first())
self.assertEqual([None, ""], df2.map(lambda r: r.s).collect())
df2.registerTempTable("test2")
result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
self.assertEqual(1, result.head()[0])
+ def test_infer_nested_schema(self):
+ NestedRow = Row("f1", "f2")
+ nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}),
+ NestedRow([2, 3], {"row2": 2.0})])
+ df = self.sqlCtx.inferSchema(nestedRdd1)
+ self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0])
+
+ nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]),
+ NestedRow([[2, 3], [3, 4]], [2, 3])])
+ df = self.sqlCtx.inferSchema(nestedRdd2)
+ self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0])
+
+ from collections import namedtuple
+ CustomRow = namedtuple('CustomRow', 'field1 field2')
+ rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"),
+ CustomRow(field1=2, field2="row2"),
+ CustomRow(field1=3, field2="row3")])
+ df = self.sqlCtx.inferSchema(rdd)
+ self.assertEquals(Row(field1=1, field2=u'row1'), df.first())
+
+ def test_apply_schema(self):
+ from datetime import date, datetime
+ rdd = self.sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
+ date(2010, 1, 1), datetime(2010, 1, 1, 1, 1, 1),
+ {"a": 1}, (2,), [1, 2, 3], None)])
+ schema = StructType([
+ StructField("byte1", ByteType(), False),
+ StructField("byte2", ByteType(), False),
+ StructField("short1", ShortType(), False),
+ StructField("short2", ShortType(), False),
+ StructField("int1", IntegerType(), False),
+ StructField("float1", FloatType(), False),
+ StructField("date1", DateType(), False),
+ StructField("time1", TimestampType(), False),
+ StructField("map1", MapType(StringType(), IntegerType(), False), False),
+ StructField("struct1", StructType([StructField("b", ShortType(), False)]), False),
+ StructField("list1", ArrayType(ByteType(), False), False),
+ StructField("null1", DoubleType(), True)])
+ df = self.sqlCtx.applySchema(rdd, schema)
+ results = df.map(lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int1, x.float1, x.date1,
+ x.time1, x.map1["a"], x.struct1.b, x.list1, x.null1))
+ r = (127, -128, -32768, 32767, 2147483647, 1.0, date(2010, 1, 1),
+ datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ self.assertEqual(r, results.first())
+
+ df.registerTempTable("table2")
+ r = self.sqlCtx.sql("SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " +
+ "short1 + 1 AS short1, short2 - 1 AS short2, int1 - 1 AS int1, " +
+ "float1 + 1.5 as float1 FROM table2").first()
+
+ self.assertEqual((126, -127, -32767, 32766, 2147483646, 2.5), tuple(r))
+
+ from pyspark.sql.types import _parse_schema_abstract, _infer_schema_type
+ rdd = self.sc.parallelize([(127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1),
+ {"a": 1}, (2,), [1, 2, 3])])
+ abstract = "byte1 short1 float1 time1 map1{} struct1(b) list1[]"
+ schema = _parse_schema_abstract(abstract)
+ typedSchema = _infer_schema_type(rdd.first(), schema)
+ df = self.sqlCtx.applySchema(rdd, typedSchema)
+ r = (127, -32768, 1.0, datetime(2010, 1, 1, 1, 1, 1), {"a": 1}, Row(b=2), [1, 2, 3])
+ self.assertEqual(r, tuple(df.first()))
+
def test_struct_in_map(self):
d = [Row(m={Row(i=1): Row(s="")})]
- rdd = self.sc.parallelize(d)
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sc.parallelize(d).toDF()
k, v = df.head().m.items()[0]
self.assertEqual(1, k.i)
self.assertEqual("", v.s)
@@ -214,8 +285,7 @@ def test_struct_in_map(self):
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.inferSchema(rdd)
+ df = self.sc.parallelize([row]).toDF()
df.registerTempTable("test")
row = self.sqlCtx.sql("select l, d from test").head()
self.assertEqual(1, row.asDict()["l"][0].a)
@@ -224,9 +294,8 @@ def test_convert_row_to_dict(self):
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df = self.sqlCtx.inferSchema(rdd)
- schema = df.schema()
+ df = self.sc.parallelize([row]).toDF()
+ schema = df.schema
field = [f for f in schema.fields if f.name == "point"][0]
self.assertEqual(type(field.dataType), ExamplePointUDT)
df.registerTempTable("labeled_point")
@@ -239,15 +308,14 @@ def test_apply_schema_with_udt(self):
rdd = self.sc.parallelize([row])
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
- df = self.sqlCtx.applySchema(rdd, schema)
+ df = rdd.toDF(schema)
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
- rdd = self.sc.parallelize([row])
- df0 = self.sqlCtx.inferSchema(rdd)
+ df0 = self.sc.parallelize([row]).toDF()
output_dir = os.path.join(self.tempdir.name, "labeled_point")
df0.saveAsParquetFile(output_dir)
df1 = self.sqlCtx.parquetFile(output_dir)
@@ -281,10 +349,42 @@ def test_aggregator(self):
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
- from pyspark.sql import Dsl
- self.assertEqual((0, u'99'), tuple(g.agg(Dsl.first(df.key), Dsl.last(df.value)).first()))
- self.assertTrue(95 < g.agg(Dsl.approxCountDistinct(df.key)).first()[0])
- self.assertEqual(100, g.agg(Dsl.countDistinct(df.value)).first()[0])
+ from pyspark.sql import functions
+ self.assertEqual((0, u'99'),
+ tuple(g.agg(functions.first(df.key), functions.last(df.value)).first()))
+ self.assertTrue(95 < g.agg(functions.approxCountDistinct(df.key)).first()[0])
+ self.assertEqual(100, g.agg(functions.countDistinct(df.value)).first()[0])
+
+ def test_save_and_load(self):
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.save(tmpPath, "org.apache.spark.sql.json", "error")
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json", schema)
+ self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+
+ df.save(tmpPath, "org.apache.spark.sql.json", "overwrite")
+ actual = self.sqlCtx.load(tmpPath, "org.apache.spark.sql.json")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ df.save(source="org.apache.spark.sql.json", mode="overwrite", path=tmpPath,
+ noUse="this options will not be used in save.")
+ actual = self.sqlCtx.load(source="org.apache.spark.sql.json", path=tmpPath,
+ noUse="this options will not be used in load.")
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ actual = self.sqlCtx.load(path=tmpPath)
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
def test_help_command(self):
# Regression test for SPARK-5464
@@ -295,6 +395,102 @@ def test_help_command(self):
pydoc.render_doc(df.foo)
pydoc.render_doc(df.take(1))
+ def test_infer_long_type(self):
+ longrow = [Row(f1='a', f2=100000000000000)]
+ df = self.sc.parallelize(longrow).toDF()
+ self.assertEqual(df.schema.fields[1].dataType, LongType())
+
+ # this saving as Parquet caused issues as well.
+ output_dir = os.path.join(self.tempdir.name, "infer_long_type")
+ df.saveAsParquetFile(output_dir)
+ df1 = self.sqlCtx.parquetFile(output_dir)
+ self.assertEquals('a', df1.first().f1)
+ self.assertEquals(100000000000000, df1.first().f2)
+
+ self.assertEqual(_infer_type(1), LongType())
+ self.assertEqual(_infer_type(2**10), LongType())
+ self.assertEqual(_infer_type(2**20), LongType())
+ self.assertEqual(_infer_type(2**31 - 1), LongType())
+ self.assertEqual(_infer_type(2**31), LongType())
+ self.assertEqual(_infer_type(2**61), LongType())
+ self.assertEqual(_infer_type(2**71), LongType())
+
+
+class HiveContextSQLTests(ReusedPySparkTestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ try:
+ cls.sc._jvm.org.apache.hadoop.hive.conf.HiveConf()
+ except py4j.protocol.Py4JError:
+ cls.sqlCtx = None
+ return
+ os.unlink(cls.tempdir.name)
+ _scala_HiveContext =\
+ cls.sc._jvm.org.apache.spark.sql.hive.test.TestHiveContext(cls.sc._jsc.sc())
+ cls.sqlCtx = HiveContext(cls.sc, _scala_HiveContext)
+ cls.testData = [Row(key=i, value=str(i)) for i in range(100)]
+ cls.df = cls.sc.parallelize(cls.testData).toDF()
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
+ def test_save_and_load_table(self):
+ if self.sqlCtx is None:
+ return # no hive available, skipped
+
+ df = self.df
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "append", path=tmpPath)
+ actual = self.sqlCtx.createExternalTable("externalJsonTable", tmpPath,
+ "org.apache.spark.sql.json")
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+
+ df.saveAsTable("savedJsonTable", "org.apache.spark.sql.json", "overwrite", path=tmpPath)
+ schema = StructType([StructField("value", StringType(), True)])
+ actual = self.sqlCtx.createExternalTable("externalJsonTable",
+ source="org.apache.spark.sql.json",
+ schema=schema, path=tmpPath,
+ noUse="this options will not be used")
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.select("value").collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.select("value").collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE savedJsonTable")
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+
+ defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default",
+ "org.apache.spark.sql.parquet")
+ self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json")
+ df.saveAsTable("savedJsonTable", path=tmpPath, mode="overwrite")
+ actual = self.sqlCtx.createExternalTable("externalJsonTable", path=tmpPath)
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM savedJsonTable").collect()))
+ self.assertTrue(
+ sorted(df.collect()) ==
+ sorted(self.sqlCtx.sql("SELECT * FROM externalJsonTable").collect()))
+ self.assertTrue(sorted(df.collect()) == sorted(actual.collect()))
+ self.sqlCtx.sql("DROP TABLE savedJsonTable")
+ self.sqlCtx.sql("DROP TABLE externalJsonTable")
+ self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName)
+
+ shutil.rmtree(tmpPath)
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 41afefe48ee5e..31a861e1feb46 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -21,6 +21,7 @@
import warnings
import json
import re
+import weakref
from array import array
from operator import itemgetter
@@ -28,7 +29,7 @@
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
- "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", ]
+ "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
class DataType(object):
@@ -42,8 +43,7 @@ def __hash__(self):
return hash(str(self))
def __eq__(self, other):
- return (isinstance(other, self.__class__) and
- self.__dict__ == other.__dict__)
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
@@ -52,6 +52,9 @@ def __ne__(self, other):
def typeName(cls):
return cls.__name__[:-4].lower()
+ def simpleString(self):
+ return self.typeName()
+
def jsonValue(self):
return self.typeName()
@@ -61,6 +64,8 @@ def json(self):
sort_keys=True)
+# This singleton pattern does not work with pickle, you will get
+# another object after pickle and unpickle
class PrimitiveTypeSingleton(type):
"""Metaclass for PrimitiveType"""
@@ -79,10 +84,6 @@ class PrimitiveType(DataType):
__metaclass__ = PrimitiveTypeSingleton
- def __eq__(self, other):
- # because they should be the same object
- return self is other
-
class NullType(PrimitiveType):
@@ -145,6 +146,12 @@ def __init__(self, precision=None, scale=None):
self.scale = scale
self.hasPrecisionInfo = precision is not None
+ def simpleString(self):
+ if self.hasPrecisionInfo:
+ return "decimal(%d,%d)" % (self.precision, self.scale)
+ else:
+ return "decimal(10,0)"
+
def jsonValue(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
@@ -180,6 +187,8 @@ class ByteType(PrimitiveType):
The data type representing int values with 1 singed byte.
"""
+ def simpleString(self):
+ return 'tinyint'
class IntegerType(PrimitiveType):
@@ -188,6 +197,8 @@ class IntegerType(PrimitiveType):
The data type representing int values.
"""
+ def simpleString(self):
+ return 'int'
class LongType(PrimitiveType):
@@ -198,6 +209,8 @@ class LongType(PrimitiveType):
beyond the range of [-9223372036854775808, 9223372036854775807],
please use DecimalType.
"""
+ def simpleString(self):
+ return 'bigint'
class ShortType(PrimitiveType):
@@ -206,6 +219,8 @@ class ShortType(PrimitiveType):
The data type representing int values with 2 signed bytes.
"""
+ def simpleString(self):
+ return 'smallint'
class ArrayType(DataType):
@@ -225,14 +240,18 @@ def __init__(self, elementType, containsNull=True):
:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains None values.
- >>> ArrayType(StringType) == ArrayType(StringType, True)
+ >>> ArrayType(StringType()) == ArrayType(StringType(), True)
True
- >>> ArrayType(StringType, False) == ArrayType(StringType)
+ >>> ArrayType(StringType(), False) == ArrayType(StringType())
False
"""
+ assert isinstance(elementType, DataType), "elementType should be DataType"
self.elementType = elementType
self.containsNull = containsNull
+ def simpleString(self):
+ return 'array<%s>' % self.elementType.simpleString()
+
def __repr__(self):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())
@@ -272,17 +291,22 @@ def __init__(self, keyType, valueType, valueContainsNull=True):
:param valueContainsNull: indicates whether values contains
null values.
- >>> (MapType(StringType, IntegerType)
- ... == MapType(StringType, IntegerType, True))
+ >>> (MapType(StringType(), IntegerType())
+ ... == MapType(StringType(), IntegerType(), True))
True
- >>> (MapType(StringType, IntegerType, False)
- ... == MapType(StringType, FloatType))
+ >>> (MapType(StringType(), IntegerType(), False)
+ ... == MapType(StringType(), FloatType()))
False
"""
+ assert isinstance(keyType, DataType), "keyType should be DataType"
+ assert isinstance(valueType, DataType), "valueType should be DataType"
self.keyType = keyType
self.valueType = valueType
self.valueContainsNull = valueContainsNull
+ def simpleString(self):
+ return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString())
+
def __repr__(self):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())
@@ -325,18 +349,22 @@ def __init__(self, name, dataType, nullable=True, metadata=None):
to simple type that can be serialized to JSON
automatically
- >>> (StructField("f1", StringType, True)
- ... == StructField("f1", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f1", StringType(), True))
True
- >>> (StructField("f1", StringType, True)
- ... == StructField("f2", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f2", StringType(), True))
False
"""
+ assert isinstance(dataType, DataType), "dataType should be DataType"
self.name = name
self.dataType = dataType
self.nullable = nullable
self.metadata = metadata or {}
+ def simpleString(self):
+ return '%s:%s' % (self.name, self.dataType.simpleString())
+
def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())
@@ -367,18 +395,22 @@ class StructType(DataType):
def __init__(self, fields):
"""Creates a StructType
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True)])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True),
- ... [StructField("f2", IntegerType, False)]])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True),
+ ... StructField("f2", IntegerType(), False)])
>>> struct1 == struct2
False
"""
+ assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
self.fields = fields
+ def simpleString(self):
+ return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
+
def __repr__(self):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))
@@ -435,6 +467,9 @@ def deserialize(self, datum):
"""
raise NotImplementedError("UDT must implement deserialize().")
+ def simpleString(self):
+ return 'null'
+
def json(self):
return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
@@ -473,20 +508,24 @@ def __eq__(self, other):
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
+ >>> import pickle
>>> def check_datatype(datatype):
+ ... pickled = pickle.loads(pickle.dumps(datatype))
+ ... assert datatype == pickled
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
- ... return datatype == python_datatype
- >>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
- True
+ ... assert datatype == python_datatype
+ >>> for cls in _all_primitive_types.values():
+ ... check_datatype(cls())
+
>>> # Simple ArrayType.
>>> simple_arraytype = ArrayType(StringType(), True)
>>> check_datatype(simple_arraytype)
- True
+
>>> # Simple MapType.
>>> simple_maptype = MapType(StringType(), LongType())
>>> check_datatype(simple_maptype)
- True
+
>>> # Simple StructType.
>>> simple_structtype = StructType([
... StructField("a", DecimalType(), False),
@@ -494,7 +533,7 @@ def _parse_datatype_json_string(json_string):
... StructField("c", LongType(), True),
... StructField("d", BinaryType(), False)])
>>> check_datatype(simple_structtype)
- True
+
>>> # Complex StructType.
>>> complex_structtype = StructType([
... StructField("simpleArray", simple_arraytype, True),
@@ -503,22 +542,20 @@ def _parse_datatype_json_string(json_string):
... StructField("boolean", BooleanType(), False),
... StructField("withMeta", DoubleType(), False, {"name": "age"})])
>>> check_datatype(complex_structtype)
- True
+
>>> # Complex ArrayType.
>>> complex_arraytype = ArrayType(complex_structtype, True)
>>> check_datatype(complex_arraytype)
- True
+
>>> # Complex MapType.
>>> complex_maptype = MapType(complex_structtype,
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
- True
+
>>> check_datatype(ExamplePointUDT())
- True
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> check_datatype(structtype_with_udt)
- True
"""
return _parse_datatype_json_value(json.loads(json_string))
@@ -551,7 +588,7 @@ def _parse_datatype_json_value(json_value):
_type_mappings = {
type(None): NullType,
bool: BooleanType,
- int: IntegerType,
+ int: LongType,
long: LongType,
float: DoubleType,
str: StringType,
@@ -572,7 +609,7 @@ def _infer_type(obj):
ExamplePointUDT
"""
if obj is None:
- raise ValueError("Can not infer type for None")
+ return NullType()
if hasattr(obj, '__UDT__'):
return obj.__UDT__
@@ -605,15 +642,14 @@ def _infer_schema(row):
if isinstance(row, dict):
items = sorted(row.items())
- elif isinstance(row, tuple):
+ elif isinstance(row, (tuple, list)):
if hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
elif hasattr(row, "__FIELDS__"): # Row
items = zip(row.__FIELDS__, tuple(row))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in row):
- items = row
else:
- raise ValueError("Can't infer schema from tuple")
+ names = ['_%d' % i for i in range(1, len(row) + 1)]
+ items = zip(names, row)
elif hasattr(row, "__dict__"): # object
items = sorted(row.__dict__.items())
@@ -755,8 +791,24 @@ def _merge_type(a, b):
return a
+def _need_converter(dataType):
+ if isinstance(dataType, StructType):
+ return True
+ elif isinstance(dataType, ArrayType):
+ return _need_converter(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
+ elif isinstance(dataType, NullType):
+ return True
+ else:
+ return False
+
+
def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
+ if not _need_converter(dataType):
+ return lambda x: x
+
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
@@ -775,29 +827,29 @@ def _create_converter(dataType):
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
+ convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
def convert_struct(obj):
if obj is None:
return
- if isinstance(obj, tuple):
- if hasattr(obj, "_fields"):
- d = dict(zip(obj._fields, obj))
- elif hasattr(obj, "__FIELDS__"):
- d = dict(zip(obj.__FIELDS__, obj))
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
- d = dict(obj)
+ if isinstance(obj, (tuple, list)):
+ if convert_fields:
+ return tuple(conv(v) for v, conv in zip(obj, converters))
else:
- raise ValueError("unexpected tuple: %s" % str(obj))
+ return tuple(obj)
- elif isinstance(obj, dict):
+ if isinstance(obj, dict):
d = obj
elif hasattr(obj, "__dict__"): # object
d = obj.__dict__
else:
raise ValueError("Unexpected obj: %s" % obj)
- return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ if convert_fields:
+ return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ else:
+ return tuple([d.get(name) for name in names])
return convert_struct
@@ -847,20 +899,20 @@ def _parse_field_abstract(s):
Parse a field in schema abstract
>>> _parse_field_abstract("a")
- StructField(a,None,true)
+ StructField(a,NullType,true)
>>> _parse_field_abstract("b(c d)")
- StructField(b,StructType(...c,None,true),StructField(d...
+ StructField(b,StructType(...c,NullType,true),StructField(d...
>>> _parse_field_abstract("a[]")
- StructField(a,ArrayType(None,true),true)
+ StructField(a,ArrayType(NullType,true),true)
>>> _parse_field_abstract("a{[]}")
- StructField(a,MapType(None,ArrayType(None,true),true),true)
+ StructField(a,MapType(NullType,ArrayType(NullType,true),true),true)
"""
if set(_BRACKETS.keys()) & set(s):
idx = min((s.index(c) for c in _BRACKETS if c in s))
name = s[:idx]
return StructField(name, _parse_schema_abstract(s[idx:]), True)
else:
- return StructField(s, None, True)
+ return StructField(s, NullType(), True)
def _parse_schema_abstract(s):
@@ -874,11 +926,11 @@ def _parse_schema_abstract(s):
>>> _parse_schema_abstract("c{} d{a b}")
StructType...c,MapType...d,MapType...a...b...
>>> _parse_schema_abstract("a b(t)").fields[1]
- StructField(b,StructType(List(StructField(t,None,true))),true)
+ StructField(b,StructType(List(StructField(t,NullType,true))),true)
"""
s = s.strip()
if not s:
- return
+ return NullType()
elif s.startswith('('):
return _parse_schema_abstract(s[1:-1])
@@ -887,7 +939,7 @@ def _parse_schema_abstract(s):
return ArrayType(_parse_schema_abstract(s[1:-1]), True)
elif s.startswith('{'):
- return MapType(None, _parse_schema_abstract(s[1:-1]))
+ return MapType(NullType(), _parse_schema_abstract(s[1:-1]))
parts = _split_schema_abstract(s)
fields = [_parse_field_abstract(p) for p in parts]
@@ -901,13 +953,13 @@ def _infer_schema_type(obj, dataType):
>>> schema = _parse_schema_abstract("a b c d")
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
>>> _infer_schema_type(row, schema)
- StructType...IntegerType...DoubleType...StringType...DateType...
+ StructType...LongType...DoubleType...StringType...DateType...
>>> row = [[1], {"key": (1, 2.0)}]
>>> schema = _parse_schema_abstract("a[] b{c d}")
>>> _infer_schema_type(row, schema)
- StructType...a,ArrayType...b,MapType(StringType,...c,IntegerType...
+ StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
- if dataType is None:
+ if dataType is NullType():
return _infer_type(obj)
if not obj:
@@ -960,7 +1012,7 @@ def _verify_type(obj, dataType):
>>> _verify_type(None, StructType([]))
>>> _verify_type("", StringType())
- >>> _verify_type(0, IntegerType())
+ >>> _verify_type(0, LongType())
>>> _verify_type(range(3), ArrayType(ShortType()))
>>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
@@ -990,7 +1042,7 @@ def _verify_type(obj, dataType):
return
_type = type(dataType)
- assert _type in _acceptable_types, "unkown datatype: %s" % dataType
+ assert _type in _acceptable_types, "unknown datatype: %s" % dataType
# subclass of them can not be deserialized in JVM
if type(obj) not in _acceptable_types[_type]:
@@ -1008,13 +1060,12 @@ def _verify_type(obj, dataType):
elif isinstance(dataType, StructType):
if len(obj) != len(dataType.fields):
- raise ValueError("Length of object (%d) does not match with"
+ raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)
-
-_cached_cls = {}
+_cached_cls = weakref.WeakValueDictionary()
def _restore_object(dataType, obj):
@@ -1209,8 +1260,7 @@ def __new__(self, *args, **kwargs):
elif kwargs:
# create row objects
names = sorted(kwargs.keys())
- values = tuple(kwargs[n] for n in names)
- row = tuple.__new__(self, values)
+ row = tuple.__new__(self, [kwargs[n] for n in names])
row.__FIELDS__ = names
return row
diff --git a/python/pyspark/status.py b/python/pyspark/status.py
new file mode 100644
index 0000000000000..a6fa7dd3144d4
--- /dev/null
+++ b/python/pyspark/status.py
@@ -0,0 +1,96 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from collections import namedtuple
+
+__all__ = ["SparkJobInfo", "SparkStageInfo", "StatusTracker"]
+
+
+class SparkJobInfo(namedtuple("SparkJobInfo", "jobId stageIds status")):
+ """
+ Exposes information about Spark Jobs.
+ """
+
+
+class SparkStageInfo(namedtuple("SparkStageInfo",
+ "stageId currentAttemptId name numTasks numActiveTasks "
+ "numCompletedTasks numFailedTasks")):
+ """
+ Exposes information about Spark Stages.
+ """
+
+
+class StatusTracker(object):
+ """
+ Low-level status reporting APIs for monitoring job and stage progress.
+
+ These APIs intentionally provide very weak consistency semantics;
+ consumers of these APIs should be prepared to handle empty / missing
+ information. For example, a job's stage ids may be known but the status
+ API may not have any information about the details of those stages, so
+ `getStageInfo` could potentially return `None` for a valid stage id.
+
+ To limit memory usage, these APIs only provide information on recent
+ jobs / stages. These APIs will provide information for the last
+ `spark.ui.retainedStages` stages and `spark.ui.retainedJobs` jobs.
+ """
+ def __init__(self, jtracker):
+ self._jtracker = jtracker
+
+ def getJobIdsForGroup(self, jobGroup=None):
+ """
+ Return a list of all known jobs in a particular job group. If
+ `jobGroup` is None, then returns all known jobs that are not
+ associated with a job group.
+
+ The returned list may contain running, failed, and completed jobs,
+ and may vary across invocations of this method. This method does
+ not guarantee the order of the elements in its result.
+ """
+ return list(self._jtracker.getJobIdsForGroup(jobGroup))
+
+ def getActiveStageIds(self):
+ """
+ Returns an array containing the ids of all active stages.
+ """
+ return sorted(list(self._jtracker.getActiveStageIds()))
+
+ def getActiveJobsIds(self):
+ """
+ Returns an array containing the ids of all active jobs.
+ """
+ return sorted((list(self._jtracker.getActiveJobIds())))
+
+ def getJobInfo(self, jobId):
+ """
+ Returns a :class:`SparkJobInfo` object, or None if the job info
+ could not be found or was garbage collected.
+ """
+ job = self._jtracker.getJobInfo(jobId)
+ if job is not None:
+ return SparkJobInfo(jobId, job.stageIds(), str(job.status()))
+
+ def getStageInfo(self, stageId):
+ """
+ Returns a :class:`SparkStageInfo` object, or None if the stage
+ info could not be found or was garbage collected.
+ """
+ stage = self._jtracker.getStageInfo(stageId)
+ if stage is not None:
+ # TODO: fetch them in batch for better performance
+ attrs = [getattr(stage, f)() for f in SparkStageInfo._fields[1:]]
+ return SparkStageInfo(stageId, *attrs)
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index b06ab650370bd..2c73083c9f9a8 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -189,7 +189,7 @@ def awaitTermination(self, timeout=None):
if timeout is None:
self._jssc.awaitTermination()
else:
- self._jssc.awaitTermination(int(timeout * 1000))
+ self._jssc.awaitTerminationOrTimeout(int(timeout * 1000))
def awaitTerminationOrTimeout(self, timeout):
"""
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
index 2fe39392ff081..3fa42444239f7 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -578,7 +578,7 @@ def reduceFunc(t, a, b):
if a is None:
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
else:
- g = a.cogroup(b, numPartitions)
+ g = a.cogroup(b.partitionBy(numPartitions), numPartitions)
g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
return state.filter(lambda (k, v): v is not None)
diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py
index 19ad71f99d4d5..0002dc10e8a17 100644
--- a/python/pyspark/streaming/kafka.py
+++ b/python/pyspark/streaming/kafka.py
@@ -16,7 +16,7 @@
#
from py4j.java_collections import MapConverter
-from py4j.java_gateway import java_import, Py4JError
+from py4j.java_gateway import java_import, Py4JError, Py4JJavaError
from pyspark.storagelevel import StorageLevel
from pyspark.serializers import PairDeserializer, NoOpSerializer
@@ -50,8 +50,6 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
:param valueDecoder: A function used to decode value (default is utf8_decoder)
:return: A DStream object
"""
- java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils")
-
kafkaParams.update({
"zookeeper.connect": zkQuorum,
"group.id": groupId,
@@ -63,20 +61,34 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)
- def getClassByName(name):
- return ssc._jvm.org.apache.spark.util.Utils.classForName(name)
-
try:
- array = getClassByName("[B")
- decoder = getClassByName("kafka.serializer.DefaultDecoder")
- jstream = ssc._jvm.KafkaUtils.createStream(ssc._jssc, array, array, decoder, decoder,
- jparam, jtopics, jlevel)
- except Py4JError, e:
+ # Use KafkaUtilsPythonHelper to access Scala's KafkaUtils (see SPARK-6027)
+ helperClass = ssc._jvm.java.lang.Thread.currentThread().getContextClassLoader()\
+ .loadClass("org.apache.spark.streaming.kafka.KafkaUtilsPythonHelper")
+ helper = helperClass.newInstance()
+ jstream = helper.createStream(ssc._jssc, jparam, jtopics, jlevel)
+ except Py4JJavaError, e:
# TODO: use --jar once it also work on driver
- if not e.message or 'call a package' in e.message:
- print "No kafka package, please put the assembly jar into classpath:"
- print " $ bin/spark-submit --driver-class-path external/kafka-assembly/target/" + \
- "scala-*/spark-streaming-kafka-assembly-*.jar"
+ if 'ClassNotFoundException' in str(e.java_exception):
+ print """
+________________________________________________________________________________________________
+
+ Spark Streaming's Kafka libraries not found in class path. Try one of the following.
+
+ 1. Include the Kafka library and its dependencies with in the
+ spark-submit command as
+
+ $ bin/spark-submit --packages org.apache.spark:spark-streaming-kafka:%s ...
+
+ 2. Download the JAR of the artifact from Maven Central http://search.maven.org/,
+ Group Id = org.apache.spark, Artifact Id = spark-streaming-kafka-assembly, Version = %s.
+ Then, innclude the jar in the spark-submit command as
+
+ $ bin/spark-submit --jars ...
+
+________________________________________________________________________________________________
+
+""" % (ssc.sparkContext.version, ssc.sparkContext.version)
raise e
ser = PairDeserializer(NoOpSerializer(), NoOpSerializer())
stream = DStream(jstream, ssc, ser)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index b5e28c498040b..06ba2b461d53e 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -543,6 +543,12 @@ def test_zip_with_different_serializers(self):
# regression test for bug in _reserializer()
self.assertEqual(cnt, t.zip(rdd).count())
+ def test_zip_with_different_object_sizes(self):
+ # regress test for SPARK-5973
+ a = self.sc.parallelize(range(10000)).map(lambda i: '*' * i)
+ b = self.sc.parallelize(range(10000, 20000)).map(lambda i: '*' * i)
+ self.assertEqual(10000, a.zip(b).count())
+
def test_zip_with_different_number_of_items(self):
a = self.sc.parallelize(range(5), 2)
# different number of partitions
@@ -727,7 +733,6 @@ def test_multiple_python_java_RDD_conversions(self):
(u'1', {u'director': u'David Lean'}),
(u'2', {u'director': u'Andrew Dominik'})
]
- from pyspark.rdd import RDD
data_rdd = self.sc.parallelize(data)
data_java_rdd = data_rdd._to_java_object_rdd()
data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd)
@@ -740,6 +745,43 @@ def test_multiple_python_java_RDD_conversions(self):
converted_rdd = RDD(data_python_rdd, self.sc)
self.assertEqual(2, converted_rdd.count())
+ def test_narrow_dependency_in_join(self):
+ rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x))
+ parted = rdd.partitionBy(2)
+ self.assertEqual(2, parted.union(parted).getNumPartitions())
+ self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
+ self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
+
+ self.sc.setJobGroup("test1", "test", True)
+ tracker = self.sc.statusTracker()
+
+ d = sorted(parted.join(parted).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual((0, (0, 0)), d[0])
+ jobId = tracker.getJobIdsForGroup("test1")[0]
+ self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test2", "test", True)
+ d = sorted(parted.join(rdd).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual((0, (0, 0)), d[0])
+ jobId = tracker.getJobIdsForGroup("test2")[0]
+ self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test3", "test", True)
+ d = sorted(parted.cogroup(parted).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual([[0], [0]], map(list, d[0][1]))
+ jobId = tracker.getJobIdsForGroup("test3")[0]
+ self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test4", "test", True)
+ d = sorted(parted.cogroup(rdd).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual([[0], [0]], map(list, d[0][1]))
+ jobId = tracker.getJobIdsForGroup("test4")[0]
+ self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
+
class ProfilerTests(PySparkTestCase):
@@ -1404,31 +1446,59 @@ def setUp(self):
def tearDown(self):
shutil.rmtree(self.programDir)
- def createTempFile(self, name, content):
+ def createTempFile(self, name, content, dir=None):
"""
Create a temp file with the given name and content and return its path.
Strips leading spaces from content up to the first '|' in each line.
"""
pattern = re.compile(r'^ *\|', re.MULTILINE)
content = re.sub(pattern, '', content.strip())
- path = os.path.join(self.programDir, name)
+ if dir is None:
+ path = os.path.join(self.programDir, name)
+ else:
+ os.makedirs(os.path.join(self.programDir, dir))
+ path = os.path.join(self.programDir, dir, name)
with open(path, "w") as f:
f.write(content)
return path
- def createFileInZip(self, name, content):
+ def createFileInZip(self, name, content, ext=".zip", dir=None, zip_name=None):
"""
Create a zip archive containing a file with the given content and return its path.
Strips leading spaces from content up to the first '|' in each line.
"""
pattern = re.compile(r'^ *\|', re.MULTILINE)
content = re.sub(pattern, '', content.strip())
- path = os.path.join(self.programDir, name + ".zip")
+ if dir is None:
+ path = os.path.join(self.programDir, name + ext)
+ else:
+ path = os.path.join(self.programDir, dir, zip_name + ext)
zip = zipfile.ZipFile(path, 'w')
zip.writestr(name, content)
zip.close()
return path
+ def create_spark_package(self, artifact_name):
+ group_id, artifact_id, version = artifact_name.split(":")
+ self.createTempFile("%s-%s.pom" % (artifact_id, version), ("""
+ |
+ |
+ | 4.0.0
+ | %s
+ | %s
+ | %s
+ |
+ """ % (group_id, artifact_id, version)).lstrip(),
+ os.path.join(group_id, artifact_id, version))
+ self.createFileInZip("%s.py" % artifact_id, """
+ |def myfunc(x):
+ | return x + 1
+ """, ".jar", os.path.join(group_id, artifact_id, version),
+ "%s-%s" % (artifact_id, version))
+
def test_single_script(self):
"""Submit and test a single script file"""
script = self.createTempFile("test.py", """
@@ -1497,6 +1567,39 @@ def test_module_dependency_on_cluster(self):
self.assertEqual(0, proc.returncode)
self.assertIn("[2, 3, 4]", out)
+ def test_package_dependency(self):
+ """Submit and test a script with a dependency on a Spark Package"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |from mylib import myfunc
+ |
+ |sc = SparkContext()
+ |print sc.parallelize([1, 2, 3]).map(myfunc).collect()
+ """)
+ self.create_spark_package("a:mylib:0.1")
+ proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
+ "file:" + self.programDir, script], stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[2, 3, 4]", out)
+
+ def test_package_dependency_on_cluster(self):
+ """Submit and test a script with a dependency on a Spark Package on a cluster"""
+ script = self.createTempFile("test.py", """
+ |from pyspark import SparkContext
+ |from mylib import myfunc
+ |
+ |sc = SparkContext()
+ |print sc.parallelize([1, 2, 3]).map(myfunc).collect()
+ """)
+ self.create_spark_package("a:mylib:0.1")
+ proc = subprocess.Popen([self.sparkSubmit, "--packages", "a:mylib:0.1", "--repositories",
+ "file:" + self.programDir, "--master",
+ "local-cluster[1,1,512]", script], stdout=subprocess.PIPE)
+ out, err = proc.communicate()
+ self.assertEqual(0, proc.returncode)
+ self.assertIn("[2, 3, 4]", out)
+
def test_single_script_on_cluster(self):
"""Submit and test a single script on a cluster"""
script = self.createTempFile("test.py", """
@@ -1550,6 +1653,37 @@ def test_with_stop(self):
sc.stop()
self.assertEqual(SparkContext._active_spark_context, None)
+ def test_progress_api(self):
+ with SparkContext() as sc:
+ sc.setJobGroup('test_progress_api', '', True)
+
+ rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100))
+ t = threading.Thread(target=rdd.collect)
+ t.daemon = True
+ t.start()
+ # wait for scheduler to start
+ time.sleep(1)
+
+ tracker = sc.statusTracker()
+ jobIds = tracker.getJobIdsForGroup('test_progress_api')
+ self.assertEqual(1, len(jobIds))
+ job = tracker.getJobInfo(jobIds[0])
+ self.assertEqual(1, len(job.stageIds))
+ stage = tracker.getStageInfo(job.stageIds[0])
+ self.assertEqual(rdd.getNumPartitions(), stage.numTasks)
+
+ sc.cancelAllJobs()
+ t.join()
+ # wait for event listener to update the status
+ time.sleep(1)
+
+ job = tracker.getJobInfo(jobIds[0])
+ self.assertEqual('FAILED', job.status)
+ self.assertEqual([], tracker.getActiveJobsIds())
+ self.assertEqual([], tracker.getActiveStageIds())
+
+ sc.stop()
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
diff --git a/python/run-tests b/python/run-tests
index 58a26dd8ff088..a2c2f37a54eda 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -35,7 +35,7 @@ rm -rf metastore warehouse
function run_test() {
echo "Running test: $1" | tee -a $LOG_FILE
- SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 2>&1 | tee -a $LOG_FILE
+ SPARK_TESTING=1 time "$FWDIR"/bin/pyspark $1 > $LOG_FILE 2>&1
FAILED=$((PIPESTATUS[0]||$FAILED))
@@ -67,6 +67,7 @@ function run_sql_tests() {
run_test "pyspark/sql/types.py"
run_test "pyspark/sql/context.py"
run_test "pyspark/sql/dataframe.py"
+ run_test "pyspark/sql/functions.py"
run_test "pyspark/sql/tests.py"
}
diff --git a/repl/pom.xml b/repl/pom.xml
index 3d4adf8fd5b03..b883344bf0ceb 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -33,8 +33,6 @@
repl
- /usr/share/spark
- rootscala-2.10/src/main/scalascala-2.10/src/test/scala
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index b4db3df795177..8dc0e0c965923 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -1064,15 +1064,16 @@ class SparkILoop(
private def main(settings: Settings): Unit = process(settings)
}
-object SparkILoop {
+object SparkILoop extends Logging {
implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp
private def echo(msg: String) = Console println msg
def getAddedJars: Array[String] = {
val envJars = sys.env.get("ADD_JARS")
- val propJars = sys.props.get("spark.jars").flatMap { p =>
- if (p == "") None else Some(p)
+ if (envJars.isDefined) {
+ logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead")
}
+ val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) }
val jars = propJars.orElse(envJars).getOrElse("")
Utils.resolveURIs(jars).split(",").filter(_.nonEmpty)
}
diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
index 0cf2de6d399b0..05faef8786d2c 100644
--- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
@@ -137,7 +137,7 @@ private[repl] trait SparkILoopInit {
command("import org.apache.spark.SparkContext._")
command("import sqlContext.implicits._")
command("import sqlContext.sql")
- command("import org.apache.spark.sql.Dsl._")
+ command("import org.apache.spark.sql.functions._")
}
}
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 201f2672d5474..529914a2b6141 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -262,7 +262,7 @@ class ReplSuite extends FunSuite {
|val sqlContext = new org.apache.spark.sql.SQLContext(sc)
|import sqlContext.implicits._
|case class TestCaseClass(value: Int)
- |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDataFrame.collect()
+ |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF.collect()
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
index dc25692749aad..2210fbaafeadb 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
@@ -51,6 +51,9 @@ object Main extends Logging {
def getAddedJars: Array[String] = {
val envJars = sys.env.get("ADD_JARS")
+ if (envJars.isDefined) {
+ logWarning("ADD_JARS environment variable is deprecated, use --jar spark submit argument instead")
+ }
val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) }
val jars = propJars.orElse(envJars).getOrElse("")
Utils.resolveURIs(jars).split(",").filter(_.nonEmpty)
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
index 1bd2a6991404b..7a5e94da5cbf3 100644
--- a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -77,7 +77,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
command("import org.apache.spark.SparkContext._")
command("import sqlContext.implicits._")
command("import sqlContext.sql")
- command("import org.apache.spark.sql.Dsl._")
+ command("import org.apache.spark.sql.functions._")
}
}
diff --git a/sbin/spark-daemon.sh b/sbin/spark-daemon.sh
index 89608bc41b71d..5e812a1d91c6b 100755
--- a/sbin/spark-daemon.sh
+++ b/sbin/spark-daemon.sh
@@ -129,8 +129,9 @@ case $option in
mkdir -p "$SPARK_PID_DIR"
if [ -f $pid ]; then
- if kill -0 `cat $pid` > /dev/null 2>&1; then
- echo $command running as process `cat $pid`. Stop it first.
+ TARGET_ID="$(cat "$pid")"
+ if [[ $(ps -p "$TARGET_ID" -o args=) =~ $command ]]; then
+ echo "$command running as process $TARGET_ID. Stop it first."
exit 1
fi
fi
@@ -141,7 +142,7 @@ case $option in
fi
spark_rotate_log "$log"
- echo starting $command, logging to $log
+ echo "starting $command, logging to $log"
if [ $option == spark-submit ]; then
source "$SPARK_HOME"/bin/utils.sh
gatherSparkSubmitOpts "$@"
@@ -154,7 +155,7 @@ case $option in
echo $newpid > $pid
sleep 2
# Check if the process has died; in that case we'll tail the log so the user can see
- if ! kill -0 $newpid >/dev/null 2>&1; then
+ if [[ ! $(ps -p "$newpid" -o args=) =~ $command ]]; then
echo "failed to launch $command:"
tail -2 "$log" | sed 's/^/ /'
echo "full log in $log"
@@ -164,14 +165,15 @@ case $option in
(stop)
if [ -f $pid ]; then
- if kill -0 `cat $pid` > /dev/null 2>&1; then
- echo stopping $command
- kill `cat $pid`
+ TARGET_ID="$(cat "$pid")"
+ if [[ $(ps -p "$TARGET_ID" -o comm=) =~ "java" ]]; then
+ echo "stopping $command"
+ kill "$TARGET_ID" && rm -f "$pid"
else
- echo no $command to stop
+ echo "no $command to stop"
fi
else
- echo no $command to stop
+ echo "no $command to stop"
fi
;;
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
new file mode 100644
index 0000000000000..15add84878ecf
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/AnalysisException.scala
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Thrown when a query fails to analyze, usually because the query itself is invalid.
+ */
+@DeveloperApi
+class AnalysisException protected[sql] (
+ val message: String,
+ val line: Option[Int] = None,
+ val startPosition: Option[Int] = None)
+ extends Exception with Serializable {
+
+ override def getMessage: String = {
+ val lineAnnotation = line.map(l => s" line $l").getOrElse("")
+ val positionAnnotation = startPosition.map(p => s" pos $p").getOrElse("")
+ s"$message;$lineAnnotation$positionAnnotation"
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 3a70d25534968..d794f034f5578 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import scala.util.hashing.MurmurHash3
import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.types.DateUtils
+import org.apache.spark.sql.types.{StructType, DateUtils}
object Row {
/**
@@ -122,6 +122,11 @@ trait Row extends Serializable {
/** Number of elements in the Row. */
def length: Int
+ /**
+ * Schema for the row.
+ */
+ def schema: StructType = null
+
/**
* Returns the value at position i. If the value is null, null is returned. The following
* is a mapping between Spark SQL types and return types:
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 5d9c331ca5178..d6126c24fc50d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst
import java.sql.Timestamp
import org.apache.spark.util.Utils
-import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference, Row}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
import org.apache.spark.sql.types._
@@ -91,9 +91,9 @@ trait ScalaReflection {
def convertRowToScala(r: Row, schema: StructType): Row = {
// TODO: This is very slow!!!
- new GenericRow(
+ new GenericRowWithSchema(
r.toSeq.zip(schema.fields.map(_.dataType))
- .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray)
+ .map(r_dt => convertToScala(r_dt._1, r_dt._2)).toArray, schema)
}
/** Returns a Sequence of attributes for the given case class type. */
@@ -122,6 +122,21 @@ trait ScalaReflection {
case t if t <:< typeOf[Option[_]] =>
val TypeRef(_, _, Seq(optType)) = t
Schema(schemaFor(optType).dataType, nullable = true)
+ // Need to decide if we actually need a special type here.
+ case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
+ case t if t <:< typeOf[Array[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
+ case t if t <:< typeOf[Seq[_]] =>
+ val TypeRef(_, _, Seq(elementType)) = t
+ val Schema(dataType, nullable) = schemaFor(elementType)
+ Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
+ case t if t <:< typeOf[Map[_, _]] =>
+ val TypeRef(_, _, Seq(keyType, valueType)) = t
+ val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+ Schema(MapType(schemaFor(keyType).dataType,
+ valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[Product] =>
val formalTypeArgs = t.typeSymbol.asClass.typeParams
val TypeRef(_, _, actualTypeArgs) = t
@@ -144,21 +159,6 @@ trait ScalaReflection {
schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs))
StructField(p.name.toString, dataType, nullable)
}), nullable = true)
- // Need to decide if we actually need a special type here.
- case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true)
- case t if t <:< typeOf[Array[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, nullable) = schemaFor(elementType)
- Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if t <:< typeOf[Seq[_]] =>
- val TypeRef(_, _, Seq(elementType)) = t
- val Schema(dataType, nullable) = schemaFor(elementType)
- Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
- case t if t <:< typeOf[Map[_, _]] =>
- val TypeRef(_, _, Seq(keyType, valueType)) = t
- val Schema(valueDataType, valueNullable) = schemaFor(valueType)
- Schema(MapType(schemaFor(keyType).dataType,
- valueDataType, valueContainsNull = valueNullable), nullable = true)
case t if t <:< typeOf[String] => Schema(StringType, nullable = true)
case t if t <:< typeOf[Timestamp] => Schema(TimestampType, nullable = true)
case t if t <:< typeOf[java.sql.Date] => Schema(DateType, nullable = true)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
old mode 100755
new mode 100644
index 124f083669358..c363a5efacde8
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -40,7 +40,7 @@ class SqlParser extends AbstractSparkSQLParser {
def parseExpression(input: String): Expression = {
// Initialize the Keywords.
lexical.initialize(reservedWords)
- phrase(expression)(new lexical.Scanner(input)) match {
+ phrase(projection)(new lexical.Scanner(input)) match {
case Success(plan, _) => plan
case failureOrError => sys.error(failureOrError.toString)
}
@@ -78,6 +78,7 @@ class SqlParser extends AbstractSparkSQLParser {
protected val IF = Keyword("IF")
protected val IN = Keyword("IN")
protected val INNER = Keyword("INNER")
+ protected val INT = Keyword("INT")
protected val INSERT = Keyword("INSERT")
protected val INTERSECT = Keyword("INTERSECT")
protected val INTO = Keyword("INTO")
@@ -394,6 +395,7 @@ class SqlParser extends AbstractSparkSQLParser {
| fixedDecimalType
| DECIMAL ^^^ DecimalType.Unlimited
| DATE ^^^ DateType
+ | INT ^^^ IntegerType
)
protected lazy val fixedDecimalType: Parser[DataType] =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index fb2ff014cef07..e4e542562f22d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -18,11 +18,12 @@
package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.util.collection.OpenHashSet
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.types.{ArrayType, StructField, StructType, IntegerType}
+import org.apache.spark.sql.types._
/**
* A trivial [[Analyzer]] with an [[EmptyCatalog]] and [[EmptyFunctionRegistry]]. Used for testing
@@ -49,52 +50,25 @@ class Analyzer(catalog: Catalog,
/**
* Override to provide additional rules for the "Resolution" batch.
*/
- val extendedRules: Seq[Rule[LogicalPlan]] = Nil
+ val extendedResolutionRules: Seq[Rule[LogicalPlan]] = Nil
lazy val batches: Seq[Batch] = Seq(
- Batch("MultiInstanceRelations", Once,
- NewRelationInstances),
Batch("Resolution", fixedPoint,
- ResolveReferences ::
ResolveRelations ::
+ ResolveReferences ::
ResolveGroupingAnalytics ::
ResolveSortReferences ::
- NewRelationInstances ::
ImplicitGenerate ::
ResolveFunctions ::
GlobalAggregates ::
UnresolvedHavingClauseAttributes ::
TrimGroupingAliases ::
typeCoercionRules ++
- extendedRules : _*),
- Batch("Check Analysis", Once,
- CheckResolution ::
- CheckAggregation ::
- Nil: _*),
- Batch("AnalysisOperators", fixedPoint,
- EliminateAnalysisOperators)
+ extendedResolutionRules : _*),
+ Batch("Remove SubQueries", fixedPoint,
+ EliminateSubQueries)
)
- /**
- * Makes sure all attributes and logical plans have been resolved.
- */
- object CheckResolution extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = {
- plan.transform {
- case p if p.expressions.exists(!_.resolved) =>
- throw new TreeNodeException(p,
- s"Unresolved attributes: ${p.expressions.filterNot(_.resolved).mkString(",")}")
- case p if !p.resolved && p.childrenResolved =>
- throw new TreeNodeException(p, "Unresolved plan found")
- } match {
- // As a backstop, use the root node to check that the entire plan tree is resolved.
- case p if !p.resolved =>
- throw new TreeNodeException(p, "Unresolved plan in tree")
- case p => p
- }
- }
- }
-
/**
* Removes no-op Alias expressions from the plan.
*/
@@ -193,46 +167,24 @@ class Analyzer(catalog: Catalog,
}
/**
- * Checks for non-aggregated attributes with aggregation
+ * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
*/
- object CheckAggregation extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = {
- plan.transform {
- case aggregatePlan @ Aggregate(groupingExprs, aggregateExprs, child) =>
- def isValidAggregateExpression(expr: Expression): Boolean = expr match {
- case _: AggregateExpression => true
- case e: Attribute => groupingExprs.contains(e)
- case e if groupingExprs.contains(e) => true
- case e if e.references.isEmpty => true
- case e => e.children.forall(isValidAggregateExpression)
- }
-
- aggregateExprs.find { e =>
- !isValidAggregateExpression(e.transform {
- // Should trim aliases around `GetField`s. These aliases are introduced while
- // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
- // (Should we just turn `GetField` into a `NamedExpression`?)
- case Alias(g: GetField, _) => g
- })
- }.foreach { e =>
- throw new TreeNodeException(plan, s"Expression not in GROUP BY: $e")
- }
-
- aggregatePlan
+ object ResolveRelations extends Rule[LogicalPlan] {
+ def getTable(u: UnresolvedRelation) = {
+ try {
+ catalog.lookupRelation(u.tableIdentifier, u.alias)
+ } catch {
+ case _: NoSuchTableException =>
+ u.failAnalysis(s"no such table ${u.tableIdentifier}")
}
}
- }
- /**
- * Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
- */
- object ResolveRelations extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case i @ InsertIntoTable(UnresolvedRelation(tableIdentifier, alias), _, _, _) =>
+ case i @ InsertIntoTable(u: UnresolvedRelation, _, _, _) =>
i.copy(
- table = EliminateAnalysisOperators(catalog.lookupRelation(tableIdentifier, alias)))
- case UnresolvedRelation(tableIdentifier, alias) =>
- catalog.lookupRelation(tableIdentifier, alias)
+ table = EliminateSubQueries(getTable(u)))
+ case u: UnresolvedRelation =>
+ getTable(u)
}
}
@@ -282,6 +234,27 @@ class Analyzer(catalog: Catalog,
}
)
+ // Special handling for cases when self-join introduce duplicate expression ids.
+ case j @ Join(left, right, _, _) if left.outputSet.intersect(right.outputSet).nonEmpty =>
+ val conflictingAttributes = left.outputSet.intersect(right.outputSet)
+
+ val (oldRelation, newRelation, attributeRewrites) = right.collect {
+ case oldVersion: MultiInstanceRelation
+ if oldVersion.outputSet.intersect(conflictingAttributes).nonEmpty =>
+ val newVersion = oldVersion.newInstance()
+ val newAttributes = AttributeMap(oldVersion.output.zip(newVersion.output))
+ (oldVersion, newVersion, newAttributes)
+ }.head // Only handle first case found, others will be fixed on the next pass.
+
+ val newRight = right transformUp {
+ case r if r == oldRelation => newRelation
+ case other => other transformExpressions {
+ case a: Attribute => attributeRewrites.get(a).getOrElse(a)
+ }
+ }
+
+ j.copy(right = newRight)
+
case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString}")
q transformExpressionsUp {
@@ -314,10 +287,11 @@ class Analyzer(catalog: Catalog,
val checkField = (f: StructField) => resolver(f.name, fieldName)
val ordinal = fields.indexWhere(checkField)
if (ordinal == -1) {
- sys.error(
+ throw new AnalysisException(
s"No such struct field $fieldName in ${fields.map(_.name).mkString(", ")}")
} else if (fields.indexWhere(checkField, ordinal + 1) != -1) {
- sys.error(s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
+ throw new AnalysisException(
+ s"Ambiguous reference to fields ${fields.filter(checkField).mkString(", ")}")
} else {
ordinal
}
@@ -329,7 +303,8 @@ class Analyzer(catalog: Catalog,
case ArrayType(StructType(fields), containsNull) =>
val ordinal = findField(fields)
ArrayGetField(expr, fields(ordinal), ordinal, containsNull)
- case otherType => sys.error(s"GetField is not valid on fields of type $otherType")
+ case otherType =>
+ throw new AnalysisException(s"GetField is not valid on fields of type $otherType")
}
}
}
@@ -454,7 +429,7 @@ class Analyzer(catalog: Catalog,
* only required to provide scoping information for attributes and can be removed once analysis is
* complete.
*/
-object EliminateAnalysisOperators extends Rule[LogicalPlan] {
+object EliminateSubQueries extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Subquery(_, child) => child
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
index df8d03b86c533..9e6e2912e0622 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala
@@ -21,6 +21,12 @@ import scala.collection.mutable
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery}
+/**
+ * Thrown by a catalog when a table cannot be found. The analzyer will rethrow the exception
+ * as an AnalysisException with the correct position information.
+ */
+class NoSuchTableException extends Exception
+
/**
* An interface for looking up relations by name. Used by an [[Analyzer]].
*/
@@ -34,6 +40,14 @@ trait Catalog {
tableIdentifier: Seq[String],
alias: Option[String] = None): LogicalPlan
+ /**
+ * Returns tuples of (tableName, isTemporary) for all tables in the given database.
+ * isTemporary is a Boolean value indicates if a table is a temporary or not.
+ */
+ def getTables(databaseName: Option[String]): Seq[(String, Boolean)]
+
+ def refreshTable(databaseName: String, tableName: String): Unit
+
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit
def unregisterTable(tableIdentifier: Seq[String]): Unit
@@ -101,6 +115,16 @@ class SimpleCatalog(val caseSensitive: Boolean) extends Catalog {
// properly qualified with this alias.
alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers)
}
+
+ override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+ tables.map {
+ case (name, _) => (name, true)
+ }.toSeq
+ }
+
+ override def refreshTable(databaseName: String, tableName: String): Unit = {
+ throw new UnsupportedOperationException
+ }
}
/**
@@ -137,6 +161,27 @@ trait OverrideCatalog extends Catalog {
withAlias.getOrElse(super.lookupRelation(tableIdentifier, alias))
}
+ abstract override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+ val dbName = if (!caseSensitive) {
+ if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None
+ } else {
+ databaseName
+ }
+
+ val temporaryTables = overrides.filter {
+ // If a temporary table does not have an associated database, we should return its name.
+ case ((None, _), _) => true
+ // If a temporary table does have an associated database, we should return it if the database
+ // matches the given database name.
+ case ((db: Some[String], _), _) if db == dbName => true
+ case _ => false
+ }.map {
+ case ((_, tableName), _) => (tableName, true)
+ }.toSeq
+
+ temporaryTables ++ super.getTables(databaseName)
+ }
+
override def registerTable(
tableIdentifier: Seq[String],
plan: LogicalPlan): Unit = {
@@ -172,6 +217,10 @@ object EmptyCatalog extends Catalog {
throw new UnsupportedOperationException
}
+ override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = {
+ throw new UnsupportedOperationException
+ }
+
def registerTable(tableIdentifier: Seq[String], plan: LogicalPlan): Unit = {
throw new UnsupportedOperationException
}
@@ -181,4 +230,8 @@ object EmptyCatalog extends Catalog {
}
override def unregisterAllTables(): Unit = {}
+
+ override def refreshTable(databaseName: String, tableName: String): Unit = {
+ throw new UnsupportedOperationException
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
new file mode 100644
index 0000000000000..4e8fc892f3eea
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -0,0 +1,105 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.analysis
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types._
+
+/**
+ * Throws user facing errors when passed invalid queries that fail to analyze.
+ */
+class CheckAnalysis {
+
+ /**
+ * Override to provide additional checks for correct analysis.
+ * These rules will be evaluated after our built-in check rules.
+ */
+ val extendedCheckRules: Seq[LogicalPlan => Unit] = Nil
+
+ def failAnalysis(msg: String) = {
+ throw new AnalysisException(msg)
+ }
+
+ def apply(plan: LogicalPlan): Unit = {
+ // We transform up and order the rules so as to catch the first possible failure instead
+ // of the result of cascading resolution failures.
+ plan.foreachUp {
+ case operator: LogicalPlan =>
+ operator transformExpressionsUp {
+ case a: Attribute if !a.resolved =>
+ val from = operator.inputSet.map(_.name).mkString(", ")
+ a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from")
+
+ case c: Cast if !c.resolved =>
+ failAnalysis(
+ s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}")
+
+ case b: BinaryExpression if !b.resolved =>
+ failAnalysis(
+ s"invalid expression ${b.prettyString} " +
+ s"between ${b.left.simpleString} and ${b.right.simpleString}")
+ }
+
+ operator match {
+ case f: Filter if f.condition.dataType != BooleanType =>
+ failAnalysis(
+ s"filter expression '${f.condition.prettyString}' " +
+ s"of type ${f.condition.dataType.simpleString} is not a boolean.")
+
+ case aggregatePlan@Aggregate(groupingExprs, aggregateExprs, child) =>
+ def checkValidAggregateExpression(expr: Expression): Unit = expr match {
+ case _: AggregateExpression => // OK
+ case e: Attribute if !groupingExprs.contains(e) =>
+ failAnalysis(
+ s"expression '${e.prettyString}' is neither present in the group by, " +
+ s"nor is it an aggregate function. " +
+ "Add to group by or wrap in first() if you don't care which value you get.")
+ case e if groupingExprs.contains(e) => // OK
+ case e if e.references.isEmpty => // OK
+ case e => e.children.foreach(checkValidAggregateExpression)
+ }
+
+ val cleaned = aggregateExprs.map(_.transform {
+ // Should trim aliases around `GetField`s. These aliases are introduced while
+ // resolving struct field accesses, because `GetField` is not a `NamedExpression`.
+ // (Should we just turn `GetField` into a `NamedExpression`?)
+ case Alias(g, _) => g
+ })
+
+ cleaned.foreach(checkValidAggregateExpression)
+
+ case o if o.children.nonEmpty &&
+ !o.references.filter(_.name != "grouping__id").subsetOf(o.inputSet) =>
+ val missingAttributes = (o.references -- o.inputSet).map(_.prettyString).mkString(",")
+ val input = o.inputSet.map(_.prettyString).mkString(",")
+
+ failAnalysis(s"resolved attributes $missingAttributes missing from $input")
+
+ // Catch all
+ case o if !o.resolved =>
+ failAnalysis(
+ s"unresolved operator ${operator.simpleString}")
+
+ case _ => // Analysis successful!
+ }
+ }
+ extendedCheckRules.foreach(_(plan))
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
index 4c5fb3f45bf49..894c3500cf533 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala
@@ -26,28 +26,9 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
* produced by distinct operators in a query tree as this breaks the guarantee that expression
* ids, which are used to differentiate attributes, are unique.
*
- * Before analysis, all operators that include this trait will be asked to produce a new version
+ * During analysis, operators that include this trait may be asked to produce a new version
* of itself with globally unique expression ids.
*/
trait MultiInstanceRelation {
def newInstance(): this.type
}
-
-/**
- * If any MultiInstanceRelation appears more than once in the query plan then the plan is updated so
- * that each instance has unique expression ids for the attributes produced.
- */
-object NewRelationInstances extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = {
- val localRelations = plan collect { case l: MultiInstanceRelation => l}
- val multiAppearance = localRelations
- .groupBy(identity[MultiInstanceRelation])
- .filter { case (_, ls) => ls.size > 1 }
- .map(_._1)
- .toSet
-
- plan transform {
- case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance()
- }
- }
-}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
index 3f672a3e0fd91..e95f19e69ed43 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/package.scala
@@ -17,6 +17,9 @@
package org.apache.spark.sql.catalyst
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.trees.TreeNode
+
/**
* Provides a logical query plan [[Analyzer]] and supporting classes for performing analysis.
* Analysis consists of translating [[UnresolvedAttribute]]s and [[UnresolvedRelation]]s
@@ -25,11 +28,18 @@ package org.apache.spark.sql.catalyst
package object analysis {
/**
- * Responsible for resolving which identifiers refer to the same entity. For example, by using
- * case insensitive equality.
+ * Resolver should return true if the first string refers to the same entity as the second string.
+ * For example, by using case insensitive equality.
*/
type Resolver = (String, String) => Boolean
val caseInsensitiveResolution = (a: String, b: String) => a.equalsIgnoreCase(b)
val caseSensitiveResolution = (a: String, b: String) => a == b
+
+ implicit class AnalysisErrorAt(t: TreeNode[_]) {
+ /** Fails the analysis at the point where a specific tree node was parsed. */
+ def failAnalysis(msg: String) = {
+ throw new AnalysisException(msg, t.origin.line, t.origin.startPosition)
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index f959a50564011..a7cd4124e56f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -152,7 +152,7 @@ case class MultiAlias(child: Expression, names: Seq[String])
override lazy val resolved = false
- override def newInstance = this
+ override def newInstance() = this
override def withNullability(newNullability: Boolean) = this
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
old mode 100755
new mode 100644
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
index 171845ad14e3e..a9ba0be596349 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeSet.scala
@@ -20,7 +20,11 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.Star
protected class AttributeEquals(val a: Attribute) {
- override def hashCode() = a.exprId.hashCode()
+ override def hashCode() = a match {
+ case ar: AttributeReference => ar.exprId.hashCode()
+ case a => a.hashCode()
+ }
+
override def equals(other: Any) = (a, other.asInstanceOf[AttributeEquals].a) match {
case (a1: AttributeReference, a2: AttributeReference) => a1.exprId == a2.exprId
case (a1, a2) => a1 == a2
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index cf14992ef835c..6ad39b8372cfb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.trees.TreeNode
@@ -67,203 +68,14 @@ abstract class Expression extends TreeNode[Expression] {
def childrenResolved = !children.exists(!_.resolved)
/**
- * A set of helper functions that return the correct descendant of `scala.math.Numeric[T]` type
- * and do any casting necessary of child evaluation.
+ * Returns a string representation of this expression that does not have developer centric
+ * debugging information like the expression id.
*/
- @inline
- def n1(e: Expression, i: Row, f: ((Numeric[Any], Any) => Any)): Any = {
- val evalE = e.eval(i)
- if (evalE == null) {
- null
- } else {
- e.dataType match {
- case n: NumericType =>
- val castedFunction = f.asInstanceOf[(Numeric[n.JvmType], n.JvmType) => n.JvmType]
- castedFunction(n.numeric, evalE.asInstanceOf[n.JvmType])
- case other => sys.error(s"Type $other does not support numeric operations")
- }
- }
- }
-
- /**
- * Evaluation helper function for 2 Numeric children expressions. Those expressions are supposed
- * to be in the same data type, and also the return type.
- * Either one of the expressions result is null, the evaluation result should be null.
- */
- @inline
- protected final def n2(
- i: Row,
- e1: Expression,
- e2: Expression,
- f: ((Numeric[Any], Any, Any) => Any)): Any = {
-
- if (e1.dataType != e2.dataType) {
- throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
- }
-
- val evalE1 = e1.eval(i)
- if(evalE1 == null) {
- null
- } else {
- val evalE2 = e2.eval(i)
- if (evalE2 == null) {
- null
- } else {
- e1.dataType match {
- case n: NumericType =>
- f.asInstanceOf[(Numeric[n.JvmType], n.JvmType, n.JvmType) => n.JvmType](
- n.numeric, evalE1.asInstanceOf[n.JvmType], evalE2.asInstanceOf[n.JvmType])
- case other => sys.error(s"Type $other does not support numeric operations")
- }
- }
- }
- }
-
- /**
- * Evaluation helper function for 2 Fractional children expressions. Those expressions are
- * supposed to be in the same data type, and also the return type.
- * Either one of the expressions result is null, the evaluation result should be null.
- */
- @inline
- protected final def f2(
- i: Row,
- e1: Expression,
- e2: Expression,
- f: ((Fractional[Any], Any, Any) => Any)): Any = {
- if (e1.dataType != e2.dataType) {
- throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
- }
-
- val evalE1 = e1.eval(i: Row)
- if(evalE1 == null) {
- null
- } else {
- val evalE2 = e2.eval(i: Row)
- if (evalE2 == null) {
- null
- } else {
- e1.dataType match {
- case ft: FractionalType =>
- f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType, ft.JvmType) => ft.JvmType](
- ft.fractional, evalE1.asInstanceOf[ft.JvmType], evalE2.asInstanceOf[ft.JvmType])
- case other => sys.error(s"Type $other does not support fractional operations")
- }
- }
- }
- }
-
- /**
- * Evaluation helper function for 1 Fractional children expression.
- * if the expression result is null, the evaluation result should be null.
- */
- @inline
- protected final def f1(i: Row, e1: Expression, f: ((Fractional[Any], Any) => Any)): Any = {
- val evalE1 = e1.eval(i: Row)
- if(evalE1 == null) {
- null
- } else {
- e1.dataType match {
- case ft: FractionalType =>
- f.asInstanceOf[(Fractional[ft.JvmType], ft.JvmType) => ft.JvmType](
- ft.fractional, evalE1.asInstanceOf[ft.JvmType])
- case other => sys.error(s"Type $other does not support fractional operations")
- }
- }
- }
-
- /**
- * Evaluation helper function for 2 Integral children expressions. Those expressions are
- * supposed to be in the same data type, and also the return type.
- * Either one of the expressions result is null, the evaluation result should be null.
- */
- @inline
- protected final def i2(
- i: Row,
- e1: Expression,
- e2: Expression,
- f: ((Integral[Any], Any, Any) => Any)): Any = {
- if (e1.dataType != e2.dataType) {
- throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
- }
-
- val evalE1 = e1.eval(i)
- if(evalE1 == null) {
- null
- } else {
- val evalE2 = e2.eval(i)
- if (evalE2 == null) {
- null
- } else {
- e1.dataType match {
- case i: IntegralType =>
- f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType](
- i.integral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
- case i: FractionalType =>
- f.asInstanceOf[(Integral[i.JvmType], i.JvmType, i.JvmType) => i.JvmType](
- i.asIntegral, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
- case other => sys.error(s"Type $other does not support numeric operations")
- }
- }
- }
- }
-
- /**
- * Evaluation helper function for 1 Integral children expression.
- * if the expression result is null, the evaluation result should be null.
- */
- @inline
- protected final def i1(i: Row, e1: Expression, f: ((Integral[Any], Any) => Any)): Any = {
- val evalE1 = e1.eval(i)
- if(evalE1 == null) {
- null
- } else {
- e1.dataType match {
- case i: IntegralType =>
- f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType](
- i.integral, evalE1.asInstanceOf[i.JvmType])
- case i: FractionalType =>
- f.asInstanceOf[(Integral[i.JvmType], i.JvmType) => i.JvmType](
- i.asIntegral, evalE1.asInstanceOf[i.JvmType])
- case other => sys.error(s"Type $other does not support numeric operations")
- }
- }
- }
-
- /**
- * Evaluation helper function for 2 Comparable children expressions. Those expressions are
- * supposed to be in the same data type, and the return type should be Integer:
- * Negative value: 1st argument less than 2nd argument
- * Zero: 1st argument equals 2nd argument
- * Positive value: 1st argument greater than 2nd argument
- *
- * Either one of the expressions result is null, the evaluation result should be null.
- */
- @inline
- protected final def c2(
- i: Row,
- e1: Expression,
- e2: Expression,
- f: ((Ordering[Any], Any, Any) => Any)): Any = {
- if (e1.dataType != e2.dataType) {
- throw new TreeNodeException(this, s"Types do not match ${e1.dataType} != ${e2.dataType}")
- }
-
- val evalE1 = e1.eval(i)
- if(evalE1 == null) {
- null
- } else {
- val evalE2 = e2.eval(i)
- if (evalE2 == null) {
- null
- } else {
- e1.dataType match {
- case i: NativeType =>
- f.asInstanceOf[(Ordering[i.JvmType], i.JvmType, i.JvmType) => Boolean](
- i.ordering, evalE1.asInstanceOf[i.JvmType], evalE2.asInstanceOf[i.JvmType])
- case other => sys.error(s"Type $other does not support ordered operations")
- }
- }
- }
+ def prettyString: String = {
+ transform {
+ case a: AttributeReference => PrettyAttribute(a.name)
+ case u: UnresolvedAttribute => PrettyAttribute(u.name)
+ }.toString
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
index 7434165f654f8..21d714c9a8c3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala
@@ -220,13 +220,14 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR
override def isNullAt(i: Int): Boolean = values(i).isNull
override def copy(): Row = {
- val newValues = new Array[MutableValue](values.length)
+ val newValues = new Array[Any](values.length)
var i = 0
while (i < values.length) {
- newValues(i) = values(i).copy()
+ newValues(i) = values(i).boxed
i += 1
}
- new SpecificMutableRow(newValues)
+
+ new GenericRow(newValues)
}
override def update(ordinal: Int, value: Any): Unit = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
old mode 100755
new mode 100644
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 574907f566c0f..00b0d3c683fe2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.types._
case class UnaryMinus(child: Expression) extends UnaryExpression {
@@ -28,8 +29,18 @@ case class UnaryMinus(child: Expression) extends UnaryExpression {
def nullable = child.nullable
override def toString = s"-$child"
+ lazy val numeric = dataType match {
+ case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+
override def eval(input: Row): Any = {
- n1(child, input, _.negate(_))
+ val evalE = child.eval(input)
+ if (evalE == null) {
+ null
+ } else {
+ numeric.negate(evalE)
+ }
}
}
@@ -41,18 +52,19 @@ case class Sqrt(child: Expression) extends UnaryExpression {
def nullable = true
override def toString = s"SQRT($child)"
+ lazy val numeric = child.dataType match {
+ case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
+ case other => sys.error(s"Type $other does not support non-negative numeric operations")
+ }
+
override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
- child.dataType match {
- case n: NumericType =>
- val value = n.numeric.toDouble(evalE.asInstanceOf[n.JvmType])
- if (value < 0) null
- else math.sqrt(value)
- case other => sys.error(s"Type $other does not support non-negative numeric operations")
- }
+ val value = numeric.toDouble(evalE)
+ if (value < 0) null
+ else math.sqrt(value)
}
}
}
@@ -98,19 +110,70 @@ abstract class BinaryArithmetic extends BinaryExpression {
case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "+"
- override def eval(input: Row): Any = n2(input, left, right, _.plus(_, _))
+ lazy val numeric = dataType match {
+ case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ numeric.plus(evalE1, evalE2)
+ }
+ }
+ }
}
case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "-"
- override def eval(input: Row): Any = n2(input, left, right, _.minus(_, _))
+ lazy val numeric = dataType match {
+ case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ numeric.minus(evalE1, evalE2)
+ }
+ }
+ }
}
case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "*"
- override def eval(input: Row): Any = n2(input, left, right, _.times(_, _))
+ lazy val numeric = dataType match {
+ case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ numeric.times(evalE1, evalE2)
+ }
+ }
+ }
}
case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -118,16 +181,25 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
override def nullable = true
+ lazy val div: (Any, Any) => Any = dataType match {
+ case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
+ case it: IntegralType => it.integral.asInstanceOf[Integral[Any]].quot
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+
override def eval(input: Row): Any = {
val evalE2 = right.eval(input)
- dataType match {
- case _ if evalE2 == null => null
- case _ if evalE2 == 0 => null
- case ft: FractionalType => f1(input, left, _.div(_, evalE2.asInstanceOf[ft.JvmType]))
- case it: IntegralType => i1(input, left, _.quot(_, evalE2.asInstanceOf[it.JvmType]))
+ if (evalE2 == null || evalE2 == 0) {
+ null
+ } else {
+ val evalE1 = left.eval(input)
+ if (evalE1 == null) {
+ null
+ } else {
+ div(evalE1, evalE2)
+ }
}
}
-
}
case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
@@ -135,12 +207,23 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override def nullable = true
+ lazy val integral = dataType match {
+ case i: IntegralType => i.integral.asInstanceOf[Integral[Any]]
+ case i: FractionalType => i.asIntegral.asInstanceOf[Integral[Any]]
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+
override def eval(input: Row): Any = {
val evalE2 = right.eval(input)
- dataType match {
- case _ if evalE2 == null => null
- case _ if evalE2 == 0 => null
- case nt: NumericType => i1(input, left, _.rem(_, evalE2.asInstanceOf[nt.JvmType]))
+ if (evalE2 == null || evalE2 == 0) {
+ null
+ } else {
+ val evalE1 = left.eval(input)
+ if (evalE1 == null) {
+ null
+ } else {
+ integral.rem(evalE1, evalE2)
+ }
}
}
}
@@ -151,13 +234,19 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "&"
- override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match {
- case ByteType => (evalE1.asInstanceOf[Byte] & evalE2.asInstanceOf[Byte]).toByte
- case ShortType => (evalE1.asInstanceOf[Short] & evalE2.asInstanceOf[Short]).toShort
- case IntegerType => evalE1.asInstanceOf[Int] & evalE2.asInstanceOf[Int]
- case LongType => evalE1.asInstanceOf[Long] & evalE2.asInstanceOf[Long]
+ lazy val and: (Any, Any) => Any = dataType match {
+ case ByteType =>
+ ((evalE1: Byte, evalE2: Byte) => (evalE1 & evalE2).toByte).asInstanceOf[(Any, Any) => Any]
+ case ShortType =>
+ ((evalE1: Short, evalE2: Short) => (evalE1 & evalE2).toShort).asInstanceOf[(Any, Any) => Any]
+ case IntegerType =>
+ ((evalE1: Int, evalE2: Int) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
+ case LongType =>
+ ((evalE1: Long, evalE2: Long) => evalE1 & evalE2).asInstanceOf[(Any, Any) => Any]
case other => sys.error(s"Unsupported bitwise & operation on $other")
}
+
+ override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = and(evalE1, evalE2)
}
/**
@@ -166,13 +255,19 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme
case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "|"
- override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match {
- case ByteType => (evalE1.asInstanceOf[Byte] | evalE2.asInstanceOf[Byte]).toByte
- case ShortType => (evalE1.asInstanceOf[Short] | evalE2.asInstanceOf[Short]).toShort
- case IntegerType => evalE1.asInstanceOf[Int] | evalE2.asInstanceOf[Int]
- case LongType => evalE1.asInstanceOf[Long] | evalE2.asInstanceOf[Long]
+ lazy val or: (Any, Any) => Any = dataType match {
+ case ByteType =>
+ ((evalE1: Byte, evalE2: Byte) => (evalE1 | evalE2).toByte).asInstanceOf[(Any, Any) => Any]
+ case ShortType =>
+ ((evalE1: Short, evalE2: Short) => (evalE1 | evalE2).toShort).asInstanceOf[(Any, Any) => Any]
+ case IntegerType =>
+ ((evalE1: Int, evalE2: Int) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
+ case LongType =>
+ ((evalE1: Long, evalE2: Long) => evalE1 | evalE2).asInstanceOf[(Any, Any) => Any]
case other => sys.error(s"Unsupported bitwise | operation on $other")
}
+
+ override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = or(evalE1, evalE2)
}
/**
@@ -181,13 +276,19 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet
case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic {
def symbol = "^"
- override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = dataType match {
- case ByteType => (evalE1.asInstanceOf[Byte] ^ evalE2.asInstanceOf[Byte]).toByte
- case ShortType => (evalE1.asInstanceOf[Short] ^ evalE2.asInstanceOf[Short]).toShort
- case IntegerType => evalE1.asInstanceOf[Int] ^ evalE2.asInstanceOf[Int]
- case LongType => evalE1.asInstanceOf[Long] ^ evalE2.asInstanceOf[Long]
+ lazy val xor: (Any, Any) => Any = dataType match {
+ case ByteType =>
+ ((evalE1: Byte, evalE2: Byte) => (evalE1 ^ evalE2).toByte).asInstanceOf[(Any, Any) => Any]
+ case ShortType =>
+ ((evalE1: Short, evalE2: Short) => (evalE1 ^ evalE2).toShort).asInstanceOf[(Any, Any) => Any]
+ case IntegerType =>
+ ((evalE1: Int, evalE2: Int) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
+ case LongType =>
+ ((evalE1: Long, evalE2: Long) => evalE1 ^ evalE2).asInstanceOf[(Any, Any) => Any]
case other => sys.error(s"Unsupported bitwise ^ operation on $other")
}
+
+ override def evalInternal(evalE1: EvaluatedType, evalE2: EvaluatedType): Any = xor(evalE1, evalE2)
}
/**
@@ -201,18 +302,24 @@ case class BitwiseNot(child: Expression) extends UnaryExpression {
def nullable = child.nullable
override def toString = s"~$child"
+ lazy val not: (Any) => Any = dataType match {
+ case ByteType =>
+ ((evalE: Byte) => (~evalE).toByte).asInstanceOf[(Any) => Any]
+ case ShortType =>
+ ((evalE: Short) => (~evalE).toShort).asInstanceOf[(Any) => Any]
+ case IntegerType =>
+ ((evalE: Int) => ~evalE).asInstanceOf[(Any) => Any]
+ case LongType =>
+ ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any]
+ case other => sys.error(s"Unsupported bitwise ~ operation on $other")
+ }
+
override def eval(input: Row): Any = {
val evalE = child.eval(input)
if (evalE == null) {
null
} else {
- dataType match {
- case ByteType => (~evalE.asInstanceOf[Byte]).toByte
- case ShortType => (~evalE.asInstanceOf[Short]).toShort
- case IntegerType => ~evalE.asInstanceOf[Int]
- case LongType => ~evalE.asInstanceOf[Long]
- case other => sys.error(s"Unsupported bitwise ~ operation on $other")
- }
+ not(evalE)
}
}
}
@@ -226,21 +333,35 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
override def children = left :: right :: Nil
- override def dataType = left.dataType
+ override lazy val resolved =
+ left.resolved && right.resolved &&
+ left.dataType == right.dataType
+
+ override def dataType = {
+ if (!resolved) {
+ throw new UnresolvedException(this,
+ s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
+ }
+ left.dataType
+ }
+
+ lazy val ordering = left.dataType match {
+ case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case other => sys.error(s"Type $other does not support ordered operations")
+ }
override def eval(input: Row): Any = {
- val leftEval = left.eval(input)
- val rightEval = right.eval(input)
- if (leftEval == null) {
- rightEval
- } else if (rightEval == null) {
- leftEval
+ val evalE1 = left.eval(input)
+ val evalE2 = right.eval(input)
+ if (evalE1 == null) {
+ evalE2
+ } else if (evalE2 == null) {
+ evalE1
} else {
- val numeric = left.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]]
- if (numeric.compare(leftEval, rightEval) < 0) {
- rightEval
+ if (ordering.compare(evalE1, evalE2) < 0) {
+ evalE2
} else {
- leftEval
+ evalE1
}
}
}
@@ -259,5 +380,17 @@ case class Abs(child: Expression) extends UnaryExpression {
def nullable = child.nullable
override def toString = s"Abs($child)"
- override def eval(input: Row): Any = n1(child, input, _.abs(_))
+ lazy val numeric = dataType match {
+ case n: NumericType => n.numeric.asInstanceOf[Numeric[Any]]
+ case other => sys.error(s"Type $other does not support numeric operations")
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE = child.eval(input)
+ if (evalE == null) {
+ null
+ } else {
+ numeric.abs(evalE)
+ }
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 43b6482c0171c..0983d274def3f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -73,6 +73,25 @@ abstract class Generator extends Expression {
}
}
+/**
+ * A generator that produces its output using the provided lambda function.
+ */
+case class UserDefinedGenerator(
+ schema: Seq[Attribute],
+ function: Row => TraversableOnce[Row],
+ children: Seq[Expression])
+ extends Generator{
+
+ override protected def makeOutput(): Seq[Attribute] = schema
+
+ override def eval(input: Row): TraversableOnce[Row] = {
+ val inputRow = new InterpretedProjection(children)
+ function(inputRow(input))
+ }
+
+ override def toString = s"UserDefinedGenerator(${children.mkString(",")})"
+}
+
/**
* Given an input array produces a sequence of rows for each value in the array.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 97bb96f48e2c7..9ff66563c8164 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -38,6 +38,8 @@ object Literal {
case d: Date => Literal(DateUtils.fromJavaDate(d), DateType)
case a: Array[Byte] => Literal(a, BinaryType)
case null => Literal(null, NullType)
+ case _ =>
+ throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index e6ab1fd8d7939..62c062be6d820 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -40,6 +40,17 @@ abstract class NamedExpression extends Expression {
def name: String
def exprId: ExprId
+
+ /**
+ * All possible qualifiers for the expression.
+ *
+ * For now, since we do not allow using original table name to qualify a column name once the
+ * table is aliased, this can only be:
+ *
+ * 1. Empty Seq: when an attribute doesn't have a qualifier,
+ * e.g. top level attributes aliased in the SELECT clause, or column from a LocalRelation.
+ * 2. Single element: either the table name or the alias name of the table.
+ */
def qualifiers: Seq[String]
def toAttribute: Attribute
@@ -190,6 +201,26 @@ case class AttributeReference(
override def toString: String = s"$name#${exprId.id}$typeSuffix"
}
+/**
+ * A place holder used when printing expressions without debugging information such as the
+ * expression id or the unresolved indicator.
+ */
+case class PrettyAttribute(name: String) extends Attribute with trees.LeafNode[Expression] {
+ type EvaluatedType = Any
+
+ override def toString = name
+
+ override def withNullability(newNullability: Boolean): Attribute = ???
+ override def newInstance(): Attribute = ???
+ override def withQualifiers(newQualifiers: Seq[String]): Attribute = ???
+ override def withName(newName: String): Attribute = ???
+ override def qualifiers: Seq[String] = ???
+ override def exprId: ExprId = ???
+ override def eval(input: Row): EvaluatedType = ???
+ override def nullable: Boolean = ???
+ override def dataType: DataType = NullType
+}
+
object VirtualColumn {
val groupingIdName = "grouping__id"
def newGroupingId = AttributeReference(groupingIdName, IntegerType, false)()
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 365b1685a8e71..0024ef92c0452 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -18,8 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
+import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.types.{BinaryType, BooleanType}
+import org.apache.spark.sql.types.{BinaryType, BooleanType, NativeType}
object InterpretedPredicate {
def apply(expression: Expression, inputSchema: Seq[Attribute]): (Row => Boolean) =
@@ -201,22 +202,118 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "<"
- override def eval(input: Row): Any = c2(input, left, right, _.lt(_, _))
+
+ lazy val ordering = {
+ if (left.dataType != right.dataType) {
+ throw new TreeNodeException(this,
+ s"Types do not match ${left.dataType} != ${right.dataType}")
+ }
+ left.dataType match {
+ case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case other => sys.error(s"Type $other does not support ordered operations")
+ }
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ ordering.lt(evalE1, evalE2)
+ }
+ }
+ }
}
case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
def symbol = "<="
- override def eval(input: Row): Any = c2(input, left, right, _.lteq(_, _))
+
+ lazy val ordering = {
+ if (left.dataType != right.dataType) {
+ throw new TreeNodeException(this,
+ s"Types do not match ${left.dataType} != ${right.dataType}")
+ }
+ left.dataType match {
+ case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case other => sys.error(s"Type $other does not support ordered operations")
+ }
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ ordering.lteq(evalE1, evalE2)
+ }
+ }
+ }
}
case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
def symbol = ">"
- override def eval(input: Row): Any = c2(input, left, right, _.gt(_, _))
+
+ lazy val ordering = {
+ if (left.dataType != right.dataType) {
+ throw new TreeNodeException(this,
+ s"Types do not match ${left.dataType} != ${right.dataType}")
+ }
+ left.dataType match {
+ case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case other => sys.error(s"Type $other does not support ordered operations")
+ }
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ ordering.gt(evalE1, evalE2)
+ }
+ }
+ }
}
case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
def symbol = ">="
- override def eval(input: Row): Any = c2(input, left, right, _.gteq(_, _))
+
+ lazy val ordering = {
+ if (left.dataType != right.dataType) {
+ throw new TreeNodeException(this,
+ s"Types do not match ${left.dataType} != ${right.dataType}")
+ }
+ left.dataType match {
+ case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case other => sys.error(s"Type $other does not support ordered operations")
+ }
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ if(evalE1 == null) {
+ null
+ } else {
+ val evalE2 = right.eval(input)
+ if (evalE2 == null) {
+ null
+ } else {
+ ordering.gteq(evalE1, evalE2)
+ }
+ }
+ }
}
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
index 73ec7a6d114f5..faa366771824b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.types.NativeType
+import org.apache.spark.sql.types.{StructType, NativeType}
/**
@@ -149,6 +149,10 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
def copy() = this
}
+class GenericRowWithSchema(values: Array[Any], override val schema: StructType)
+ extends GenericRow(values) {
+}
+
class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
/** No-arg constructor for serialization. */
def this() = this(null)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 0da081ed1a6e2..1a75fcf3545bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -119,6 +119,15 @@ object ColumnPruning extends Rule[LogicalPlan] {
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
a.copy(child = Project(a.references.toSeq, child))
+ case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
+ if (a.outputSet -- p.references).nonEmpty =>
+ Project(
+ projectList,
+ Aggregate(
+ groupingExpressions,
+ aggregateExpressions.filter(e => p.references.contains(e)),
+ child))
+
// Eliminate unneeded attributes from either side of a Join.
case Project(projectList, Join(left, right, joinType, condition)) =>
// Collect the list of all references required either above or to evaluate the condition.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 619f42859cbb8..17a88e07de15f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -152,6 +152,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
/** Prints out the schema in the tree format */
def printSchema(): Unit = println(schemaString)
+ /**
+ * A prefix string used when printing the plan.
+ *
+ * We use "!" to indicate an invalid plan, and "'" to indicate an unresolved plan.
+ */
protected def statePrefix = if (missingInput.nonEmpty && children.nonEmpty) "!" else ""
override def simpleString = statePrefix + super.simpleString
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 7cf4b81274906..8c4f09b58a4f2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -18,41 +18,29 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.Logging
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, Resolver}
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.catalyst.trees.TreeNode
-import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.catalyst.trees
-/**
- * Estimates of various statistics. The default estimation logic simply lazily multiplies the
- * corresponding statistic produced by the children. To override this behavior, override
- * `statistics` and assign it an overridden version of `Statistics`.
- *
- * '''NOTE''': concrete and/or overridden versions of statistics fields should pay attention to the
- * performance of the implementations. The reason is that estimations might get triggered in
- * performance-critical processes, such as query plan planning.
- *
- * Note that we are using a BigInt here since it is easy to overflow a 64-bit integer in
- * cardinality estimation (e.g. cartesian joins).
- *
- * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
- * defaults to the product of children's `sizeInBytes`.
- */
-private[sql] case class Statistics(sizeInBytes: BigInt)
abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
self: Product =>
+ /**
+ * Computes [[Statistics]] for this plan. The default implementation assumes the output
+ * cardinality is the product of of all child plan's cardinality, i.e. applies in the case
+ * of cartesian joins.
+ *
+ * [[LeafNode]]s must override this.
+ */
def statistics: Statistics = {
if (children.size == 0) {
throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.")
}
-
- Statistics(
- sizeInBytes = children.map(_.statistics).map(_.sizeInBytes).product)
+ Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product)
}
/**
@@ -128,6 +116,44 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
def resolve(name: String, resolver: Resolver): Option[NamedExpression] =
resolve(name, output, resolver)
+ /**
+ * Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
+ *
+ * This assumes `name` has multiple parts, where the 1st part is a qualifier
+ * (i.e. table name, alias, or subquery alias).
+ * See the comment above `candidates` variable in resolve() for semantics the returned data.
+ */
+ private def resolveAsTableColumn(
+ nameParts: Array[String],
+ resolver: Resolver,
+ attribute: Attribute): Option[(Attribute, List[String])] = {
+ assert(nameParts.length > 1)
+ if (attribute.qualifiers.exists(resolver(_, nameParts.head))) {
+ // At least one qualifier matches. See if remaining parts match.
+ val remainingParts = nameParts.tail
+ resolveAsColumn(remainingParts, resolver, attribute)
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Resolve the given `name` string against the given attribute, returning either 0 or 1 match.
+ *
+ * Different from resolveAsTableColumn, this assumes `name` does NOT start with a qualifier.
+ * See the comment above `candidates` variable in resolve() for semantics the returned data.
+ */
+ private def resolveAsColumn(
+ nameParts: Array[String],
+ resolver: Resolver,
+ attribute: Attribute): Option[(Attribute, List[String])] = {
+ if (resolver(attribute.name, nameParts.head)) {
+ Option((attribute.withName(nameParts.head), nameParts.tail.toList))
+ } else {
+ None
+ }
+ }
+
/** Performs attribute resolution given a name and a sequence of possible attributes. */
protected def resolve(
name: String,
@@ -136,34 +162,44 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
val parts = name.split("\\.")
- // Collect all attributes that are output by this nodes children where either the first part
- // matches the name or where the first part matches the scope and the second part matches the
- // name. Return these matches along with any remaining parts, which represent dotted access to
- // struct fields.
- val options = input.flatMap { option =>
- // If the first part of the desired name matches a qualifier for this possible match, drop it.
- val remainingParts =
- if (option.qualifiers.find(resolver(_, parts.head)).nonEmpty && parts.size > 1) {
- parts.drop(1)
- } else {
- parts
+ // A sequence of possible candidate matches.
+ // Each candidate is a tuple. The first element is a resolved attribute, followed by a list
+ // of parts that are to be resolved.
+ // For example, consider an example where "a" is the table name, "b" is the column name,
+ // and "c" is the struct field name, i.e. "a.b.c". In this case, Attribute will be "a.b",
+ // and the second element will be List("c").
+ var candidates: Seq[(Attribute, List[String])] = {
+ // If the name has 2 or more parts, try to resolve it as `table.column` first.
+ if (parts.length > 1) {
+ input.flatMap { option =>
+ resolveAsTableColumn(parts, resolver, option)
}
-
- if (resolver(option.name, remainingParts.head)) {
- // Preserve the case of the user's attribute reference.
- (option.withName(remainingParts.head), remainingParts.tail.toList) :: Nil
} else {
- Nil
+ Seq.empty
+ }
+ }
+
+ // If none of attributes match `table.column` pattern, we try to resolve it as a column.
+ if (candidates.isEmpty) {
+ candidates = input.flatMap { candidate =>
+ resolveAsColumn(parts, resolver, candidate)
}
}
- options.distinct match {
+ candidates.distinct match {
// One match, no nested fields, use it.
case Seq((a, Nil)) => Some(a)
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
- Some(Alias(nestedFields.foldLeft(a: Expression)(UnresolvedGetField), nestedFields.last)())
+ // The foldLeft adds UnresolvedGetField for every remaining parts of the name,
+ // and aliased it with the last part of the name.
+ // For example, consider name "a.b.c", where "a" is resolved to an existing attribute.
+ // Then this will add UnresolvedGetField("b") and UnresolvedGetField("c"), and alias
+ // the final expression as "c".
+ val fieldExprs = nestedFields.foldLeft(a: Expression)(UnresolvedGetField)
+ val aliasName = nestedFields.last
+ Some(Alias(fieldExprs, aliasName)())
// No matches.
case Seq() =>
@@ -172,8 +208,8 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// More than one match.
case ambiguousReferences =>
- throw new TreeNodeException(
- this, s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
+ throw new AnalysisException(
+ s"Ambiguous references to $name: ${ambiguousReferences.mkString(",")}")
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
index cfe2c7a39a17c..ccf5291219add 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
+import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Attribute, Expression}
/**
* Transforms the input by forking and running the specified script.
@@ -32,7 +32,9 @@ case class ScriptTransformation(
script: String,
output: Seq[Attribute],
child: LogicalPlan,
- ioschema: ScriptInputOutputSchema) extends UnaryNode
+ ioschema: ScriptInputOutputSchema) extends UnaryNode {
+ override def references: AttributeSet = AttributeSet(input.flatMap(_.references))
+}
/**
* A placeholder for implementation specific input and output properties when passing data
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
new file mode 100644
index 0000000000000..9ac4c3a2a56c8
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.logical
+
+/**
+ * Estimates of various statistics. The default estimation logic simply lazily multiplies the
+ * corresponding statistic produced by the children. To override this behavior, override
+ * `statistics` and assign it an overridden version of `Statistics`.
+ *
+ * '''NOTE''': concrete and/or overridden versions of statistics fields should pay attention to the
+ * performance of the implementations. The reason is that estimations might get triggered in
+ * performance-critical processes, such as query plan planning.
+ *
+ * Note that we are using a BigInt here since it is easy to overflow a 64-bit integer in
+ * cardinality estimation (e.g. cartesian joins).
+ *
+ * @param sizeInBytes Physical size in bytes. For leaf operators this defaults to 1, otherwise it
+ * defaults to the product of children's `sizeInBytes`.
+ */
+private[sql] case class Statistics(sizeInBytes: BigInt)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 9628e93274a11..89544add74430 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -23,6 +23,16 @@ import org.apache.spark.sql.types._
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
def output = projectList.map(_.toAttribute)
+
+ override lazy val resolved: Boolean = {
+ val containsAggregatesOrGenerators = projectList.exists ( _.collect {
+ case agg: AggregateExpression => agg
+ case generator: Generator => generator
+ }.nonEmpty
+ )
+
+ !expressions.exists(!_.resolved) && childrenResolved && !containsAggregatesOrGenerators
+ }
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 2013ae4f7bd13..109671bdca361 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -22,9 +22,42 @@ import org.apache.spark.sql.catalyst.errors._
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
+case class Origin(
+ line: Option[Int] = None,
+ startPosition: Option[Int] = None)
+
+/**
+ * Provides a location for TreeNodes to ask about the context of their origin. For example, which
+ * line of code is currently being parsed.
+ */
+object CurrentOrigin {
+ private val value = new ThreadLocal[Origin]() {
+ override def initialValue: Origin = Origin()
+ }
+
+ def get = value.get()
+ def set(o: Origin) = value.set(o)
+
+ def reset() = value.set(Origin())
+
+ def setPosition(line: Int, start: Int) = {
+ value.set(
+ value.get.copy(line = Some(line), startPosition = Some(start)))
+ }
+
+ def withOrigin[A](o: Origin)(f: => A): A = {
+ set(o)
+ val ret = try f finally { reset() }
+ reset()
+ ret
+ }
+}
+
abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
self: BaseType with Product =>
+ val origin = CurrentOrigin.get
+
/** Returns a Seq of the children of this node */
def children: Seq[BaseType]
@@ -46,6 +79,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
children.foreach(_.foreach(f))
}
+ /**
+ * Runs the given function recursively on [[children]] then on this node.
+ * @param f the function to be applied to each node in the tree.
+ */
+ def foreachUp(f: BaseType => Unit): Unit = {
+ children.foreach(_.foreach(f))
+ f(this)
+ }
+
/**
* Returns a Seq containing the result of applying the given function to each
* node in this tree in a preorder traversal.
@@ -141,7 +183,10 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param rule the function used to transform this nodes children
*/
def transformDown(rule: PartialFunction[BaseType, BaseType]): BaseType = {
- val afterRule = rule.applyOrElse(this, identity[BaseType])
+ val afterRule = CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(this, identity[BaseType])
+ }
+
// Check if unchanged and then possibly return old copy to avoid gc churn.
if (this fastEquals afterRule) {
transformChildrenDown(rule)
@@ -201,9 +246,13 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
def transformUp(rule: PartialFunction[BaseType, BaseType]): BaseType = {
val afterRuleOnChildren = transformChildrenUp(rule);
if (this fastEquals afterRuleOnChildren) {
- rule.applyOrElse(this, identity[BaseType])
+ CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(this, identity[BaseType])
+ }
} else {
- rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
+ CurrentOrigin.withOrigin(origin) {
+ rule.applyOrElse(afterRuleOnChildren, identity[BaseType])
+ }
}
}
@@ -259,12 +308,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
try {
- // Skip no-arg constructors that are just there for kryo.
- val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
- if (otherCopyArgs.isEmpty) {
- defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
- } else {
- defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type]
+ CurrentOrigin.withOrigin(origin) {
+ // Skip no-arg constructors that are just there for kryo.
+ val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
+ if (otherCopyArgs.isEmpty) {
+ defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
+ } else {
+ defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[this.type]
+ }
}
} catch {
case e: java.lang.IllegalArgumentException =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Metadata.scala
old mode 100755
new mode 100644
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index d0f547d187ecb..eee00e3f7ea76 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -61,6 +61,7 @@ case class OptionalData(
case class ComplexData(
arrayField: Seq[Int],
arrayField1: Array[Int],
+ arrayField2: List[Int],
arrayFieldContainsNull: Seq[java.lang.Integer],
mapField: Map[Int, Long],
mapFieldValueContainsNull: Map[Int, java.lang.Long],
@@ -137,6 +138,10 @@ class ScalaReflectionSuite extends FunSuite {
"arrayField1",
ArrayType(IntegerType, containsNull = false),
nullable = true),
+ StructField(
+ "arrayField2",
+ ArrayType(IntegerType, containsNull = false),
+ nullable = true),
StructField(
"arrayFieldContainsNull",
ArrayType(IntegerType, containsNull = true),
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index 60060bf02913b..c1dd5aa913ddc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis
import org.scalatest.{BeforeAndAfter, FunSuite}
-import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types._
@@ -30,11 +30,21 @@ import org.apache.spark.sql.catalyst.dsl.plans._
class AnalysisSuite extends FunSuite with BeforeAndAfter {
val caseSensitiveCatalog = new SimpleCatalog(true)
val caseInsensitiveCatalog = new SimpleCatalog(false)
- val caseSensitiveAnalyze =
+
+ val caseSensitiveAnalyzer =
new Analyzer(caseSensitiveCatalog, EmptyFunctionRegistry, caseSensitive = true)
- val caseInsensitiveAnalyze =
+ val caseInsensitiveAnalyzer =
new Analyzer(caseInsensitiveCatalog, EmptyFunctionRegistry, caseSensitive = false)
+ val checkAnalysis = new CheckAnalysis
+
+
+ def caseSensitiveAnalyze(plan: LogicalPlan) =
+ checkAnalysis(caseSensitiveAnalyzer(plan))
+
+ def caseInsensitiveAnalyze(plan: LogicalPlan) =
+ checkAnalysis(caseInsensitiveAnalyzer(plan))
+
val testRelation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)())
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
@@ -55,35 +65,46 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None)))
}
- assert(caseInsensitiveAnalyze(plan).resolved)
+ assert(caseInsensitiveAnalyzer(plan).resolved)
+ }
+
+ test("check project's resolved") {
+ assert(Project(testRelation.output, testRelation).resolved)
+
+ assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)
+
+ val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)())
+ assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
+
+ assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)
}
test("analyze project") {
assert(
- caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
+ caseSensitiveAnalyzer(Project(Seq(UnresolvedAttribute("a")), testRelation)) ===
Project(testRelation.output, testRelation))
assert(
- caseSensitiveAnalyze(
+ caseSensitiveAnalyzer(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
- val e = intercept[TreeNodeException[_]] {
+ val e = intercept[AnalysisException] {
caseSensitiveAnalyze(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL"))))
}
- assert(e.getMessage().toLowerCase.contains("unresolved"))
+ assert(e.getMessage().toLowerCase.contains("cannot resolve"))
assert(
- caseInsensitiveAnalyze(
+ caseInsensitiveAnalyzer(
Project(Seq(UnresolvedAttribute("TbL.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
assert(
- caseInsensitiveAnalyze(
+ caseInsensitiveAnalyzer(
Project(Seq(UnresolvedAttribute("tBl.a")),
UnresolvedRelation(Seq("TaBlE"), Some("TbL")))) ===
Project(testRelation.output, testRelation))
@@ -96,36 +117,65 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(e.getMessage == "Table Not Found: tAbLe")
assert(
- caseSensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
- testRelation)
+ caseSensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
assert(
- caseInsensitiveAnalyze(UnresolvedRelation(Seq("tAbLe"), None)) ===
- testRelation)
+ caseInsensitiveAnalyzer(UnresolvedRelation(Seq("tAbLe"), None)) === testRelation)
assert(
- caseInsensitiveAnalyze(UnresolvedRelation(Seq("TaBlE"), None)) ===
- testRelation)
+ caseInsensitiveAnalyzer(UnresolvedRelation(Seq("TaBlE"), None)) === testRelation)
}
- test("throw errors for unresolved attributes during analysis") {
- val e = intercept[TreeNodeException[_]] {
- caseSensitiveAnalyze(Project(Seq(UnresolvedAttribute("abcd")), testRelation))
+ def errorTest(
+ name: String,
+ plan: LogicalPlan,
+ errorMessages: Seq[String],
+ caseSensitive: Boolean = true) = {
+ test(name) {
+ val error = intercept[AnalysisException] {
+ if(caseSensitive) {
+ caseSensitiveAnalyze(plan)
+ } else {
+ caseInsensitiveAnalyze(plan)
+ }
+ }
+
+ errorMessages.foreach(m => assert(error.getMessage contains m))
}
- assert(e.getMessage().toLowerCase.contains("unresolved attribute"))
}
- test("throw errors for unresolved plans during analysis") {
- case class UnresolvedTestPlan() extends LeafNode {
- override lazy val resolved = false
- override def output = Nil
- }
- val e = intercept[TreeNodeException[_]] {
- caseSensitiveAnalyze(UnresolvedTestPlan())
- }
- assert(e.getMessage().toLowerCase.contains("unresolved plan"))
+ errorTest(
+ "unresolved attributes",
+ testRelation.select('abcd),
+ "cannot resolve" :: "abcd" :: Nil)
+
+ errorTest(
+ "bad casts",
+ testRelation.select(Literal(1).cast(BinaryType).as('badCast)),
+ "invalid cast" :: Literal(1).dataType.simpleString :: BinaryType.simpleString :: Nil)
+
+ errorTest(
+ "non-boolean filters",
+ testRelation.where(Literal(1)),
+ "filter" :: "'1'" :: "not a boolean" :: Literal(1).dataType.simpleString :: Nil)
+
+ errorTest(
+ "missing group by",
+ testRelation2.groupBy('a)('b),
+ "'b'" :: "group by" :: Nil
+ )
+
+ case class UnresolvedTestPlan() extends LeafNode {
+ override lazy val resolved = false
+ override def output = Nil
}
+ errorTest(
+ "catch all unresolved plan",
+ UnresolvedTestPlan(),
+ "unresolved" :: Nil)
+
+
test("divide should be casted into fractional types") {
val testRelation2 = LocalRelation(
AttributeReference("a", StringType)(),
@@ -134,18 +184,15 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
AttributeReference("d", DecimalType.Unlimited)(),
AttributeReference("e", ShortType)())
- val expr0 = 'a / 2
- val expr1 = 'a / 'b
- val expr2 = 'a / 'c
- val expr3 = 'a / 'd
- val expr4 = 'e / 'e
- val plan = caseInsensitiveAnalyze(Project(
- Alias(expr0, s"Analyzer($expr0)")() ::
- Alias(expr1, s"Analyzer($expr1)")() ::
- Alias(expr2, s"Analyzer($expr2)")() ::
- Alias(expr3, s"Analyzer($expr3)")() ::
- Alias(expr4, s"Analyzer($expr4)")() :: Nil, testRelation2))
+ val plan = caseInsensitiveAnalyzer(
+ testRelation2.select(
+ 'a / Literal(2) as 'div1,
+ 'a / 'b as 'div2,
+ 'a / 'c as 'div3,
+ 'a / 'd as 'div4,
+ 'e / 'e as 'div5))
val pl = plan.asInstanceOf[Project].projectList
+
assert(pl(0).dataType == DoubleType)
assert(pl(1).dataType == DoubleType)
assert(pl(2).dataType == DoubleType)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
index 264a0eff37d34..72f06e26e05f1 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -30,7 +30,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
- EliminateAnalysisOperators) ::
+ EliminateSubQueries) ::
Batch("Constant Folding", FixedPoint(50),
NullPropagation,
ConstantFolding,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
index e22c62505860a..ef10c0aece716 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.optimizer
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateAnalysisOperators}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedGetField, EliminateSubQueries}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -33,7 +33,7 @@ class ConstantFoldingSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
- EliminateAnalysisOperators) ::
+ EliminateSubQueries) ::
Batch("ConstantFolding", Once,
ConstantFolding,
BooleanSimplification) :: Nil
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 1158b5dfc6147..55c6766520a1e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
-import org.apache.spark.sql.catalyst.expressions.Explode
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.expressions.{Count, Explode}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.rules._
@@ -32,12 +32,13 @@ class FilterPushdownSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
- EliminateAnalysisOperators) ::
+ EliminateSubQueries) ::
Batch("Filter Pushdown", Once,
CombineFilters,
PushPredicateThroughProject,
PushPredicateThroughJoin,
- PushPredicateThroughGenerate) :: Nil
+ PushPredicateThroughGenerate,
+ ColumnPruning) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -58,6 +59,38 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
+ test("column pruning for group") {
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a, Count('b))
+ .select('a)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a)
+ .select('a).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("column pruning for group with alias") {
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a as 'c, Count('b))
+ .select('c)
+
+ val optimized = Optimize(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a as 'c)
+ .select('c).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
// After this line is unimplemented.
test("simple push down") {
val originalQuery =
@@ -351,7 +384,7 @@ class FilterPushdownSuite extends PlanTest {
}
val optimized = Optimize(originalQuery.analyze)
- comparePlans(analysis.EliminateAnalysisOperators(originalQuery.analyze), optimized)
+ comparePlans(analysis.EliminateSubQueries(originalQuery.analyze), optimized)
}
test("joins: conjunctive predicates") {
@@ -370,7 +403,7 @@ class FilterPushdownSuite extends PlanTest {
left.join(right, condition = Some("x.b".attr === "y.b".attr))
.analyze
- comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer))
+ comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer))
}
test("joins: conjunctive predicates #2") {
@@ -389,7 +422,7 @@ class FilterPushdownSuite extends PlanTest {
left.join(right, condition = Some("x.b".attr === "y.b".attr))
.analyze
- comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer))
+ comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer))
}
test("joins: conjunctive predicates #3") {
@@ -412,7 +445,7 @@ class FilterPushdownSuite extends PlanTest {
condition = Some("z.a".attr === "x.b".attr))
.analyze
- comparePlans(optimized, analysis.EliminateAnalysisOperators(correctAnswer))
+ comparePlans(optimized, analysis.EliminateSubQueries(correctAnswer))
}
val testRelationWithArrayType = LocalRelation('a.int, 'b.int, 'c_arr.array(IntegerType))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index da912ab382179..233e329cb2038 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import scala.collection.immutable.HashSet
-import org.apache.spark.sql.catalyst.analysis.{EliminateAnalysisOperators, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.PlanTest
@@ -34,7 +34,7 @@ class OptimizeInSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("AnalysisNodes", Once,
- EliminateAnalysisOperators) ::
+ EliminateSubQueries) ::
Batch("ConstantFolding", Once,
ConstantFolding,
BooleanSimplification,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
index dfef87bd9133d..a54751dfa9a12 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis
-import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.{PlanTest, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.rules._
@@ -29,7 +29,7 @@ class UnionPushdownSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
- EliminateAnalysisOperators) ::
+ EliminateSubQueries) ::
Batch("Union Pushdown", Once,
UnionPushdown) :: Nil
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index c4a1f899d8a13..7d609b91389c6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -33,11 +33,9 @@ class PlanTest extends FunSuite {
* we must normalize them to check if two different queries are identical.
*/
protected def normalizeExprIds(plan: LogicalPlan) = {
- val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id))
- val minId = if (list.isEmpty) 0 else list.min
plan transformAllExpressions {
case a: AttributeReference =>
- AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId))
+ AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index cdb843f959704..e7ce92a2160b6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -104,4 +104,18 @@ class TreeNodeSuite extends FunSuite {
assert(actual === Dummy(None))
}
+ test("preserves origin") {
+ CurrentOrigin.setPosition(1,1)
+ val add = Add(Literal(1), Literal(1))
+ CurrentOrigin.reset()
+
+ val transformed = add transform {
+ case Literal(1, _) => Literal(2)
+ }
+
+ assert(transformed.origin.line.isDefined)
+ assert(transformed.origin.startPosition.isDefined)
+ }
+
+
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala
old mode 100755
new mode 100644
diff --git a/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java
new file mode 100644
index 0000000000000..a40be526d0d11
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/SaveMode.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql;
+
+/**
+ * SaveMode is used to specify the expected behavior of saving a DataFrame to a data source.
+ */
+public enum SaveMode {
+ /**
+ * Append mode means that when saving a DataFrame to a data source, if data/table already exists,
+ * contents of the DataFrame are expected to be appended to existing data.
+ */
+ Append,
+ /**
+ * Overwrite mode means that when saving a DataFrame to a data source,
+ * if data/table already exists, existing data is expected to be overwritten by the contents of
+ * the DataFrame.
+ */
+ Overwrite,
+ /**
+ * ErrorIfExists mode means that when saving a DataFrame to a data source, if data already exists,
+ * an exception is expected to be thrown.
+ */
+ ErrorIfExists,
+ /**
+ * Ignore mode means that when saving a DataFrame to a data source, if data already exists,
+ * the save operation is expected to not save the contents of the DataFrame and to not
+ * change the existing data.
+ */
+ Ignore
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java b/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java
deleted file mode 100644
index aa441b2096f18..0000000000000
--- a/sql/core/src/main/java/org/apache/spark/sql/jdbc/JDBCUtils.java
+++ /dev/null
@@ -1,59 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.jdbc;
-
-import org.apache.spark.Partition;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.DataFrame;
-
-public class JDBCUtils {
- /**
- * Construct a DataFrame representing the JDBC table at the database
- * specified by url with table name table.
- */
- public static DataFrame jdbcRDD(SQLContext sql, String url, String table) {
- Partition[] parts = new Partition[1];
- parts[0] = new JDBCPartition(null, 0);
- return sql.baseRelationToDataFrame(
- new JDBCRelation(url, table, parts, sql));
- }
-
- /**
- * Construct a DataFrame representing the JDBC table at the database
- * specified by url with table name table partitioned by parts.
- * Here, parts is an array of expressions suitable for insertion into a WHERE
- * clause; each one defines one partition.
- */
- public static DataFrame jdbcRDD(SQLContext sql, String url, String table, String[] parts) {
- Partition[] partitions = new Partition[parts.length];
- for (int i = 0; i < parts.length; i++)
- partitions[i] = new JDBCPartition(parts[i], i);
- return sql.baseRelationToDataFrame(
- new JDBCRelation(url, table, partitions, sql));
- }
-
- private static JavaJDBCTrampoline trampoline = new JavaJDBCTrampoline();
-
- public static void createJDBCTable(DataFrame rdd, String url, String table, boolean allowExisting) {
- trampoline.createJDBCTable(rdd, url, table, allowExisting);
- }
-
- public static void insertIntoJDBC(DataFrame rdd, String url, String table, boolean overwrite) {
- trampoline.insertIntoJDBC(rdd, url, table, overwrite);
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
index f1949aa5dd74b..ca4a127120b37 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala
@@ -71,11 +71,17 @@ private[sql] class CacheManager(sqlContext: SQLContext) extends Logging {
}
}
+ /** Clears all cached tables. */
private[sql] def clearCache(): Unit = writeLock {
cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist())
cachedData.clear()
}
+ /** Checks if the cache is empty. */
+ private[sql] def isEmpty: Boolean = readLock {
+ cachedData.isEmpty
+ }
+
/**
* Caches the data produced by the logical representation of the given schema rdd. Unlike
* `RDD.cache()`, the default storage level is set to be `MEMORY_AND_DISK` because recomputing
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 1011bf0bb5ef4..a2cc9a9b93eb8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -17,91 +17,47 @@
package org.apache.spark.sql
-import scala.annotation.tailrec
import scala.language.implicitConversions
-import org.apache.spark.sql.Dsl.lit
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{Subquery, Project, LogicalPlan}
-import org.apache.spark.sql.catalyst.analysis.UnresolvedGetField
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar, UnresolvedGetField}
import org.apache.spark.sql.types._
private[sql] object Column {
- def apply(colName: String): Column = new IncomputableColumn(colName)
+ def apply(colName: String): Column = new Column(colName)
- def apply(expr: Expression): Column = new IncomputableColumn(expr)
-
- def apply(sqlContext: SQLContext, plan: LogicalPlan, expr: Expression): Column = {
- new ComputableColumn(sqlContext, plan, expr)
- }
+ def apply(expr: Expression): Column = new Column(expr)
def unapply(col: Column): Option[Expression] = Some(col.expr)
}
/**
+ * :: Experimental ::
* A column in a [[DataFrame]].
*
- * `Column` instances can be created by:
- * {{{
- * // 1. Select a column out of a DataFrame
- * df("colName")
- *
- * // 2. Create a literal expression
- * Literal(1)
- *
- * // 3. Create new columns from
- * }}}
- *
+ * @groupname java_expr_ops Java-specific expression operators.
+ * @groupname expr_ops Expression operators.
+ * @groupname df_ops DataFrame functions.
+ * @groupname Ungrouped Support functions for DataFrames.
*/
-// TODO: Improve documentation.
-trait Column extends DataFrame {
-
- protected[sql] def expr: Expression
-
- /**
- * Returns true iff the [[Column]] is computable.
- */
- def isComputable: Boolean
-
- private def computableCol(baseCol: ComputableColumn, expr: Expression) = {
- val plan = Project(Seq(expr match {
- case named: NamedExpression => named
- case unnamed: Expression => Alias(unnamed, "col")()
- }), baseCol.plan)
- Column(baseCol.sqlContext, plan, expr)
- }
-
- private def constructColumn(otherValue: Any)(newExpr: Column => Expression): Column = {
- // Removes all the top level projection and subquery so we can get to the underlying plan.
- @tailrec def stripProject(p: LogicalPlan): LogicalPlan = p match {
- case Project(_, child) => stripProject(child)
- case Subquery(_, child) => stripProject(child)
- case _ => p
- }
+@Experimental
+class Column(protected[sql] val expr: Expression) {
- (this, lit(otherValue)) match {
- case (left: ComputableColumn, right: ComputableColumn) =>
- if (stripProject(left.plan).sameResult(stripProject(right.plan))) {
- computableCol(right, newExpr(right))
- } else {
- Column(newExpr(right))
- }
- case (left: ComputableColumn, right) => computableCol(left, newExpr(right))
- case (_, right: ComputableColumn) => computableCol(right, newExpr(right))
- case (_, right) => Column(newExpr(right))
- }
- }
+ def this(name: String) = this(name match {
+ case "*" => UnresolvedStar(None)
+ case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
+ case _ => UnresolvedAttribute(name)
+ })
/** Creates a column based on the given expression. */
- private def exprToColumn(newExpr: Expression, computable: Boolean = true): Column = {
- this match {
- case c: ComputableColumn if computable => computableCol(c, newExpr)
- case _ => Column(newExpr)
- }
- }
+ implicit private def exprToColumn(newExpr: Expression): Column = new Column(newExpr)
+
+ override def toString: String = expr.prettyString
/**
* Unary minus, i.e. negate the expression.
@@ -110,11 +66,13 @@ trait Column extends DataFrame {
* df.select( -df("amount") )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.select( negate(col("amount") );
* }}}
+ *
+ * @group expr_ops
*/
- def unary_- : Column = exprToColumn(UnaryMinus(expr))
+ def unary_- : Column = UnaryMinus(expr)
/**
* Inversion of boolean expression, i.e. NOT.
@@ -123,11 +81,13 @@ trait Column extends DataFrame {
* df.filter( !df("isActive") )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( not(df.col("isActive")) );
* }}
+ *
+ * @group expr_ops
*/
- def unary_! : Column = exprToColumn(Not(expr))
+ def unary_! : Column = Not(expr)
/**
* Equality test.
@@ -136,13 +96,13 @@ trait Column extends DataFrame {
* df.filter( df("colA") === df("colB") )
*
* // Java
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").equalTo(col("colB")) );
* }}}
+ *
+ * @group expr_ops
*/
- def === (other: Any): Column = constructColumn(other) { o =>
- EqualTo(expr, o.expr)
- }
+ def === (other: Any): Column = EqualTo(expr, lit(other).expr)
/**
* Equality test.
@@ -151,9 +111,11 @@ trait Column extends DataFrame {
* df.filter( df("colA") === df("colB") )
*
* // Java
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").equalTo(col("colB")) );
* }}}
+ *
+ * @group expr_ops
*/
def equalTo(other: Any): Column = this === other
@@ -165,13 +127,13 @@ trait Column extends DataFrame {
* df.select( !(df("colA") === df("colB")) )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").notEqual(col("colB")) );
* }}}
+ *
+ * @group expr_ops
*/
- def !== (other: Any): Column = constructColumn(other) { o =>
- Not(EqualTo(expr, o.expr))
- }
+ def !== (other: Any): Column = Not(EqualTo(expr, lit(other).expr))
/**
* Inequality test.
@@ -181,13 +143,13 @@ trait Column extends DataFrame {
* df.select( !(df("colA") === df("colB")) )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.filter( col("colA").notEqual(col("colB")) );
* }}}
+ *
+ * @group java_expr_ops
*/
- def notEqual(other: Any): Column = constructColumn(other) { o =>
- Not(EqualTo(expr, o.expr))
- }
+ def notEqual(other: Any): Column = Not(EqualTo(expr, lit(other).expr))
/**
* Greater than.
@@ -196,13 +158,13 @@ trait Column extends DataFrame {
* people.select( people("age") > 21 )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* people.select( people("age").gt(21) );
* }}}
+ *
+ * @group expr_ops
*/
- def > (other: Any): Column = constructColumn(other) { o =>
- GreaterThan(expr, o.expr)
- }
+ def > (other: Any): Column = GreaterThan(expr, lit(other).expr)
/**
* Greater than.
@@ -211,9 +173,11 @@ trait Column extends DataFrame {
* people.select( people("age") > lit(21) )
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* people.select( people("age").gt(21) );
* }}}
+ *
+ * @group java_expr_ops
*/
def gt(other: Any): Column = this > other
@@ -226,10 +190,10 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("age").lt(21) );
* }}}
+ *
+ * @group expr_ops
*/
- def < (other: Any): Column = constructColumn(other) { o =>
- LessThan(expr, o.expr)
- }
+ def < (other: Any): Column = LessThan(expr, lit(other).expr)
/**
* Less than.
@@ -240,6 +204,8 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("age").lt(21) );
* }}}
+ *
+ * @group java_expr_ops
*/
def lt(other: Any): Column = this < other
@@ -252,10 +218,10 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("age").leq(21) );
* }}}
+ *
+ * @group expr_ops
*/
- def <= (other: Any): Column = constructColumn(other) { o =>
- LessThanOrEqual(expr, o.expr)
- }
+ def <= (other: Any): Column = LessThanOrEqual(expr, lit(other).expr)
/**
* Less than or equal to.
@@ -266,6 +232,8 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("age").leq(21) );
* }}}
+ *
+ * @group java_expr_ops
*/
def leq(other: Any): Column = this <= other
@@ -278,10 +246,10 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("age").geq(21) )
* }}}
+ *
+ * @group expr_ops
*/
- def >= (other: Any): Column = constructColumn(other) { o =>
- GreaterThanOrEqual(expr, o.expr)
- }
+ def >= (other: Any): Column = GreaterThanOrEqual(expr, lit(other).expr)
/**
* Greater than or equal to an expression.
@@ -292,30 +260,38 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("age").geq(21) )
* }}}
+ *
+ * @group java_expr_ops
*/
def geq(other: Any): Column = this >= other
/**
* Equality test that is safe for null values.
+ *
+ * @group expr_ops
*/
- def <=> (other: Any): Column = constructColumn(other) { o =>
- EqualNullSafe(expr, o.expr)
- }
+ def <=> (other: Any): Column = EqualNullSafe(expr, lit(other).expr)
/**
* Equality test that is safe for null values.
+ *
+ * @group java_expr_ops
*/
def eqNullSafe(other: Any): Column = this <=> other
/**
* True if the current expression is null.
+ *
+ * @group expr_ops
*/
- def isNull: Column = exprToColumn(IsNull(expr))
+ def isNull: Column = IsNull(expr)
/**
* True if the current expression is NOT null.
+ *
+ * @group expr_ops
*/
- def isNotNull: Column = exprToColumn(IsNotNull(expr))
+ def isNotNull: Column = IsNotNull(expr)
/**
* Boolean OR.
@@ -326,10 +302,10 @@ trait Column extends DataFrame {
* // Java:
* people.filter( people("inSchool").or(people("isEmployed")) );
* }}}
+ *
+ * @group expr_ops
*/
- def || (other: Any): Column = constructColumn(other) { o =>
- Or(expr, o.expr)
- }
+ def || (other: Any): Column = Or(expr, lit(other).expr)
/**
* Boolean OR.
@@ -340,6 +316,8 @@ trait Column extends DataFrame {
* // Java:
* people.filter( people("inSchool").or(people("isEmployed")) );
* }}}
+ *
+ * @group java_expr_ops
*/
def or(other: Column): Column = this || other
@@ -352,10 +330,10 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("inSchool").and(people("isEmployed")) );
* }}}
+ *
+ * @group expr_ops
*/
- def && (other: Any): Column = constructColumn(other) { o =>
- And(expr, o.expr)
- }
+ def && (other: Any): Column = And(expr, lit(other).expr)
/**
* Boolean AND.
@@ -366,6 +344,8 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("inSchool").and(people("isEmployed")) );
* }}}
+ *
+ * @group java_expr_ops
*/
def and(other: Column): Column = this && other
@@ -378,10 +358,10 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("height").plus(people("weight")) );
* }}}
+ *
+ * @group expr_ops
*/
- def + (other: Any): Column = constructColumn(other) { o =>
- Add(expr, o.expr)
- }
+ def + (other: Any): Column = Add(expr, lit(other).expr)
/**
* Sum of this expression and another expression.
@@ -392,6 +372,8 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("height").plus(people("weight")) );
* }}}
+ *
+ * @group java_expr_ops
*/
def plus(other: Any): Column = this + other
@@ -404,10 +386,10 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("height").minus(people("weight")) );
* }}}
+ *
+ * @group expr_ops
*/
- def - (other: Any): Column = constructColumn(other) { o =>
- Subtract(expr, o.expr)
- }
+ def - (other: Any): Column = Subtract(expr, lit(other).expr)
/**
* Subtraction. Subtract the other expression from this expression.
@@ -418,6 +400,8 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("height").minus(people("weight")) );
* }}}
+ *
+ * @group java_expr_ops
*/
def minus(other: Any): Column = this - other
@@ -430,10 +414,10 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("height").multiply(people("weight")) );
* }}}
+ *
+ * @group expr_ops
*/
- def * (other: Any): Column = constructColumn(other) { o =>
- Multiply(expr, o.expr)
- }
+ def * (other: Any): Column = Multiply(expr, lit(other).expr)
/**
* Multiplication of this expression and another expression.
@@ -444,6 +428,8 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("height").multiply(people("weight")) );
* }}}
+ *
+ * @group java_expr_ops
*/
def multiply(other: Any): Column = this * other
@@ -456,10 +442,10 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("height").divide(people("weight")) );
* }}}
+ *
+ * @group expr_ops
*/
- def / (other: Any): Column = constructColumn(other) { o =>
- Divide(expr, o.expr)
- }
+ def / (other: Any): Column = Divide(expr, lit(other).expr)
/**
* Division this expression by another expression.
@@ -470,74 +456,113 @@ trait Column extends DataFrame {
* // Java:
* people.select( people("height").divide(people("weight")) );
* }}}
+ *
+ * @group java_expr_ops
*/
def divide(other: Any): Column = this / other
/**
* Modulo (a.k.a. remainder) expression.
+ *
+ * @group expr_ops
*/
- def % (other: Any): Column = constructColumn(other) { o =>
- Remainder(expr, o.expr)
- }
+ def % (other: Any): Column = Remainder(expr, lit(other).expr)
/**
* Modulo (a.k.a. remainder) expression.
+ *
+ * @group java_expr_ops
*/
def mod(other: Any): Column = this % other
/**
* A boolean expression that is evaluated to true if the value of this expression is contained
* by the evaluated values of the arguments.
+ *
+ * @group expr_ops
*/
@scala.annotation.varargs
- def in(list: Column*): Column = {
- new IncomputableColumn(In(expr, list.map(_.expr)))
- }
+ def in(list: Column*): Column = In(expr, list.map(_.expr))
- def like(literal: String): Column = exprToColumn(Like(expr, lit(literal).expr))
+ /**
+ * SQL like expression.
+ *
+ * @group expr_ops
+ */
+ def like(literal: String): Column = Like(expr, lit(literal).expr)
- def rlike(literal: String): Column = exprToColumn(RLike(expr, lit(literal).expr))
+ /**
+ * SQL RLIKE expression (LIKE with Regex).
+ *
+ * @group expr_ops
+ */
+ def rlike(literal: String): Column = RLike(expr, lit(literal).expr)
/**
* An expression that gets an item at position `ordinal` out of an array.
+ *
+ * @group expr_ops
*/
- def getItem(ordinal: Int): Column = exprToColumn(GetItem(expr, Literal(ordinal)))
+ def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
/**
* An expression that gets a field by name in a [[StructField]].
+ *
+ * @group expr_ops
*/
- def getField(fieldName: String): Column = exprToColumn(UnresolvedGetField(expr, fieldName))
+ def getField(fieldName: String): Column = UnresolvedGetField(expr, fieldName)
/**
* An expression that returns a substring.
* @param startPos expression for the starting position.
* @param len expression for the length of the substring.
+ *
+ * @group expr_ops
*/
- def substr(startPos: Column, len: Column): Column =
- exprToColumn(Substring(expr, startPos.expr, len.expr), computable = false)
+ def substr(startPos: Column, len: Column): Column = Substring(expr, startPos.expr, len.expr)
/**
* An expression that returns a substring.
* @param startPos starting position.
* @param len length of the substring.
+ *
+ * @group expr_ops
*/
- def substr(startPos: Int, len: Int): Column =
- exprToColumn(Substring(expr, lit(startPos).expr, lit(len).expr))
+ def substr(startPos: Int, len: Int): Column = Substring(expr, lit(startPos).expr, lit(len).expr)
- def contains(other: Any): Column = constructColumn(other) { o =>
- Contains(expr, o.expr)
- }
+ /**
+ * Contains the other element.
+ *
+ * @group expr_ops
+ */
+ def contains(other: Any): Column = Contains(expr, lit(other).expr)
- def startsWith(other: Column): Column = constructColumn(other) { o =>
- StartsWith(expr, o.expr)
- }
+ /**
+ * String starts with.
+ *
+ * @group expr_ops
+ */
+ def startsWith(other: Column): Column = StartsWith(expr, lit(other).expr)
+ /**
+ * String starts with another string literal.
+ *
+ * @group expr_ops
+ */
def startsWith(literal: String): Column = this.startsWith(lit(literal))
- def endsWith(other: Column): Column = constructColumn(other) { o =>
- EndsWith(expr, o.expr)
- }
+ /**
+ * String ends with.
+ *
+ * @group expr_ops
+ */
+ def endsWith(other: Column): Column = EndsWith(expr, lit(other).expr)
+ /**
+ * String ends with another string literal.
+ *
+ * @group expr_ops
+ */
def endsWith(literal: String): Column = this.endsWith(lit(literal))
/**
@@ -546,8 +571,10 @@ trait Column extends DataFrame {
* // Renames colA to colB in select output.
* df.select($"colA".as("colB"))
* }}}
+ *
+ * @group expr_ops
*/
- override def as(alias: String): Column = exprToColumn(Alias(expr, alias)())
+ def as(alias: String): Column = Alias(expr, alias)()
/**
* Gives the column an alias.
@@ -555,8 +582,10 @@ trait Column extends DataFrame {
* // Renames colA to colB in select output.
* df.select($"colA".as('colB))
* }}}
+ *
+ * @group expr_ops
*/
- override def as(alias: Symbol): Column = exprToColumn(Alias(expr, alias.name)())
+ def as(alias: Symbol): Column = Alias(expr, alias.name)()
/**
* Casts the column to a different data type.
@@ -568,8 +597,14 @@ trait Column extends DataFrame {
* // equivalent to
* df.select(df("colA").cast("int"))
* }}}
+ *
+ * @group expr_ops
*/
- def cast(to: DataType): Column = exprToColumn(Cast(expr, to))
+ def cast(to: DataType): Column = expr match {
+ // Lift alias out of cast so we can support col.as("name").cast(IntegerType)
+ case Alias(childExpr, name) => Alias(Cast(childExpr, to), name)()
+ case _ => Cast(expr, to)
+ }
/**
* Casts the column to a different data type, using the canonical string representation
@@ -579,31 +614,73 @@ trait Column extends DataFrame {
* // Casts colA to integer.
* df.select(df("colA").cast("int"))
* }}}
+ *
+ * @group expr_ops
+ */
+ def cast(to: String): Column = cast(to.toLowerCase match {
+ case "string" | "str" => StringType
+ case "boolean" => BooleanType
+ case "byte" => ByteType
+ case "short" => ShortType
+ case "int" => IntegerType
+ case "long" => LongType
+ case "float" => FloatType
+ case "double" => DoubleType
+ case "decimal" => DecimalType.Unlimited
+ case "date" => DateType
+ case "timestamp" => TimestampType
+ case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
+ })
+
+ /**
+ * Returns an ordering used in sorting.
+ * {{{
+ * // Scala: sort a DataFrame by age column in descending order.
+ * df.sort(df("age").desc)
+ *
+ * // Java
+ * df.sort(df.col("age").desc());
+ * }}}
+ *
+ * @group expr_ops
+ */
+ def desc: Column = SortOrder(expr, Descending)
+
+ /**
+ * Returns an ordering used in sorting.
+ * {{{
+ * // Scala: sort a DataFrame by age column in ascending order.
+ * df.sort(df("age").asc)
+ *
+ * // Java
+ * df.sort(df.col("age").asc());
+ * }}}
+ *
+ * @group expr_ops
*/
- def cast(to: String): Column = exprToColumn(
- Cast(expr, to.toLowerCase match {
- case "string" => StringType
- case "boolean" => BooleanType
- case "byte" => ByteType
- case "short" => ShortType
- case "int" => IntegerType
- case "long" => LongType
- case "float" => FloatType
- case "double" => DoubleType
- case "decimal" => DecimalType.Unlimited
- case "date" => DateType
- case "timestamp" => TimestampType
- case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""")
- })
- )
-
- def desc: Column = exprToColumn(SortOrder(expr, Descending), computable = false)
-
- def asc: Column = exprToColumn(SortOrder(expr, Ascending), computable = false)
+ def asc: Column = SortOrder(expr, Ascending)
+
+ /**
+ * Prints the expression to the console for debugging purpose.
+ *
+ * @group df_ops
+ */
+ def explain(extended: Boolean): Unit = {
+ if (extended) {
+ println(expr)
+ } else {
+ println(expr.prettyString)
+ }
+ }
}
-class ColumnName(name: String) extends IncomputableColumn(name) {
+/**
+ * :: Experimental ::
+ * A convenient class used for constructing schema.
+ */
+@Experimental
+class ColumnName(name: String) extends Column(name) {
/** Creates a new AttributeReference of type boolean */
def boolean: StructField = StructField(name, BooleanType)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 6abfb7853cf1c..060ab5e9a0cfa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -17,20 +17,38 @@
package org.apache.spark.sql
+import java.io.CharArrayWriter
+import java.sql.DriverManager
+
+import scala.collection.JavaConversions._
+import scala.language.implicitConversions
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
+import scala.util.control.NonFatal
+
+import com.fasterxml.jackson.core.JsonFactory
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
+import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
+import org.apache.spark.sql.jdbc.JDBCWriteDetails
+import org.apache.spark.sql.json.JsonRDD
+import org.apache.spark.sql.types.{NumericType, StructType}
+import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
private[sql] object DataFrame {
def apply(sqlContext: SQLContext, logicalPlan: LogicalPlan): DataFrame = {
- new DataFrameImpl(sqlContext, logicalPlan)
+ new DataFrame(sqlContext, logicalPlan)
}
}
@@ -39,19 +57,23 @@ private[sql] object DataFrame {
* :: Experimental ::
* A distributed collection of data organized into named columns.
*
- * A [[DataFrame]] is equivalent to a relational table in Spark SQL, and can be created using
- * various functions in [[SQLContext]].
+ * A [[DataFrame]] is equivalent to a relational table in Spark SQL. There are multiple ways
+ * to create a [[DataFrame]]:
* {{{
+ * // Create a DataFrame from Parquet files
* val people = sqlContext.parquetFile("...")
+ *
+ * // Create a DataFrame from data sources
+ * val df =
* }}}
*
* Once created, it can be manipulated using the various domain-specific-language (DSL) functions
- * defined in: [[DataFrame]] (this class), [[Column]], [[Dsl]] for the DSL.
+ * defined in: [[DataFrame]] (this class), [[Column]], and [[functions]].
*
- * To select a column from the data frame, use the apply method:
+ * To select a column from the data frame, use `apply` method in Scala and `col` in Java.
* {{{
* val ageCol = people("age") // in Scala
- * Column ageCol = people.apply("age") // in Java
+ * Column ageCol = people.col("age") // in Java
* }}}
*
* Note that the [[Column]] type can also be manipulated through its various functions.
@@ -71,56 +93,196 @@ private[sql] object DataFrame {
* .groupBy(department("name"), "gender")
* .agg(avg(people("salary")), max(people("age")))
* }}}
+ *
+ * @groupname basic Basic DataFrame functions
+ * @groupname dfops Language Integrated Queries
+ * @groupname rdd RDD Operations
+ * @groupname output Output Operations
+ * @groupname action Actions
*/
// TODO: Improve documentation.
@Experimental
-trait DataFrame extends RDDApi[Row] {
+class DataFrame protected[sql](
+ @transient val sqlContext: SQLContext,
+ @DeveloperApi @transient val queryExecution: SQLContext#QueryExecution)
+ extends RDDApi[Row] with Serializable {
+
+ /**
+ * A constructor that automatically analyzes the logical plan.
+ *
+ * This reports error eagerly as the [[DataFrame]] is constructed, unless
+ * [[SQLConf.dataFrameEagerAnalysis]] is turned off.
+ */
+ def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
+ this(sqlContext, {
+ val qe = sqlContext.executePlan(logicalPlan)
+ if (sqlContext.conf.dataFrameEagerAnalysis) {
+ qe.assertAnalyzed() // This should force analysis and throw errors if there are any
+ }
+ qe
+ })
+ }
+
+ @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match {
+ // For various commands (like DDL) and queries with side effects, we force query optimization to
+ // happen right away to let these side effects take place eagerly.
+ case _: Command |
+ _: InsertIntoTable |
+ _: CreateTableAsSelect[_] |
+ _: CreateTableUsingAsSelect |
+ _: WriteToFile =>
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
+ case _ =>
+ queryExecution.logical
+ }
+
+ /**
+ * An implicit conversion function internal to this class for us to avoid doing
+ * "new DataFrame(...)" everywhere.
+ */
+ @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = {
+ new DataFrame(sqlContext, logicalPlan)
+ }
- val sqlContext: SQLContext
+ protected[sql] def resolve(colName: String): NamedExpression = {
+ queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse {
+ throw new AnalysisException(
+ s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
+ }
+ }
+
+ protected[sql] def numericColumns: Seq[Expression] = {
+ schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
+ queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
+ }
+ }
- @DeveloperApi
- def queryExecution: SQLContext#QueryExecution
+ /**
+ * Internal API for Python
+ * @param numRows Number of rows to show
+ */
+ private[sql] def showString(numRows: Int): String = {
+ val data = take(numRows)
+ val numCols = schema.fieldNames.length
+
+ // For cells that are beyond 20 characters, replace it with the first 17 and "..."
+ val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
+ row.toSeq.map { cell =>
+ val str = if (cell == null) "null" else cell.toString
+ if (str.length > 20) str.substring(0, 17) + "..." else str
+ }: Seq[String]
+ }
+
+ // Compute the width of each column
+ val colWidths = Array.fill(numCols)(0)
+ for (row <- rows) {
+ for ((cell, i) <- row.zipWithIndex) {
+ colWidths(i) = math.max(colWidths(i), cell.length)
+ }
+ }
+
+ // Pad the cells
+ rows.map { row =>
+ row.zipWithIndex.map { case (cell, i) =>
+ String.format(s"%-${colWidths(i)}s", cell)
+ }.mkString(" ")
+ }.mkString("\n")
+ }
- protected[sql] def logicalPlan: LogicalPlan
+ override def toString: String = {
+ try {
+ schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]")
+ } catch {
+ case NonFatal(e) =>
+ s"Invalid tree; ${e.getMessage}:\n$queryExecution"
+ }
+ }
/** Left here for backward compatibility. */
- @deprecated("1.3.0", "use toDataFrame")
+ @deprecated("1.3.0", "use toDF")
def toSchemaRDD: DataFrame = this
/**
- * Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala.
+ * Returns the object itself.
+ * @group basic
*/
- def toDataFrame: DataFrame = this
+ // This is declared with parentheses to prevent the Scala compiler from treating
+ // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
+ def toDF(): DataFrame = this
/**
* Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion
* from a RDD of tuples into a [[DataFrame]] with meaningful names. For example:
* {{{
* val rdd: RDD[(Int, String)] = ...
- * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2
- * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name"
+ * rdd.toDF() // this implicit conversion creates a DataFrame with column name _1 and _2
+ * rdd.toDF("id", "name") // this creates a DataFrame with column name "id" and "name"
* }}}
+ * @group basic
*/
@scala.annotation.varargs
- def toDataFrame(colNames: String*): DataFrame
+ def toDF(colNames: String*): DataFrame = {
+ require(schema.size == colNames.size,
+ "The number of columns doesn't match.\n" +
+ "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
+ "New column names: " + colNames.mkString(", "))
+
+ val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) =>
+ apply(oldName).as(newName)
+ }
+ select(newCols :_*)
+ }
- /** Returns the schema of this [[DataFrame]]. */
- def schema: StructType
+ /**
+ * Returns the schema of this [[DataFrame]].
+ * @group basic
+ */
+ def schema: StructType = queryExecution.analyzed.schema
- /** Returns all column names and their data types as an array. */
- def dtypes: Array[(String, String)]
+ /**
+ * Returns all column names and their data types as an array.
+ * @group basic
+ */
+ def dtypes: Array[(String, String)] = schema.fields.map { field =>
+ (field.name, field.dataType.toString)
+ }
- /** Returns all column names as an array. */
+ /**
+ * Returns all column names as an array.
+ * @group basic
+ */
def columns: Array[String] = schema.fields.map(_.name)
- /** Prints the schema to the console in a nice tree format. */
- def printSchema(): Unit
+ /**
+ * Prints the schema to the console in a nice tree format.
+ * @group basic
+ */
+ def printSchema(): Unit = println(schema.treeString)
+
+ /**
+ * Prints the plans (logical and physical) to the console for debugging purpose.
+ * @group basic
+ */
+ def explain(extended: Boolean): Unit = {
+ ExplainCommand(
+ queryExecution.logical,
+ extended = extended).queryExecution.executedPlan.executeCollect().map {
+ r => println(r.getString(0))
+ }
+ }
+
+ /**
+ * Only prints the physical plan to the console for debugging purpose.
+ * @group basic
+ */
+ def explain(): Unit = explain(extended = false)
/**
* Returns true if the `collect` and `take` methods can be run locally
* (without any Spark executors).
+ * @group basic
*/
- def isLocal: Boolean
+ def isLocal: Boolean = logicalPlan.isInstanceOf[LocalRelation]
/**
* Displays the [[DataFrame]] in a tabular form. For example:
@@ -132,8 +294,15 @@ trait DataFrame extends RDDApi[Row] {
* 1983 03 0.410516 0.442194
* 1984 04 0.450090 0.483521
* }}}
+ * @param numRows Number of rows to show
+ * @group basic
+ */
+ def show(numRows: Int): Unit = println(showString(numRows))
+
+ /**
+ * Displays the top 20 rows of [[DataFrame]] in a tabular form.
*/
- def show(): Unit
+ def show(): Unit = show(20)
/**
* Cartesian join with another [[DataFrame]].
@@ -141,8 +310,11 @@ trait DataFrame extends RDDApi[Row] {
* Note that cartesian joins are very expensive without an extra filter that can be pushed down.
*
* @param right Right side of the join operation.
+ * @group dfops
*/
- def join(right: DataFrame): DataFrame
+ def join(right: DataFrame): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
+ }
/**
* Inner join with another [[DataFrame]], using the given join expression.
@@ -152,8 +324,11 @@ trait DataFrame extends RDDApi[Row] {
* df1.join(df2, $"df1Key" === $"df2Key")
* df1.join(df2).where($"df1Key" === $"df2Key")
* }}}
+ * @group dfops
*/
- def join(right: DataFrame, joinExprs: Column): DataFrame
+ def join(right: DataFrame, joinExprs: Column): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, joinType = Inner, Some(joinExprs.expr))
+ }
/**
* Join with another [[DataFrame]], using the given join expression. The following performs
@@ -161,19 +336,22 @@ trait DataFrame extends RDDApi[Row] {
*
* {{{
* // Scala:
- * import org.apache.spark.sql.dsl._
+ * import org.apache.spark.sql.functions._
* df1.join(df2, "outer", $"df1Key" === $"df2Key")
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df1.join(df2, "outer", col("df1Key") === col("df2Key"));
* }}}
*
* @param right Right side of the join.
* @param joinExprs Join expression.
* @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`.
+ * @group dfops
*/
- def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame
+ def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = {
+ Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))
+ }
/**
* Returns a new [[DataFrame]] sorted by the specified column, all in ascending order.
@@ -183,71 +361,94 @@ trait DataFrame extends RDDApi[Row] {
* df.sort($"sortcol")
* df.sort($"sortcol".asc)
* }}}
+ * @group dfops
*/
@scala.annotation.varargs
- def sort(sortCol: String, sortCols: String*): DataFrame
+ def sort(sortCol: String, sortCols: String*): DataFrame = {
+ sort((sortCol +: sortCols).map(apply) :_*)
+ }
/**
* Returns a new [[DataFrame]] sorted by the given expressions. For example:
* {{{
* df.sort($"col1", $"col2".desc)
* }}}
+ * @group dfops
*/
@scala.annotation.varargs
- def sort(sortExprs: Column*): DataFrame
+ def sort(sortExprs: Column*): DataFrame = {
+ val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
+ col.expr match {
+ case expr: SortOrder =>
+ expr
+ case expr: Expression =>
+ SortOrder(expr, Ascending)
+ }
+ }
+ Sort(sortOrder, global = true, logicalPlan)
+ }
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
* This is an alias of the `sort` function.
+ * @group dfops
*/
@scala.annotation.varargs
- def orderBy(sortCol: String, sortCols: String*): DataFrame
+ def orderBy(sortCol: String, sortCols: String*): DataFrame = sort(sortCol, sortCols :_*)
/**
* Returns a new [[DataFrame]] sorted by the given expressions.
* This is an alias of the `sort` function.
+ * @group dfops
*/
@scala.annotation.varargs
- def orderBy(sortExprs: Column*): DataFrame
+ def orderBy(sortExprs: Column*): DataFrame = sort(sortExprs :_*)
/**
* Selects column based on the column name and return it as a [[Column]].
+ * @group dfops
*/
def apply(colName: String): Column = col(colName)
/**
* Selects column based on the column name and return it as a [[Column]].
+ * @group dfops
*/
- def col(colName: String): Column
-
- /**
- * Selects a set of expressions, wrapped in a Product.
- * {{{
- * // The following two are equivalent:
- * df.apply(($"colA", $"colB" + 1))
- * df.select($"colA", $"colB" + 1)
- * }}}
- */
- def apply(projection: Product): DataFrame
+ def col(colName: String): Column = colName match {
+ case "*" =>
+ Column(ResolvedStar(schema.fieldNames.map(resolve)))
+ case _ =>
+ val expr = resolve(colName)
+ Column(expr)
+ }
/**
* Returns a new [[DataFrame]] with an alias set.
+ * @group dfops
*/
- def as(alias: String): DataFrame
+ def as(alias: String): DataFrame = Subquery(alias, logicalPlan)
/**
* (Scala-specific) Returns a new [[DataFrame]] with an alias set.
+ * @group dfops
*/
- def as(alias: Symbol): DataFrame
+ def as(alias: Symbol): DataFrame = as(alias.name)
/**
* Selects a set of expressions.
* {{{
* df.select($"colA", $"colB" + 1)
* }}}
+ * @group dfops
*/
@scala.annotation.varargs
- def select(cols: Column*): DataFrame
+ def select(cols: Column*): DataFrame = {
+ val namedExpressions = cols.map {
+ case Column(expr: NamedExpression) => expr
+ case Column(expr: Expression) => Alias(expr, expr.prettyString)()
+ }
+ Project(namedExpressions.toSeq, logicalPlan)
+ }
/**
* Selects a set of columns. This is a variant of `select` that can only select
@@ -258,9 +459,10 @@ trait DataFrame extends RDDApi[Row] {
* df.select("colA", "colB")
* df.select($"colA", $"colB")
* }}}
+ * @group dfops
*/
@scala.annotation.varargs
- def select(col: String, cols: String*): DataFrame
+ def select(col: String, cols: String*): DataFrame = select((col +: cols).map(Column(_)) :_*)
/**
* Selects a set of SQL expressions. This is a variant of `select` that accepts
@@ -269,9 +471,14 @@ trait DataFrame extends RDDApi[Row] {
* {{{
* df.selectExpr("colA", "colB as newName", "abs(colC)")
* }}}
+ * @group dfops
*/
@scala.annotation.varargs
- def selectExpr(exprs: String*): DataFrame
+ def selectExpr(exprs: String*): DataFrame = {
+ select(exprs.map { expr =>
+ Column(new SqlParser().parseExpression(expr))
+ }: _*)
+ }
/**
* Filters rows using the given condition.
@@ -281,16 +488,20 @@ trait DataFrame extends RDDApi[Row] {
* peopleDf.where($"age" > 15)
* peopleDf($"age" > 15)
* }}}
+ * @group dfops
*/
- def filter(condition: Column): DataFrame
+ def filter(condition: Column): DataFrame = Filter(condition.expr, logicalPlan)
/**
* Filters rows using the given SQL expression.
* {{{
* peopleDf.filter("age > 15")
* }}}
+ * @group dfops
*/
- def filter(conditionExpr: String): DataFrame
+ def filter(conditionExpr: String): DataFrame = {
+ filter(Column(new SqlParser().parseExpression(conditionExpr)))
+ }
/**
* Filters rows using the given condition. This is an alias for `filter`.
@@ -300,19 +511,9 @@ trait DataFrame extends RDDApi[Row] {
* peopleDf.where($"age" > 15)
* peopleDf($"age" > 15)
* }}}
+ * @group dfops
*/
- def where(condition: Column): DataFrame
-
- /**
- * Filters rows using the given condition. This is a shorthand meant for Scala.
- * {{{
- * // The following are equivalent:
- * peopleDf.filter($"age" > 15)
- * peopleDf.where($"age" > 15)
- * peopleDf($"age" > 15)
- * }}}
- */
- def apply(condition: Column): DataFrame
+ def where(condition: Column): DataFrame = filter(condition)
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -328,9 +529,10 @@ trait DataFrame extends RDDApi[Row] {
* "age" -> "max"
* ))
* }}}
+ * @group dfops
*/
@scala.annotation.varargs
- def groupBy(cols: Column*): GroupedData
+ def groupBy(cols: Column*): GroupedData = new GroupedData(this, cols.map(_.expr))
/**
* Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them.
@@ -349,9 +551,13 @@ trait DataFrame extends RDDApi[Row] {
* "age" -> "max"
* ))
* }}}
+ * @group dfops
*/
@scala.annotation.varargs
- def groupBy(col1: String, cols: String*): GroupedData
+ def groupBy(col1: String, cols: String*): GroupedData = {
+ val colNames: Seq[String] = col1 +: cols
+ new GroupedData(this, colNames.map(colName => resolve(colName)))
+ }
/**
* (Scala-specific) Compute aggregates by specifying a map from column name to
@@ -365,6 +571,7 @@ trait DataFrame extends RDDApi[Row] {
* "expense" -> "sum"
* )
* }}}
+ * @group dfops
*/
def agg(aggExpr: (String, String), aggExprs: (String, String)*): DataFrame = {
groupBy().agg(aggExpr, aggExprs :_*)
@@ -377,6 +584,7 @@ trait DataFrame extends RDDApi[Row] {
* df.agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
+ * @group dfops
*/
def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs)
@@ -387,6 +595,7 @@ trait DataFrame extends RDDApi[Row] {
* df.agg(Map("age" -> "max", "salary" -> "avg"))
* df.groupBy().agg(Map("age" -> "max", "salary" -> "avg"))
* }}
+ * @group dfops
*/
def agg(exprs: java.util.Map[String, String]): DataFrame = groupBy().agg(exprs)
@@ -397,6 +606,7 @@ trait DataFrame extends RDDApi[Row] {
* df.agg(max($"age"), avg($"salary"))
* df.groupBy().agg(max($"age"), avg($"salary"))
* }}
+ * @group dfops
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*)
@@ -404,26 +614,30 @@ trait DataFrame extends RDDApi[Row] {
/**
* Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function
* and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]].
+ * @group dfops
*/
- def limit(n: Int): DataFrame
+ def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan)
/**
* Returns a new [[DataFrame]] containing union of rows in this frame and another frame.
* This is equivalent to `UNION ALL` in SQL.
+ * @group dfops
*/
- def unionAll(other: DataFrame): DataFrame
+ def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan)
/**
* Returns a new [[DataFrame]] containing rows only in both this frame and another frame.
* This is equivalent to `INTERSECT` in SQL.
+ * @group dfops
*/
- def intersect(other: DataFrame): DataFrame
+ def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan)
/**
* Returns a new [[DataFrame]] containing rows in this frame but not in another frame.
* This is equivalent to `EXCEPT` in SQL.
+ * @group dfops
*/
- def except(other: DataFrame): DataFrame
+ def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan)
/**
* Returns a new [[DataFrame]] by sampling a fraction of rows.
@@ -431,104 +645,207 @@ trait DataFrame extends RDDApi[Row] {
* @param withReplacement Sample with replacement or not.
* @param fraction Fraction of rows to generate.
* @param seed Seed for sampling.
+ * @group dfops
*/
- def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame
+ def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
+ Sample(fraction, withReplacement, seed, logicalPlan)
+ }
/**
* Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed.
*
* @param withReplacement Sample with replacement or not.
* @param fraction Fraction of rows to generate.
+ * @group dfops
*/
def sample(withReplacement: Boolean, fraction: Double): DataFrame = {
sample(withReplacement, fraction, Utils.random.nextLong)
}
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] where each row has been expanded to zero or more
+ * rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. The columns of
+ * the input row are implicitly joined with each row that is output by the function.
+ *
+ * The following example uses this function to count the number of books which contain
+ * a given word:
+ *
+ * {{{
+ * case class Book(title: String, words: String)
+ * val df: RDD[Book]
+ *
+ * case class Word(word: String)
+ * val allWords = df.explode('words) {
+ * case Row(words: String) => words.split(" ").map(Word(_))
+ * }
+ *
+ * val bookCountPerWord = allWords.groupBy("word").agg(countDistinct("title"))
+ * }}}
+ * @group dfops
+ */
+ def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
+ val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
+ val attributes = schema.toAttributes
+ val rowFunction =
+ f.andThen(_.map(ScalaReflection.convertToCatalyst(_, schema).asInstanceOf[Row]))
+ val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
+
+ Generate(generator, join = true, outer = false, None, logicalPlan)
+ }
+
+ /**
+ * (Scala-specific) Returns a new [[DataFrame]] where a single column has been expanded to zero
+ * or more rows by the provided function. This is similar to a `LATERAL VIEW` in HiveQL. All
+ * columns of the input row are implicitly joined with each value that is output by the function.
+ *
+ * {{{
+ * df.explode("words", "word")(words: String => words.split(" "))
+ * }}}
+ * @group dfops
+ */
+ def explode[A, B : TypeTag](inputColumn: String, outputColumn: String)(f: A => TraversableOnce[B])
+ : DataFrame = {
+ val dataType = ScalaReflection.schemaFor[B].dataType
+ val attributes = AttributeReference(outputColumn, dataType)() :: Nil
+ def rowFunction(row: Row) = {
+ f(row(0).asInstanceOf[A]).map(o => Row(ScalaReflection.convertToCatalyst(o, dataType)))
+ }
+ val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
+
+ Generate(generator, join = true, outer = false, None, logicalPlan)
+ }
+
/////////////////////////////////////////////////////////////////////////////
/**
* Returns a new [[DataFrame]] by adding a column.
+ * @group dfops
*/
- def addColumn(colName: String, col: Column): DataFrame
+ def withColumn(colName: String, col: Column): DataFrame = select(Column("*"), col.as(colName))
/**
* Returns a new [[DataFrame]] with a column renamed.
+ * @group dfops
*/
- def renameColumn(existingName: String, newName: String): DataFrame
+ def withColumnRenamed(existingName: String, newName: String): DataFrame = {
+ val resolver = sqlContext.analyzer.resolver
+ val colNames = schema.map { field =>
+ val name = field.name
+ if (resolver(name, existingName)) Column(name).as(newName) else Column(name)
+ }
+ select(colNames :_*)
+ }
/**
* Returns the first `n` rows.
*/
- def head(n: Int): Array[Row]
+ def head(n: Int): Array[Row] = limit(n).collect()
/**
* Returns the first row.
*/
- def head(): Row
+ def head(): Row = head(1).head
/**
* Returns the first row. Alias for head().
*/
- override def first(): Row
+ override def first(): Row = head()
/**
* Returns a new RDD by applying a function to all rows of this DataFrame.
+ * @group rdd
*/
- override def map[R: ClassTag](f: Row => R): RDD[R]
+ override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f)
/**
* Returns a new RDD by first applying a function to all rows of this [[DataFrame]],
* and then flattening the results.
+ * @group rdd
*/
- override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R]
+ override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
/**
* Returns a new RDD by applying a function to each partition of this DataFrame.
+ * @group rdd
*/
- override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R]
+ override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
+ rdd.mapPartitions(f)
+ }
+
/**
* Applies a function `f` to all rows.
+ * @group rdd
*/
- override def foreach(f: Row => Unit): Unit
+ override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
/**
* Applies a function f to each partition of this [[DataFrame]].
+ * @group rdd
*/
- override def foreachPartition(f: Iterator[Row] => Unit): Unit
+ override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
/**
* Returns the first `n` rows in the [[DataFrame]].
+ * @group action
*/
- override def take(n: Int): Array[Row]
+ override def take(n: Int): Array[Row] = head(n)
/**
* Returns an array that contains all of [[Row]]s in this [[DataFrame]].
+ * @group action
*/
- override def collect(): Array[Row]
+ override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
/**
* Returns a Java list that contains all of [[Row]]s in this [[DataFrame]].
+ * @group action
*/
- override def collectAsList(): java.util.List[Row]
+ override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
/**
* Returns the number of rows in the [[DataFrame]].
+ * @group action
*/
- override def count(): Long
+ override def count(): Long = groupBy().count().collect().head.getLong(0)
/**
* Returns a new [[DataFrame]] that has exactly `numPartitions` partitions.
+ * @group rdd
*/
- override def repartition(numPartitions: Int): DataFrame
+ override def repartition(numPartitions: Int): DataFrame = {
+ sqlContext.createDataFrame(
+ queryExecution.toRdd.map(_.copy()).repartition(numPartitions), schema)
+ }
- /** Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]]. */
- override def distinct: DataFrame
+ /**
+ * Returns a new [[DataFrame]] that contains only the unique rows from this [[DataFrame]].
+ * @group dfops
+ */
+ override def distinct: DataFrame = Distinct(logicalPlan)
- override def persist(): this.type
+ /**
+ * @group basic
+ */
+ override def persist(): this.type = {
+ sqlContext.cacheManager.cacheQuery(this)
+ this
+ }
- override def persist(newLevel: StorageLevel): this.type
+ /**
+ * @group basic
+ */
+ override def persist(newLevel: StorageLevel): this.type = {
+ sqlContext.cacheManager.cacheQuery(this, None, newLevel)
+ this
+ }
- override def unpersist(blocking: Boolean): this.type
+ /**
+ * @group basic
+ */
+ override def unpersist(blocking: Boolean): this.type = {
+ sqlContext.cacheManager.tryUncacheQuery(this, blocking)
+ this
+ }
/////////////////////////////////////////////////////////////////////////////
// I/O
@@ -536,16 +853,23 @@ trait DataFrame extends RDDApi[Row] {
/**
* Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s.
+ * @group rdd
*/
- def rdd: RDD[Row]
+ def rdd: RDD[Row] = {
+ // use a local variable to make sure the map closure doesn't capture the whole DataFrame
+ val schema = this.schema
+ queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
+ }
/**
* Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
+ * @group rdd
*/
def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD()
/**
* Returns the content of the [[DataFrame]] as a [[JavaRDD]] of [[Row]]s.
+ * @group rdd
*/
def javaRDD: JavaRDD[Row] = toJavaRDD
@@ -553,96 +877,326 @@ trait DataFrame extends RDDApi[Row] {
* Registers this RDD as a temporary table using the given name. The lifetime of this temporary
* table is tied to the [[SQLContext]] that was used to create this DataFrame.
*
- * @group schema
+ * @group basic
*/
- def registerTempTable(tableName: String): Unit
+ def registerTempTable(tableName: String): Unit = {
+ sqlContext.registerDataFrameAsTable(this, tableName)
+ }
/**
* Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema.
* Files that are written out using this method can be read back in as a [[DataFrame]]
* using the `parquetFile` function in [[SQLContext]].
+ * @group output
*/
- def saveAsParquetFile(path: String): Unit
+ def saveAsParquetFile(path: String): Unit = {
+ if (sqlContext.conf.parquetUseDataSourceApi) {
+ save("org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> path))
+ } else {
+ sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
+ }
+ }
/**
* :: Experimental ::
- * Creates a table from the the contents of this DataFrame. This will fail if the table already
- * exists.
+ * Creates a table from the the contents of this DataFrame.
+ * It will use the default data source configured by spark.sql.sources.default.
+ * This will fail if the table already exists.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ * @group output
+ */
+ @Experimental
+ def saveAsTable(tableName: String): Unit = {
+ saveAsTable(tableName, SaveMode.ErrorIfExists)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table from the the contents of this DataFrame, using the default data source
+ * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode.
*
* Note that this currently only works with DataFrames that are created from a HiveContext as
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
* be the target of an `insertInto`.
+ * @group output
*/
@Experimental
- def saveAsTable(tableName: String): Unit
+ def saveAsTable(tableName: String, mode: SaveMode): Unit = {
+ if (sqlContext.catalog.tableExists(Seq(tableName)) && mode == SaveMode.Append) {
+ // If table already exists and the save mode is Append,
+ // we will just call insertInto to append the contents of this DataFrame.
+ insertInto(tableName, overwrite = false)
+ } else {
+ val dataSourceName = sqlContext.conf.defaultDataSourceName
+ saveAsTable(tableName, dataSourceName, mode)
+ }
+ }
/**
* :: Experimental ::
- * Creates a table from the the contents of this DataFrame based on a given data source and
- * a set of options. This will fail if the table already exists.
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source and a set of options,
+ * using [[SaveMode.ErrorIfExists]] as the save mode.
*
* Note that this currently only works with DataFrames that are created from a HiveContext as
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
* be the target of an `insertInto`.
+ * @group output
+ */
+ @Experimental
+ def saveAsTable(tableName: String, source: String): Unit = {
+ saveAsTable(tableName, source, SaveMode.ErrorIfExists)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source, [[SaveMode]] specified by mode, and a set of options.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ * @group output
+ */
+ @Experimental
+ def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = {
+ saveAsTable(tableName, source, mode, Map.empty[String, String])
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a table at the given path from the the contents of this DataFrame
+ * based on a given data source, [[SaveMode]] specified by mode, and a set of options.
+ *
+ * Note that this currently only works with DataFrames that are created from a HiveContext as
+ * there is no notion of a persisted catalog in a standard SQL context. Instead you can write
+ * an RDD out to a parquet file, and then register that file as a table. This "table" can then
+ * be the target of an `insertInto`.
+ * @group output
*/
@Experimental
def saveAsTable(
tableName: String,
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit
+ source: String,
+ mode: SaveMode,
+ options: java.util.Map[String, String]): Unit = {
+ saveAsTable(tableName, source, mode, options.toMap)
+ }
/**
* :: Experimental ::
- * Creates a table from the the contents of this DataFrame based on a given data source and
- * a set of options. This will fail if the table already exists.
+ * (Scala-specific)
+ * Creates a table from the the contents of this DataFrame based on a given data source,
+ * [[SaveMode]] specified by mode, and a set of options.
*
* Note that this currently only works with DataFrames that are created from a HiveContext as
* there is no notion of a persisted catalog in a standard SQL context. Instead you can write
* an RDD out to a parquet file, and then register that file as a table. This "table" can then
* be the target of an `insertInto`.
+ * @group output
*/
@Experimental
def saveAsTable(
tableName: String,
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = {
+ val cmd =
+ CreateTableUsingAsSelect(
+ tableName,
+ source,
+ temporary = false,
+ mode,
+ options,
+ logicalPlan)
+
+ sqlContext.executePlan(cmd).toRdd
+ }
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path,
+ * using the default data source configured by spark.sql.sources.default and
+ * [[SaveMode.ErrorIfExists]] as the save mode.
+ * @group output
+ */
+ @Experimental
+ def save(path: String): Unit = {
+ save(path, SaveMode.ErrorIfExists)
+ }
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode,
+ * using the default data source configured by spark.sql.sources.default.
+ * @group output
+ */
@Experimental
- def save(path: String): Unit
+ def save(path: String, mode: SaveMode): Unit = {
+ val dataSourceName = sqlContext.conf.defaultDataSourceName
+ save(path, dataSourceName, mode)
+ }
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path based on the given data source,
+ * using [[SaveMode.ErrorIfExists]] as the save mode.
+ * @group output
+ */
+ @Experimental
+ def save(path: String, source: String): Unit = {
+ save(source, SaveMode.ErrorIfExists, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame to the given path based on the given data source and
+ * [[SaveMode]] specified by mode.
+ * @group output
+ */
+ @Experimental
+ def save(path: String, source: String, mode: SaveMode): Unit = {
+ save(source, mode, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Saves the contents of this DataFrame based on the given data source,
+ * [[SaveMode]] specified by mode, and a set of options.
+ * @group output
+ */
@Experimental
def save(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit
+ source: String,
+ mode: SaveMode,
+ options: java.util.Map[String, String]): Unit = {
+ save(source, mode, options.toMap)
+ }
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Saves the contents of this DataFrame based on the given data source,
+ * [[SaveMode]] specified by mode, and a set of options
+ * @group output
+ */
@Experimental
def save(
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit
+ source: String,
+ mode: SaveMode,
+ options: Map[String, String]): Unit = {
+ ResolvedDataSource(sqlContext, source, mode, options, this)
+ }
/**
* :: Experimental ::
* Adds the rows from this RDD to the specified table, optionally overwriting the existing data.
+ * @group output
*/
@Experimental
- def insertInto(tableName: String, overwrite: Boolean): Unit
+ def insertInto(tableName: String, overwrite: Boolean): Unit = {
+ sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
+ Map.empty, logicalPlan, overwrite)).toRdd
+ }
/**
* :: Experimental ::
* Adds the rows from this RDD to the specified table.
* Throws an exception if the table already exists.
+ * @group output
*/
@Experimental
def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false)
/**
* Returns the content of the [[DataFrame]] as a RDD of JSON strings.
+ * @group rdd
*/
- def toJSON: RDD[String]
+ def toJSON: RDD[String] = {
+ val rowSchema = this.schema
+ this.mapPartitions { iter =>
+ val writer = new CharArrayWriter()
+ // create the Generator without separator inserted between 2 records
+ val gen = new JsonFactory().createGenerator(writer).setRootValueSeparator(null)
+
+ new Iterator[String] {
+ override def hasNext = iter.hasNext
+ override def next(): String = {
+ JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
+ gen.flush()
+
+ val json = writer.toString
+ if (hasNext) {
+ writer.reset()
+ } else {
+ gen.close()
+ }
+
+ json
+ }
+ }
+ }
+ }
+
+ ////////////////////////////////////////////////////////////////////////////
+ // JDBC Write Support
+ ////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Save this RDD to a JDBC database at `url` under the table name `table`.
+ * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements.
+ * If you pass `true` for `allowExisting`, it will drop any table with the
+ * given name; if you pass `false`, it will throw if the table already
+ * exists.
+ * @group output
+ */
+ def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = {
+ val conn = DriverManager.getConnection(url)
+ try {
+ if (allowExisting) {
+ val sql = s"DROP TABLE IF EXISTS $table"
+ conn.prepareStatement(sql).executeUpdate()
+ }
+ val schema = JDBCWriteDetails.schemaString(this, url)
+ val sql = s"CREATE TABLE $table ($schema)"
+ conn.prepareStatement(sql).executeUpdate()
+ } finally {
+ conn.close()
+ }
+ JDBCWriteDetails.saveTable(this, url, table)
+ }
+
+ /**
+ * Save this RDD to a JDBC database at `url` under the table name `table`.
+ * Assumes the table already exists and has a compatible schema. If you
+ * pass `true` for `overwrite`, it will `TRUNCATE` the table before
+ * performing the `INSERT`s.
+ *
+ * The table must already exist on the database. It must have a schema
+ * that is compatible with the schema of this RDD; inserting the rows of
+ * the RDD in order via the simple statement
+ * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail.
+ * @group output
+ */
+ def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = {
+ if (overwrite) {
+ val conn = DriverManager.getConnection(url)
+ try {
+ val sql = s"TRUNCATE TABLE $table"
+ conn.prepareStatement(sql).executeUpdate()
+ } finally {
+ conn.close()
+ }
+ }
+ JDBCWriteDetails.saveTable(this, url, table)
+ }
////////////////////////////////////////////////////////////////////////////
// for Python API
@@ -651,5 +1205,9 @@ trait DataFrame extends RDDApi[Row] {
/**
* Converts a JavaRDD to a PythonRDD.
*/
- protected[sql] def javaToPython: JavaRDD[Array[Byte]]
+ protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
+ val fieldTypes = schema.fields.map(_.dataType)
+ val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
+ SerDeUtil.javaToPython(jrdd)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala
similarity index 65%
rename from sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala
index ac479b26a7c6a..a3187fe3230fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ComputableColumn.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameHolder.scala
@@ -17,17 +17,14 @@
package org.apache.spark.sql
-import scala.language.implicitConversions
+/**
+ * A container for a [[DataFrame]], used for implicit conversions.
+ */
+private[sql] case class DataFrameHolder(df: DataFrame) {
-import org.apache.spark.sql.catalyst.expressions.Expression
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+ // This is declared with parentheses to prevent the Scala compiler from treating
+ // `rdd.toDF("1")` as invoking this toDF and then apply on the returned DataFrame.
+ def toDF(): DataFrame = df
-
-private[sql] class ComputableColumn protected[sql](
- sqlContext: SQLContext,
- protected[sql] val plan: LogicalPlan,
- protected[sql] val expr: Expression)
- extends DataFrameImpl(sqlContext, plan) with Column {
-
- override def isComputable: Boolean = true
+ def toDF(colNames: String*): DataFrame = df.toDF(colNames :_*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
deleted file mode 100644
index 73393295ab0a5..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameImpl.scala
+++ /dev/null
@@ -1,431 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import scala.language.implicitConversions
-import scala.reflect.ClassTag
-import scala.collection.JavaConversions._
-
-import com.fasterxml.jackson.core.JsonFactory
-
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.api.python.SerDeUtil
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython}
-import org.apache.spark.sql.json.JsonRDD
-import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsLogicalPlan}
-import org.apache.spark.sql.types.{NumericType, StructType}
-
-
-/**
- * Internal implementation of [[DataFrame]]. Users of the API should use [[DataFrame]] directly.
- */
-private[sql] class DataFrameImpl protected[sql](
- override val sqlContext: SQLContext,
- val queryExecution: SQLContext#QueryExecution)
- extends DataFrame {
-
- /**
- * A constructor that automatically analyzes the logical plan. This reports error eagerly
- * as the [[DataFrame]] is constructed.
- */
- def this(sqlContext: SQLContext, logicalPlan: LogicalPlan) = {
- this(sqlContext, {
- val qe = sqlContext.executePlan(logicalPlan)
- if (sqlContext.conf.dataFrameEagerAnalysis) {
- qe.analyzed // This should force analysis and throw errors if there are any
- }
- qe
- })
- }
-
- @transient protected[sql] override val logicalPlan: LogicalPlan = queryExecution.logical match {
- // For various commands (like DDL) and queries with side effects, we force query optimization to
- // happen right away to let these side effects take place eagerly.
- case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile =>
- LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
- case _ =>
- queryExecution.logical
- }
-
- /**
- * An implicit conversion function internal to this class for us to avoid doing
- * "new DataFrameImpl(...)" everywhere.
- */
- @inline private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = {
- new DataFrameImpl(sqlContext, logicalPlan)
- }
-
- protected[sql] def resolve(colName: String): NamedExpression = {
- queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse {
- throw new RuntimeException(
- s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""")
- }
- }
-
- protected[sql] def numericColumns: Seq[Expression] = {
- schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n =>
- queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get
- }
- }
-
- override def toDataFrame(colNames: String*): DataFrame = {
- require(schema.size == colNames.size,
- "The number of columns doesn't match.\n" +
- "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
- "New column names: " + colNames.mkString(", "))
-
- val newCols = schema.fieldNames.zip(colNames).map { case (oldName, newName) =>
- apply(oldName).as(newName)
- }
- select(newCols :_*)
- }
-
- override def schema: StructType = queryExecution.analyzed.schema
-
- override def dtypes: Array[(String, String)] = schema.fields.map { field =>
- (field.name, field.dataType.toString)
- }
-
- override def columns: Array[String] = schema.fields.map(_.name)
-
- override def printSchema(): Unit = println(schema.treeString)
-
- override def isLocal: Boolean = {
- logicalPlan.isInstanceOf[LocalRelation]
- }
-
- override def show(): Unit = {
- val data = take(20)
- val numCols = schema.fieldNames.length
-
- // For cells that are beyond 20 characters, replace it with the first 17 and "..."
- val rows: Seq[Seq[String]] = schema.fieldNames.toSeq +: data.map { row =>
- row.toSeq.map { cell =>
- val str = if (cell == null) "null" else cell.toString
- if (str.length > 20) str.substring(0, 17) + "..." else str
- } : Seq[String]
- }
-
- // Compute the width of each column
- val colWidths = Array.fill(numCols)(0)
- for (row <- rows) {
- for ((cell, i) <- row.zipWithIndex) {
- colWidths(i) = math.max(colWidths(i), cell.length)
- }
- }
-
- // Pad the cells and print them
- println(rows.map { row =>
- row.zipWithIndex.map { case (cell, i) =>
- String.format(s"%-${colWidths(i)}s", cell)
- }.mkString(" ")
- }.mkString("\n"))
- }
-
- override def join(right: DataFrame): DataFrame = {
- Join(logicalPlan, right.logicalPlan, joinType = Inner, None)
- }
-
- override def join(right: DataFrame, joinExprs: Column): DataFrame = {
- Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr))
- }
-
- override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = {
- Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr))
- }
-
- override def sort(sortCol: String, sortCols: String*): DataFrame = {
- sort((sortCol +: sortCols).map(apply) :_*)
- }
-
- override def sort(sortExprs: Column*): DataFrame = {
- val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
- col.expr match {
- case expr: SortOrder =>
- expr
- case expr: Expression =>
- SortOrder(expr, Ascending)
- }
- }
- Sort(sortOrder, global = true, logicalPlan)
- }
-
- override def orderBy(sortCol: String, sortCols: String*): DataFrame = {
- sort(sortCol, sortCols :_*)
- }
-
- override def orderBy(sortExprs: Column*): DataFrame = {
- sort(sortExprs :_*)
- }
-
- override def col(colName: String): Column = colName match {
- case "*" =>
- Column(ResolvedStar(schema.fieldNames.map(resolve)))
- case _ =>
- val expr = resolve(colName)
- Column(sqlContext, Project(Seq(expr), logicalPlan), expr)
- }
-
- override def apply(projection: Product): DataFrame = {
- require(projection.productArity >= 1)
- select(projection.productIterator.map {
- case c: Column => c
- case o: Any => Column(Literal(o))
- }.toSeq :_*)
- }
-
- override def as(alias: String): DataFrame = Subquery(alias, logicalPlan)
-
- override def as(alias: Symbol): DataFrame = Subquery(alias.name, logicalPlan)
-
- override def select(cols: Column*): DataFrame = {
- val exprs = cols.zipWithIndex.map {
- case (Column(expr: NamedExpression), _) =>
- expr
- case (Column(expr: Expression), _) =>
- Alias(expr, expr.toString)()
- }
- Project(exprs.toSeq, logicalPlan)
- }
-
- override def select(col: String, cols: String*): DataFrame = {
- select((col +: cols).map(Column(_)) :_*)
- }
-
- override def selectExpr(exprs: String*): DataFrame = {
- select(exprs.map { expr =>
- Column(new SqlParser().parseExpression(expr))
- }: _*)
- }
-
- override def addColumn(colName: String, col: Column): DataFrame = {
- select(Column("*"), col.as(colName))
- }
-
- override def renameColumn(existingName: String, newName: String): DataFrame = {
- val colNames = schema.map { field =>
- val name = field.name
- if (name == existingName) Column(name).as(newName) else Column(name)
- }
- select(colNames :_*)
- }
-
- override def filter(condition: Column): DataFrame = {
- Filter(condition.expr, logicalPlan)
- }
-
- override def filter(conditionExpr: String): DataFrame = {
- filter(Column(new SqlParser().parseExpression(conditionExpr)))
- }
-
- override def where(condition: Column): DataFrame = {
- filter(condition)
- }
-
- override def apply(condition: Column): DataFrame = {
- filter(condition)
- }
-
- override def groupBy(cols: Column*): GroupedData = {
- new GroupedData(this, cols.map(_.expr))
- }
-
- override def groupBy(col1: String, cols: String*): GroupedData = {
- val colNames: Seq[String] = col1 +: cols
- new GroupedData(this, colNames.map(colName => resolve(colName)))
- }
-
- override def limit(n: Int): DataFrame = {
- Limit(Literal(n), logicalPlan)
- }
-
- override def unionAll(other: DataFrame): DataFrame = {
- Union(logicalPlan, other.logicalPlan)
- }
-
- override def intersect(other: DataFrame): DataFrame = {
- Intersect(logicalPlan, other.logicalPlan)
- }
-
- override def except(other: DataFrame): DataFrame = {
- Except(logicalPlan, other.logicalPlan)
- }
-
- override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = {
- Sample(fraction, withReplacement, seed, logicalPlan)
- }
-
- /////////////////////////////////////////////////////////////////////////////
- // RDD API
- /////////////////////////////////////////////////////////////////////////////
-
- override def head(n: Int): Array[Row] = limit(n).collect()
-
- override def head(): Row = head(1).head
-
- override def first(): Row = head()
-
- override def map[R: ClassTag](f: Row => R): RDD[R] = rdd.map(f)
-
- override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f)
-
- override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = {
- rdd.mapPartitions(f)
- }
-
- override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
-
- override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f)
-
- override def take(n: Int): Array[Row] = head(n)
-
- override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect()
-
- override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*)
-
- override def count(): Long = groupBy().count().rdd.collect().head.getLong(0)
-
- override def repartition(numPartitions: Int): DataFrame = {
- sqlContext.applySchema(rdd.repartition(numPartitions), schema)
- }
-
- override def distinct: DataFrame = Distinct(logicalPlan)
-
- override def persist(): this.type = {
- sqlContext.cacheManager.cacheQuery(this)
- this
- }
-
- override def persist(newLevel: StorageLevel): this.type = {
- sqlContext.cacheManager.cacheQuery(this, None, newLevel)
- this
- }
-
- override def unpersist(blocking: Boolean): this.type = {
- sqlContext.cacheManager.tryUncacheQuery(this, blocking)
- this
- }
-
- /////////////////////////////////////////////////////////////////////////////
- // I/O
- /////////////////////////////////////////////////////////////////////////////
-
- override def rdd: RDD[Row] = {
- val schema = this.schema
- queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema))
- }
-
- override def registerTempTable(tableName: String): Unit = {
- sqlContext.registerRDDAsTable(this, tableName)
- }
-
- override def saveAsParquetFile(path: String): Unit = {
- if (sqlContext.conf.parquetUseDataSourceApi) {
- save("org.apache.spark.sql.parquet", "path" -> path)
- } else {
- sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd
- }
- }
-
- override def saveAsTable(tableName: String): Unit = {
- val dataSourceName = sqlContext.conf.defaultDataSourceName
- val cmd =
- CreateTableUsingAsLogicalPlan(
- tableName,
- dataSourceName,
- temporary = false,
- Map.empty,
- allowExisting = false,
- logicalPlan)
-
- sqlContext.executePlan(cmd).toRdd
- }
-
- override def saveAsTable(
- tableName: String,
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = {
- val cmd =
- CreateTableUsingAsLogicalPlan(
- tableName,
- dataSourceName,
- temporary = false,
- (option +: options).toMap,
- allowExisting = false,
- logicalPlan)
-
- sqlContext.executePlan(cmd).toRdd
- }
-
- override def saveAsTable(
- tableName: String,
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- saveAsTable(tableName, dataSourceName, opts.head, opts.tail:_*)
- }
-
- override def save(path: String): Unit = {
- val dataSourceName = sqlContext.conf.defaultDataSourceName
- save(dataSourceName, "path" -> path)
- }
-
- override def save(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = {
- ResolvedDataSource(sqlContext, dataSourceName, (option +: options).toMap, this)
- }
-
- override def save(
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- save(dataSourceName, opts.head, opts.tail:_*)
- }
-
- override def insertInto(tableName: String, overwrite: Boolean): Unit = {
- sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)),
- Map.empty, logicalPlan, overwrite)).toRdd
- }
-
- override def toJSON: RDD[String] = {
- val rowSchema = this.schema
- this.mapPartitions { iter =>
- val jsonFactory = new JsonFactory()
- iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory))
- }
- }
-
- ////////////////////////////////////////////////////////////////////////////
- // for Python API
- ////////////////////////////////////////////////////////////////////////////
- protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
- SerDeUtil.javaToPython(jrdd)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
index f0e6a8f332188..d5d7e35a6b35d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala
@@ -20,8 +20,13 @@ package org.apache.spark.sql
import org.apache.spark.annotation.Experimental
/**
+ * :: Experimental ::
* Holder for experimental methods for the bravest. We make NO guarantee about the stability
* regarding binary compatibility and source compatibility of methods here.
+ *
+ * {{{
+ * sqlContext.experimental.extraStrategies += ...
+ * }}}
*/
@Experimental
class ExperimentalMethods protected[sql](sqlContext: SQLContext) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
index 3c20676355c9d..d00175265924c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
@@ -17,20 +17,24 @@
package org.apache.spark.sql
-import scala.language.implicitConversions
import scala.collection.JavaConversions._
+import scala.language.implicitConversions
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
+import org.apache.spark.sql.types.NumericType
/**
+ * :: Experimental ::
* A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]].
*/
-class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expression]) {
+@Experimental
+class GroupedData protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) {
- private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = {
+ private[this] implicit def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
val namedGroupingExprs = groupingExprs.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.toString)()
@@ -39,8 +43,25 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
df.sqlContext, Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan))
}
- private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = {
- df.numericColumns.map { c =>
+ private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => Expression)
+ : Seq[NamedExpression] = {
+
+ val columnExprs = if (colNames.isEmpty) {
+ // No columns specified. Use all numeric columns.
+ df.numericColumns
+ } else {
+ // Make sure all specified columns are numeric.
+ colNames.map { colName =>
+ val namedExpr = df.resolve(colName)
+ if (!namedExpr.dataType.isInstanceOf[NumericType]) {
+ throw new AnalysisException(
+ s""""$colName" is not a numeric column. """ +
+ "Aggregation function can only be applied on a numeric column.")
+ }
+ namedExpr
+ }
+ }
+ columnExprs.map { c =>
val a = f(c)
Alias(a, a.toString)()
}
@@ -52,7 +73,12 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
case "max" => Max
case "min" => Min
case "sum" => Sum
- case "count" | "size" => Count
+ case "count" | "size" =>
+ // Turn count(*) into count(1)
+ (inputExpr: Expression) => inputExpr match {
+ case s: Star => Count(Literal(1))
+ case _ => Count(inputExpr)
+ }
}
}
@@ -115,17 +141,17 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
* Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this
* class, the resulting [[DataFrame]] won't automatically include the grouping columns.
*
- * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]].
+ * The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
*
* {{{
* // Selects the age of the oldest employee and the aggregate expense for each department
*
* // Scala:
- * import org.apache.spark.sql.dsl._
+ * import org.apache.spark.sql.functions._
* df.groupBy("department").agg($"department", max($"age"), sum($"expense"))
*
* // Java:
- * import static org.apache.spark.sql.Dsl.*;
+ * import static org.apache.spark.sql.functions.*;
* df.groupBy("department").agg(col("department"), max(col("age")), sum(col("expense")));
* }}}
*/
@@ -142,35 +168,55 @@ class GroupedData protected[sql](df: DataFrameImpl, groupingExprs: Seq[Expressio
* Count the number of rows for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
*/
- def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")())
+ def count(): DataFrame = Seq(Alias(Count(Literal(1)), "count")())
/**
* Compute the average value for each numeric columns for each group. This is an alias for `avg`.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the average values for them.
*/
- def mean(): DataFrame = aggregateNumericColumns(Average)
-
+ @scala.annotation.varargs
+ def mean(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Average)
+ }
+
/**
* Compute the max value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the max values for them.
*/
- def max(): DataFrame = aggregateNumericColumns(Max)
+ @scala.annotation.varargs
+ def max(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Max)
+ }
/**
* Compute the mean value for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the mean values for them.
*/
- def avg(): DataFrame = aggregateNumericColumns(Average)
+ @scala.annotation.varargs
+ def avg(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Average)
+ }
/**
* Compute the min value for each numeric column for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the min values for them.
*/
- def min(): DataFrame = aggregateNumericColumns(Min)
+ @scala.annotation.varargs
+ def min(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Min)
+ }
/**
* Compute the sum for each numeric columns for each group.
* The resulting [[DataFrame]] will also contain the grouping columns.
+ * When specified columns are given, only compute the sum for them.
*/
- def sum(): DataFrame = aggregateNumericColumns(Sum)
+ @scala.annotation.varargs
+ def sum(colNames: String*): DataFrame = {
+ aggregateNumericColumns(colNames:_*)(Sum)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala b/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
deleted file mode 100644
index 0600dcc226b4d..0000000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/IncomputableColumn.scala
+++ /dev/null
@@ -1,186 +0,0 @@
-/*
-* Licensed to the Apache Software Foundation (ASF) under one or more
-* contributor license agreements. See the NOTICE file distributed with
-* this work for additional information regarding copyright ownership.
-* The ASF licenses this file to You under the Apache License, Version 2.0
-* (the "License"); you may not use this file except in compliance with
-* the License. You may obtain a copy of the License at
-*
-* http://www.apache.org/licenses/LICENSE-2.0
-*
-* Unless required by applicable law or agreed to in writing, software
-* distributed under the License is distributed on an "AS IS" BASIS,
-* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-* See the License for the specific language governing permissions and
-* limitations under the License.
-*/
-
-package org.apache.spark.sql
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.api.java.JavaRDD
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedStar}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.storage.StorageLevel
-import org.apache.spark.sql.types.StructType
-
-
-private[sql] class IncomputableColumn(protected[sql] val expr: Expression) extends Column {
-
- def this(name: String) = this(name match {
- case "*" => UnresolvedStar(None)
- case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2)))
- case _ => UnresolvedAttribute(name)
- })
-
- private def err[T](): T = {
- throw new UnsupportedOperationException("Cannot run this method on an UncomputableColumn")
- }
-
- override def isComputable: Boolean = false
-
- override val sqlContext: SQLContext = null
-
- override def queryExecution = err()
-
- protected[sql] override def logicalPlan: LogicalPlan = err()
-
- override def toDataFrame(colNames: String*): DataFrame = err()
-
- override def schema: StructType = err()
-
- override def dtypes: Array[(String, String)] = err()
-
- override def columns: Array[String] = err()
-
- override def printSchema(): Unit = err()
-
- override def show(): Unit = err()
-
- override def isLocal: Boolean = false
-
- override def join(right: DataFrame): DataFrame = err()
-
- override def join(right: DataFrame, joinExprs: Column): DataFrame = err()
-
- override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = err()
-
- override def sort(sortCol: String, sortCols: String*): DataFrame = err()
-
- override def sort(sortExprs: Column*): DataFrame = err()
-
- override def orderBy(sortCol: String, sortCols: String*): DataFrame = err()
-
- override def orderBy(sortExprs: Column*): DataFrame = err()
-
- override def col(colName: String): Column = err()
-
- override def apply(projection: Product): DataFrame = err()
-
- override def select(cols: Column*): DataFrame = err()
-
- override def select(col: String, cols: String*): DataFrame = err()
-
- override def selectExpr(exprs: String*): DataFrame = err()
-
- override def addColumn(colName: String, col: Column): DataFrame = err()
-
- override def renameColumn(existingName: String, newName: String): DataFrame = err()
-
- override def filter(condition: Column): DataFrame = err()
-
- override def filter(conditionExpr: String): DataFrame = err()
-
- override def where(condition: Column): DataFrame = err()
-
- override def apply(condition: Column): DataFrame = err()
-
- override def groupBy(cols: Column*): GroupedData = err()
-
- override def groupBy(col1: String, cols: String*): GroupedData = err()
-
- override def limit(n: Int): DataFrame = err()
-
- override def unionAll(other: DataFrame): DataFrame = err()
-
- override def intersect(other: DataFrame): DataFrame = err()
-
- override def except(other: DataFrame): DataFrame = err()
-
- override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = err()
-
- /////////////////////////////////////////////////////////////////////////////
-
- override def head(n: Int): Array[Row] = err()
-
- override def head(): Row = err()
-
- override def first(): Row = err()
-
- override def map[R: ClassTag](f: Row => R): RDD[R] = err()
-
- override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = err()
-
- override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = err()
-
- override def foreach(f: Row => Unit): Unit = err()
-
- override def foreachPartition(f: Iterator[Row] => Unit): Unit = err()
-
- override def take(n: Int): Array[Row] = err()
-
- override def collect(): Array[Row] = err()
-
- override def collectAsList(): java.util.List[Row] = err()
-
- override def count(): Long = err()
-
- override def repartition(numPartitions: Int): DataFrame = err()
-
- override def distinct: DataFrame = err()
-
- override def persist(): this.type = err()
-
- override def persist(newLevel: StorageLevel): this.type = err()
-
- override def unpersist(blocking: Boolean): this.type = err()
-
- override def rdd: RDD[Row] = err()
-
- override def registerTempTable(tableName: String): Unit = err()
-
- override def saveAsParquetFile(path: String): Unit = err()
-
- override def saveAsTable(tableName: String): Unit = err()
-
- override def saveAsTable(
- tableName: String,
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = err()
-
- override def saveAsTable(
- tableName: String,
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = err()
-
- override def save(path: String): Unit = err()
-
- override def save(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): Unit = err()
-
- override def save(
- dataSourceName: String,
- options: java.util.Map[String, String]): Unit = err()
-
- override def insertInto(tableName: String, overwrite: Boolean): Unit = err()
-
- override def toJSON: RDD[String] = err()
-
- protected[sql] override def javaToPython: JavaRDD[Array[Byte]] = err()
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index 180f5e765fb91..4815620c6fe57 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -50,10 +50,16 @@ private[spark] object SQLConf {
val THRIFTSERVER_POOL = "spark.sql.thriftserver.scheduler.pool"
// This is used to set the default data source
- val DEFAULT_DATA_SOURCE_NAME = "spark.sql.default.datasource"
+ val DEFAULT_DATA_SOURCE_NAME = "spark.sql.sources.default"
+ // This is used to control the when we will split a schema's JSON string to multiple pieces
+ // in order to fit the JSON string in metastore's table property (by default, the value has
+ // a length restriction of 4000 characters). We will split the JSON string of a schema
+ // to its length exceeds the threshold.
+ val SCHEMA_STRING_LENGTH_THRESHOLD = "spark.sql.sources.schemaStringLengthThreshold"
- // Whether to perform eager analysis on a DataFrame.
- val DATAFRAME_EAGER_ANALYSIS = "spark.sql.dataframe.eagerAnalysis"
+ // Whether to perform eager analysis when constructing a dataframe.
+ // Set to false when debugging requires the ability to look at invalid query plans.
+ val DATAFRAME_EAGER_ANALYSIS = "spark.sql.eagerAnalysis"
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
@@ -176,6 +182,11 @@ private[sql] class SQLConf extends Serializable {
private[spark] def defaultDataSourceName: String =
getConf(DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.parquet")
+ // Do not use a value larger than 4000 as the default value of this property.
+ // See the comments of SCHEMA_STRING_LENGTH_THRESHOLD above for more information.
+ private[spark] def schemaStringLengthThreshold: Int =
+ getConf(SCHEMA_STRING_LENGTH_THRESHOLD, "4000").toInt
+
private[spark] def dataFrameEagerAnalysis: Boolean =
getConf(DATAFRAME_EAGER_ANALYSIS, "true").toBoolean
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 97e3777f933e4..ce800e0754559 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -25,32 +25,36 @@ import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.annotation.{AlphaComponent, DeveloperApi, Experimental}
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer}
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, NoRelation}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
-import org.apache.spark.sql.execution._
+import org.apache.spark.sql.catalyst.{ScalaReflection, expressions}
+import org.apache.spark.sql.execution.{Filter, _}
import org.apache.spark.sql.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation}
import org.apache.spark.sql.json._
-import org.apache.spark.sql.sources.{BaseRelation, DDLParser, DataSourceStrategy, LogicalRelation, _}
+import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.{Partition, SparkContext}
/**
- * :: AlphaComponent ::
- * The entry point for running relational queries using Spark. Allows the creation of [[DataFrame]]
- * objects and the execution of SQL queries.
+ * The entry point for working with structured data (rows and columns) in Spark. Allows the
+ * creation of [[DataFrame]] objects as well as the execution of SQL queries.
*
- * @groupname userf Spark SQL Functions
+ * @groupname basic Basic Operations
+ * @groupname ddl_ops Persistent Catalog DDL
+ * @groupname cachemgmt Cached Table Management
+ * @groupname genericdata Generic Data Sources
+ * @groupname specificdata Specific Data Sources
+ * @groupname config Configuration
+ * @groupname dataframes Custom DataFrame Creation
* @groupname Ungrouped Support functions for language integrated queries.
*/
-@AlphaComponent
class SQLContext(@transient val sparkContext: SparkContext)
extends org.apache.spark.Logging
with Serializable {
@@ -62,24 +66,40 @@ class SQLContext(@transient val sparkContext: SparkContext)
// Note that this is a lazy val so we can override the default value in subclasses.
protected[sql] lazy val conf: SQLConf = new SQLConf
- /** Set Spark SQL configuration properties. */
+ /**
+ * Set Spark SQL configuration properties.
+ *
+ * @group config
+ */
def setConf(props: Properties): Unit = conf.setConf(props)
- /** Set the given Spark SQL configuration property. */
+ /**
+ * Set the given Spark SQL configuration property.
+ *
+ * @group config
+ */
def setConf(key: String, value: String): Unit = conf.setConf(key, value)
- /** Return the value of Spark SQL configuration property for the given key. */
+ /**
+ * Return the value of Spark SQL configuration property for the given key.
+ *
+ * @group config
+ */
def getConf(key: String): String = conf.getConf(key)
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
* yet, return `defaultValue`.
+ *
+ * @group config
*/
def getConf(key: String, defaultValue: String): String = conf.getConf(key, defaultValue)
/**
* Return all the configuration properties that have been set (i.e. not the default).
* This creates a new copy of the config properties in the form of a Map.
+ *
+ * @group config
*/
def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
@@ -92,7 +112,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, caseSensitive = true) {
- override val extendedRules =
+ override val extendedResolutionRules =
+ ExtractPythonUdfs ::
sources.PreInsertCastAndRename ::
Nil
}
@@ -101,7 +122,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] lazy val optimizer: Optimizer = DefaultOptimizer
@transient
- protected[sql] val ddlParser = new DDLParser
+ protected[sql] val ddlParser = new DDLParser(sqlParser.apply(_))
@transient
protected[sql] val sqlParser = {
@@ -122,15 +143,28 @@ class SQLContext(@transient val sparkContext: SparkContext)
case _ =>
}
+ @transient
protected[sql] val cacheManager = new CacheManager(this)
/**
+ * :: Experimental ::
* A collection of methods that are considered experimental, but can be used to hook into
- * the query planner for advanced functionalities.
+ * the query planner for advanced functionality.
+ *
+ * @group basic
*/
+ @Experimental
+ @transient
val experimental: ExperimentalMethods = new ExperimentalMethods(this)
- /** Returns a [[DataFrame]] with no rows or columns. */
+ /**
+ * :: Experimental ::
+ * Returns a [[DataFrame]] with no rows or columns.
+ *
+ * @group basic
+ */
+ @Experimental
+ @transient
lazy val emptyDataFrame = DataFrame(this, NoRelation)
/**
@@ -158,48 +192,125 @@ class SQLContext(@transient val sparkContext: SparkContext)
* (Integer arg1, String arg2) -> arg2 + arg1),
* DataTypes.StringType);
* }}}
+ *
+ * @group basic
*/
+ @transient
val udf: UDFRegistration = new UDFRegistration(this)
- /** Returns true if the table is currently cached in-memory. */
+ /**
+ * Returns true if the table is currently cached in-memory.
+ * @group cachemgmt
+ */
def isCached(tableName: String): Boolean = cacheManager.isCached(tableName)
- /** Caches the specified table in-memory. */
+ /**
+ * Caches the specified table in-memory.
+ * @group cachemgmt
+ */
def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName)
- /** Removes the specified table from the in-memory cache. */
+ /**
+ * Removes the specified table from the in-memory cache.
+ * @group cachemgmt
+ */
def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName)
+ /**
+ * Removes all cached tables from the in-memory cache.
+ */
+ def clearCache(): Unit = cacheManager.clearCache()
+
// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
/**
- * (Scala-specific)
- * Implicit methods available in Scala for converting common Scala objects into [[DataFrame]]s.
+ * :: Experimental ::
+ * (Scala-specific) Implicit methods available in Scala for converting
+ * common Scala objects into [[DataFrame]]s.
+ *
+ * {{{
+ * val sqlContext = new SQLContext
+ * import sqlContext._
+ * }}}
+ *
+ * @group basic
*/
- object implicits {
+ @Experimental
+ object implicits extends Serializable {
// scalastyle:on
- /**
- * Creates a DataFrame from an RDD of case classes.
- *
- * @group userf
- */
- implicit def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
- self.createDataFrame(rdd)
+
+ /** Converts $"col name" into an [[Column]]. */
+ implicit class StringToColumn(val sc: StringContext) {
+ def $(args: Any*): ColumnName = {
+ new ColumnName(sc.s(args :_*))
+ }
}
- /**
- * Creates a DataFrame from a local Seq of Product.
- */
- implicit def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
- self.createDataFrame(data)
+ /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
+ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
+ /** Creates a DataFrame from an RDD of case classes or tuples. */
+ implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
+ DataFrameHolder(self.createDataFrame(rdd))
+ }
+
+ /** Creates a DataFrame from a local Seq of Product. */
+ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
+ {
+ DataFrameHolder(self.createDataFrame(data))
+ }
+
+ // Do NOT add more implicit conversions. They are likely to break source compatibility by
+ // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
+ // because of [[DoubleRDDFunctions]].
+
+ /** Creates a single column DataFrame from an RDD[Int]. */
+ implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
+ val dataType = IntegerType
+ val rows = data.mapPartitions { iter =>
+ val row = new SpecificMutableRow(dataType :: Nil)
+ iter.map { v =>
+ row.setInt(0, v)
+ row: Row
+ }
+ }
+ DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
+ }
+
+ /** Creates a single column DataFrame from an RDD[Long]. */
+ implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
+ val dataType = LongType
+ val rows = data.mapPartitions { iter =>
+ val row = new SpecificMutableRow(dataType :: Nil)
+ iter.map { v =>
+ row.setLong(0, v)
+ row: Row
+ }
+ }
+ DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
+ }
+
+ /** Creates a single column DataFrame from an RDD[String]. */
+ implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
+ val dataType = StringType
+ val rows = data.mapPartitions { iter =>
+ val row = new SpecificMutableRow(dataType :: Nil)
+ iter.map { v =>
+ row.setString(0, v)
+ row: Row
+ }
+ }
+ DataFrameHolder(self.createDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
}
}
/**
+ * :: Experimental ::
* Creates a DataFrame from an RDD of case classes.
*
- * @group userf
+ * @group dataframes
*/
+ @Experimental
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
@@ -209,8 +320,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
/**
+ * :: Experimental ::
* Creates a DataFrame from a local Seq of Product.
+ *
+ * @group dataframes
*/
+ @Experimental
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
SparkPlan.currentContext.set(self)
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
@@ -220,6 +335,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]].
+ *
+ * @group dataframes
*/
def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
DataFrame(this, LogicalRelation(baseRelation))
@@ -227,12 +344,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* :: DeveloperApi ::
- * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
+ * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s using the given schema.
* It is important to make sure that the structure of every [[Row]] of the provided RDD matches
* the provided schema. Otherwise, there will be runtime exception.
* Example:
* {{{
* import org.apache.spark.sql._
+ * import org.apache.spark.sql.types._
* val sqlContext = new org.apache.spark.sql.SQLContext(sc)
*
* val schema =
@@ -243,7 +361,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* val people =
* sc.textFile("examples/src/main/resources/people.txt").map(
* _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
- * val dataFrame = sqlContext. applySchema(people, schema)
+ * val dataFrame = sqlContext.createDataFrame(people, schema)
* dataFrame.printSchema
* // root
* // |-- name: string (nullable = false)
@@ -253,19 +371,41 @@ class SQLContext(@transient val sparkContext: SparkContext)
* sqlContext.sql("select name from people").collect.foreach(println)
* }}}
*
- * @group userf
+ * @group dataframes
*/
@DeveloperApi
- def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
// TODO: use MutableProjection when rowRDD is another DataFrame and the applied
// schema differs from the existing schema on any field data type.
val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self)
DataFrame(this, logicalPlan)
}
+ /**
+ * :: DeveloperApi ::
+ * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema.
+ * It is important to make sure that the structure of every [[Row]] of the provided RDD matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ *
+ * @group dataframes
+ */
@DeveloperApi
- def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
- applySchema(rowRDD.rdd, schema);
+ def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD.rdd, schema)
+ }
+
+ /**
+ * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s by applying
+ * a seq of names of columns to this RDD, the data type for each column will
+ * be inferred by the first row.
+ *
+ * @param rowRDD an JavaRDD of Row
+ * @param columns names for each column
+ * @return DataFrame
+ * @group dataframes
+ */
+ def createDataFrame(rowRDD: JavaRDD[Row], columns: java.util.List[String]): DataFrame = {
+ createDataFrame(rowRDD.rdd, columns.toSeq)
}
/**
@@ -273,8 +413,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
+ * @group dataframes
*/
- def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
val attributeSeq = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
@@ -295,35 +436,101 @@ class SQLContext(@transient val sparkContext: SparkContext)
DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this))
}
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ * @group dataframes
+ */
+ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd.rdd, beanClass)
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD.
+ * It is important to make sure that the structure of every [[Row]] of the provided RDD matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ * import org.apache.spark.sql.types._
+ * val sqlContext = new org.apache.spark.sql.SQLContext(sc)
+ *
+ * val schema =
+ * StructType(
+ * StructField("name", StringType, false) ::
+ * StructField("age", IntegerType, true) :: Nil)
+ *
+ * val people =
+ * sc.textFile("examples/src/main/resources/people.txt").map(
+ * _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
+ * val dataFrame = sqlContext. applySchema(people, schema)
+ * dataFrame.printSchema
+ * // root
+ * // |-- name: string (nullable = false)
+ * // |-- age: integer (nullable = true)
+ *
+ * dataFrame.registerTempTable("people")
+ * sqlContext.sql("select name from people").collect.foreach(println)
+ * }}}
+ */
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema)
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ */
+ @deprecated("use createDataFrame", "1.3.0")
+ def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd, beanClass)
+ }
+
/**
* Applies a schema to an RDD of Java Beans.
*
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*/
+ @deprecated("use createDataFrame", "1.3.0")
def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
- applySchema(rdd.rdd, beanClass)
+ createDataFrame(rdd, beanClass)
}
/**
- * Loads a Parquet file, returning the result as a [[DataFrame]].
+ * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty
+ * [[DataFrame]] if no paths are passed in.
*
- * @group userf
+ * @group specificdata
*/
@scala.annotation.varargs
- def parquetFile(path: String, paths: String*): DataFrame =
- if (conf.parquetUseDataSourceApi) {
- baseRelationToDataFrame(parquet.ParquetRelation2(path +: paths, Map.empty)(this))
+ def parquetFile(paths: String*): DataFrame = {
+ if (paths.isEmpty) {
+ emptyDataFrame
+ } else if (conf.parquetUseDataSourceApi) {
+ baseRelationToDataFrame(parquet.ParquetRelation2(paths, Map.empty)(this))
} else {
DataFrame(this, parquet.ParquetRelation(
paths.mkString(","), Some(sparkContext.hadoopConfiguration), this))
}
+ }
/**
* Loads a JSON file (one object per line), returning the result as a [[DataFrame]].
* It goes through the entire dataset once to determine the schema.
*
- * @group userf
+ * @group specificdata
*/
def jsonFile(path: String): DataFrame = jsonFile(path, 1.0)
@@ -332,7 +539,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Loads a JSON file (one object per line) and applies the given schema,
* returning the result as a [[DataFrame]].
*
- * @group userf
+ * @group specificdata
*/
@Experimental
def jsonFile(path: String, schema: StructType): DataFrame = {
@@ -342,6 +549,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* :: Experimental ::
+ * @group specificdata
*/
@Experimental
def jsonFile(path: String, samplingRatio: Double): DataFrame = {
@@ -354,10 +562,18 @@ class SQLContext(@transient val sparkContext: SparkContext)
* [[DataFrame]].
* It goes through the entire dataset once to determine the schema.
*
- * @group userf
+ * @group specificdata
*/
def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0)
+
+ /**
+ * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a
+ * [[DataFrame]].
+ * It goes through the entire dataset once to determine the schema.
+ *
+ * @group specificdata
+ */
def jsonRDD(json: JavaRDD[String]): DataFrame = jsonRDD(json.rdd, 1.0)
/**
@@ -365,7 +581,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema,
* returning the result as a [[DataFrame]].
*
- * @group userf
+ * @group specificdata
*/
@Experimental
def jsonRDD(json: RDD[String], schema: StructType): DataFrame = {
@@ -375,9 +591,16 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
+ /**
+ * :: Experimental ::
+ * Loads an JavaRDD storing JSON objects (one object per record) and applies the given
+ * schema, returning the result as a [[DataFrame]].
+ *
+ * @group specificdata
+ */
@Experimental
def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = {
jsonRDD(json.rdd, schema)
@@ -385,6 +608,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
/**
* :: Experimental ::
+ * Loads an RDD[String] storing JSON objects (one object per record) inferring the
+ * schema, returning the result as a [[DataFrame]].
+ *
+ * @group specificdata
*/
@Experimental
def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = {
@@ -393,91 +620,279 @@ class SQLContext(@transient val sparkContext: SparkContext)
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
- applySchema(rowRDD, appliedSchema)
+ createDataFrame(rowRDD, appliedSchema)
}
+ /**
+ * :: Experimental ::
+ * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the
+ * schema, returning the result as a [[DataFrame]].
+ *
+ * @group specificdata
+ */
@Experimental
def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = {
jsonRDD(json.rdd, samplingRatio);
}
+ /**
+ * :: Experimental ::
+ * Returns the dataset stored at path as a DataFrame,
+ * using the default data source configured by spark.sql.sources.default.
+ *
+ * @group genericdata
+ */
@Experimental
def load(path: String): DataFrame = {
val dataSourceName = conf.defaultDataSourceName
- load(dataSourceName, ("path", path))
+ load(path, dataSourceName)
+ }
+
+ /**
+ * :: Experimental ::
+ * Returns the dataset stored at path as a DataFrame, using the given data source.
+ *
+ * @group genericdata
+ */
+ @Experimental
+ def load(path: String, source: String): DataFrame = {
+ load(source, Map("path" -> path))
}
+ /**
+ * :: Experimental ::
+ * (Java-specific) Returns the dataset specified by the given data source and
+ * a set of options as a DataFrame.
+ *
+ * @group genericdata
+ */
@Experimental
- def load(
- dataSourceName: String,
- option: (String, String),
- options: (String, String)*): DataFrame = {
- val resolved = ResolvedDataSource(this, None, dataSourceName, (option +: options).toMap)
+ def load(source: String, options: java.util.Map[String, String]): DataFrame = {
+ load(source, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific) Returns the dataset specified by the given data source and
+ * a set of options as a DataFrame.
+ *
+ * @group genericdata
+ */
+ @Experimental
+ def load(source: String, options: Map[String, String]): DataFrame = {
+ val resolved = ResolvedDataSource(this, None, source, options)
DataFrame(this, LogicalRelation(resolved.relation))
}
+ /**
+ * :: Experimental ::
+ * (Java-specific) Returns the dataset specified by the given data source and
+ * a set of options as a DataFrame, using the given schema as the schema of the DataFrame.
+ *
+ * @group genericdata
+ */
@Experimental
def load(
- dataSourceName: String,
+ source: String,
+ schema: StructType,
+ options: java.util.Map[String, String]): DataFrame = {
+ load(source, schema, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific) Returns the dataset specified by the given data source and
+ * a set of options as a DataFrame, using the given schema as the schema of the DataFrame.
+ * @group genericdata
+ */
+ @Experimental
+ def load(
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): DataFrame = {
+ val resolved = ResolvedDataSource(this, Some(schema), source, options)
+ DataFrame(this, LogicalRelation(resolved.relation))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path and returns the corresponding DataFrame.
+ * It will use the default data source configured by spark.sql.sources.default.
+ *
+ * @group ddl_ops
+ */
+ @Experimental
+ def createExternalTable(tableName: String, path: String): DataFrame = {
+ val dataSourceName = conf.defaultDataSourceName
+ createExternalTable(tableName, path, dataSourceName)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path based on a data source
+ * and returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ path: String,
+ source: String): DataFrame = {
+ createExternalTable(tableName, source, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path based on a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
options: java.util.Map[String, String]): DataFrame = {
- val opts = options.toSeq
- load(dataSourceName, opts.head, opts.tail:_*)
+ createExternalTable(tableName, source, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Creates an external table from the given path based on a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ options: Map[String, String]): DataFrame = {
+ val cmd =
+ CreateTableUsing(
+ tableName,
+ userSpecifiedSchema = None,
+ source,
+ temporary = false,
+ options,
+ allowExisting = false,
+ managedIfNoPath = false)
+ executePlan(cmd).toRdd
+ table(tableName)
+ }
+
+ /**
+ * :: Experimental ::
+ * Create an external table from the given path based on a data source, a schema and
+ * a set of options. Then, returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: java.util.Map[String, String]): DataFrame = {
+ createExternalTable(tableName, source, schema, options.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Create an external table from the given path based on a data source, a schema and
+ * a set of options. Then, returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): DataFrame = {
+ val cmd =
+ CreateTableUsing(
+ tableName,
+ userSpecifiedSchema = Some(schema),
+ source,
+ temporary = false,
+ options,
+ allowExisting = false,
+ managedIfNoPath = false)
+ executePlan(cmd).toRdd
+ table(tableName)
}
/**
* :: Experimental ::
- * Construct an RDD representing the database table accessible via JDBC URL
+ * Construct a [[DataFrame]] representing the database table accessible via JDBC URL
* url named table.
+ *
+ * @group specificdata
*/
@Experimental
- def jdbcRDD(url: String, table: String): DataFrame = {
- jdbcRDD(url, table, null.asInstanceOf[JDBCPartitioningInfo])
+ def jdbc(url: String, table: String): DataFrame = {
+ jdbc(url, table, JDBCRelation.columnPartition(null))
}
/**
* :: Experimental ::
- * Construct an RDD representing the database table accessible via JDBC URL
- * url named table. The PartitioningInfo parameter
- * gives the name of a column of integral type, a number of partitions, and
- * advisory minimum and maximum values for the column. The RDD is
- * partitioned according to said column.
+ * Construct a [[DataFrame]] representing the database table accessible via JDBC URL
+ * url named table. Partitions of the table will be retrieved in parallel based on the parameters
+ * passed to this function.
+ *
+ * @param columnName the name of a column of integral type that will be used for partitioning.
+ * @param lowerBound the minimum value of `columnName` to retrieve
+ * @param upperBound the maximum value of `columnName` to retrieve
+ * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split
+ * evenly into this many partitions
+ *
+ * @group specificdata
*/
@Experimental
- def jdbcRDD(url: String, table: String, partitioning: JDBCPartitioningInfo):
- DataFrame = {
+ def jdbc(
+ url: String,
+ table: String,
+ columnName: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int): DataFrame = {
+ val partitioning = JDBCPartitioningInfo(columnName, lowerBound, upperBound, numPartitions)
val parts = JDBCRelation.columnPartition(partitioning)
- jdbcRDD(url, table, parts)
+ jdbc(url, table, parts)
}
/**
* :: Experimental ::
- * Construct an RDD representing the database table accessible via JDBC URL
+ * Construct a [[DataFrame]] representing the database table accessible via JDBC URL
* url named table. The theParts parameter gives a list expressions
* suitable for inclusion in WHERE clauses; each one defines one partition
- * of the RDD.
+ * of the [[DataFrame]].
+ *
+ * @group specificdata
*/
@Experimental
- def jdbcRDD(url: String, table: String, theParts: Array[String]):
- DataFrame = {
- val parts: Array[Partition] = theParts.zipWithIndex.map(
- x => JDBCPartition(x._1, x._2).asInstanceOf[Partition])
- jdbcRDD(url, table, parts)
+ def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = {
+ val parts: Array[Partition] = theParts.zipWithIndex.map { case (part, i) =>
+ JDBCPartition(part, i) : Partition
+ }
+ jdbc(url, table, parts)
}
- private def jdbcRDD(url: String, table: String, parts: Array[Partition]):
- DataFrame = {
+ private def jdbc(url: String, table: String, parts: Array[Partition]): DataFrame = {
val relation = JDBCRelation(url, table, parts)(this)
baseRelationToDataFrame(relation)
}
/**
- * Registers the given RDD as a temporary table in the catalog. Temporary tables exist only
- * during the lifetime of this instance of SQLContext.
- *
- * @group userf
+ * Registers the given [[DataFrame]] as a temporary table in the catalog. Temporary tables exist
+ * only during the lifetime of this instance of SQLContext.
*/
- def registerRDDAsTable(rdd: DataFrame, tableName: String): Unit = {
- catalog.registerTable(Seq(tableName), rdd.logicalPlan)
+ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
+ catalog.registerTable(Seq(tableName), df.logicalPlan)
}
/**
@@ -486,7 +901,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
*
* @param tableName the name of the table to be unregistered.
*
- * @group userf
+ * @group basic
*/
def dropTempTable(tableName: String): Unit = {
cacheManager.tryUncacheQuery(table(tableName))
@@ -497,7 +912,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is
* used for SQL parsing can be configured with 'spark.sql.dialect'.
*
- * @group userf
+ * @group basic
*/
def sql(sqlText: String): DataFrame = {
if (conf.dialect == "sql") {
@@ -507,10 +922,58 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
}
- /** Returns the specified table as a [[DataFrame]]. */
+ /**
+ * Returns the specified table as a [[DataFrame]].
+ *
+ * @group ddl_ops
+ */
def table(tableName: String): DataFrame =
DataFrame(this, catalog.lookupRelation(Seq(tableName)))
+ /**
+ * Returns a [[DataFrame]] containing names of existing tables in the current database.
+ * The returned DataFrame has two columns, tableName and isTemporary (a Boolean
+ * indicating if a table is a temporary one or not).
+ *
+ * @group ddl_ops
+ */
+ def tables(): DataFrame = {
+ DataFrame(this, ShowTablesCommand(None))
+ }
+
+ /**
+ * Returns a [[DataFrame]] containing names of existing tables in the given database.
+ * The returned DataFrame has two columns, tableName and isTemporary (a Boolean
+ * indicating if a table is a temporary one or not).
+ *
+ * @group ddl_ops
+ */
+ def tables(databaseName: String): DataFrame = {
+ DataFrame(this, ShowTablesCommand(Some(databaseName)))
+ }
+
+ /**
+ * Returns the names of tables in the current database as an array.
+ *
+ * @group ddl_ops
+ */
+ def tableNames(): Array[String] = {
+ catalog.getTables(None).map {
+ case (tableName, _) => tableName
+ }.toArray
+ }
+
+ /**
+ * Returns the names of tables in the given database as an array.
+ *
+ * @group ddl_ops
+ */
+ def tableNames(databaseName: String): Array[String] = {
+ catalog.getTables(Some(databaseName)).map {
+ case (tableName, _) => tableName
+ }.toArray
+ }
+
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext: SparkContext = self.sparkContext
@@ -555,7 +1018,8 @@ class SQLContext(@transient val sparkContext: SparkContext)
val projectSet = AttributeSet(projectList.flatMap(_.references))
val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
- val filterCondition = prunePushedDownFilters(filterPredicates).reduceLeftOption(And)
+ val filterCondition =
+ prunePushedDownFilters(filterPredicates).reduceLeftOption(expressions.And)
// Right now we still use a projection even if the only evaluation is applying an alias
// to a column. Since this is a no-op, it could be avoided. However, using this
@@ -592,6 +1056,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
Batch("Add exchange", Once, AddExchange(self)) :: Nil
}
+ @transient
+ protected[sql] lazy val checkAnalysis = new CheckAnalysis {
+ override val extendedCheckRules = Seq(
+ sources.PreWriteCheck(catalog)
+ )
+ }
+
/**
* :: DeveloperApi ::
* The primary workflow for executing relational queries using Spark. Designed to allow easy
@@ -599,9 +1070,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@DeveloperApi
protected[sql] class QueryExecution(val logical: LogicalPlan) {
+ def assertAnalyzed(): Unit = checkAnalysis(analyzed)
- lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical))
- lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed)
+ lazy val analyzed: LogicalPlan = analyzer(logical)
+ lazy val withCachedData: LogicalPlan = {
+ assertAnalyzed
+ cacheManager.useCachedData(analyzed)
+ }
lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData)
// TODO: Don't just pick the first one...
@@ -670,6 +1145,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
def needsConversion(dataType: DataType): Boolean = dataType match {
case ByteType => true
case ShortType => true
+ case LongType => true
case FloatType => true
case DateType => true
case TimestampType => true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
index f1a4053b79113..5921eaf5e63f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSQLParser.scala
@@ -23,7 +23,7 @@ import scala.util.parsing.combinator.RegexParsers
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.{UncacheTableCommand, CacheTableCommand, SetCommand}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.types.StringType
@@ -57,12 +57,16 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr
protected val AS = Keyword("AS")
protected val CACHE = Keyword("CACHE")
+ protected val CLEAR = Keyword("CLEAR")
+ protected val IN = Keyword("IN")
protected val LAZY = Keyword("LAZY")
protected val SET = Keyword("SET")
+ protected val SHOW = Keyword("SHOW")
protected val TABLE = Keyword("TABLE")
+ protected val TABLES = Keyword("TABLES")
protected val UNCACHE = Keyword("UNCACHE")
- override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | others
+ override protected lazy val start: Parser[LogicalPlan] = cache | uncache | set | show | others
private lazy val cache: Parser[LogicalPlan] =
CACHE ~> LAZY.? ~ (TABLE ~> ident) ~ (AS ~> restInput).? ^^ {
@@ -71,15 +75,22 @@ private[sql] class SparkSQLParser(fallback: String => LogicalPlan) extends Abstr
}
private lazy val uncache: Parser[LogicalPlan] =
- UNCACHE ~ TABLE ~> ident ^^ {
- case tableName => UncacheTableCommand(tableName)
- }
+ ( UNCACHE ~ TABLE ~> ident ^^ {
+ case tableName => UncacheTableCommand(tableName)
+ }
+ | CLEAR ~ CACHE ^^^ ClearCacheCommand
+ )
private lazy val set: Parser[LogicalPlan] =
SET ~> restInput ^^ {
case input => SetCommandParser(input)
}
+ private lazy val show: Parser[LogicalPlan] =
+ SHOW ~> TABLES ~ (IN ~> ident).? ^^ {
+ case _ ~ dbName => ShowTablesCommand(dbName)
+ }
+
private lazy val others: Parser[LogicalPlan] =
wholeInput ^^ {
case input => fallback(input)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
index d8b0a3b26dbab..8051df299252c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
@@ -32,9 +32,9 @@ import org.apache.spark.sql.types.DataType
/**
- * Functions for registering user-defined functions.
+ * Functions for registering user-defined functions. Use [[SQLContext.udf]] to access this.
*/
-class UDFRegistration(sqlContext: SQLContext) extends Logging {
+class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging {
private val functionRegistry = sqlContext.functionRegistry
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
index c60d4070942a9..295db539adfc4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.PythonUDF
import org.apache.spark.sql.types.DataType
/**
- * A user-defined function. To create one, use the `udf` functions in [[Dsl]].
+ * A user-defined function. To create one, use the `udf` functions in [[functions]].
* As an example:
* {{{
* // Defined a UDF that returns true or false based on some numeric score.
@@ -37,7 +37,7 @@ import org.apache.spark.sql.types.DataType
* df.select( predict(df("score")) )
* }}}
*/
-case class UserDefinedFunction(f: AnyRef, dataType: DataType) {
+case class UserDefinedFunction protected[sql] (f: AnyRef, dataType: DataType) {
def apply(exprs: Column*): Column = {
Column(ScalaUdf(f, dataType, exprs.map(_.expr)))
@@ -45,7 +45,7 @@ case class UserDefinedFunction(f: AnyRef, dataType: DataType) {
}
/**
- * A user-defined Python function. To create one, use the `pythonUDF` functions in [[Dsl]].
+ * A user-defined Python function. To create one, use the `pythonUDF` functions in [[functions]].
* This is used by Python API.
*/
private[sql] case class UserDefinedPythonFunction(
@@ -58,6 +58,7 @@ private[sql] case class UserDefinedPythonFunction(
accumulator: Accumulator[JList[Array[Byte]]],
dataType: DataType) {
+ /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */
def apply(exprs: Column*): Column = {
val udf = PythonUDF(name, command, envVars, pythonIncludes, pythonExec, broadcastVars,
accumulator, dataType, exprs.map(_.expr))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/package.scala
new file mode 100644
index 0000000000000..cbbd005228d44
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/package.scala
@@ -0,0 +1,23 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+/**
+ * Contains API classes that are specific to a single language (i.e. Java).
+ */
+package object api
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
old mode 100755
new mode 100644
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
index d6d8258f46a9a..d3a18b37d52b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Attribute
@@ -30,7 +31,9 @@ case class LocalTableScan(output: Seq[Attribute], rows: Seq[Row]) extends LeafNo
override def execute() = rdd
- override def executeCollect() = rows.toArray
+ override def executeCollect() =
+ rows.map(ScalaReflection.convertRowToScala(_, schema)).toArray
- override def executeTake(limit: Int) = rows.take(limit).toArray
+ override def executeTake(limit: Int) =
+ rows.map(ScalaReflection.convertRowToScala(_, schema)).take(limit).toArray
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 81bcf5a6f32dd..5281c7502556a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -309,7 +309,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
object DDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false) =>
+ case CreateTableUsing(tableName, userSpecifiedSchema, provider, true, opts, false, _) =>
ExecutedCommand(
CreateTempTableUsing(
tableName, userSpecifiedSchema, provider, opts)) :: Nil
@@ -318,29 +318,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case c: CreateTableUsing if c.temporary && c.allowExisting =>
sys.error("allowExisting should be set to false when creating a temporary table.")
- case CreateTableUsingAsSelect(tableName, provider, true, opts, false, query) =>
- val logicalPlan = sqlContext.parseSql(query)
+ case CreateTableUsingAsSelect(tableName, provider, true, mode, opts, query) =>
val cmd =
- CreateTempTableUsingAsSelect(tableName, provider, opts, logicalPlan)
+ CreateTempTableUsingAsSelect(tableName, provider, mode, opts, query)
ExecutedCommand(cmd) :: Nil
case c: CreateTableUsingAsSelect if !c.temporary =>
sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
- case c: CreateTableUsingAsSelect if c.temporary && c.allowExisting =>
- sys.error("allowExisting should be set to false when creating a temporary table.")
-
- case CreateTableUsingAsLogicalPlan(tableName, provider, true, opts, false, query) =>
- val cmd =
- CreateTempTableUsingAsSelect(tableName, provider, opts, query)
- ExecutedCommand(cmd) :: Nil
- case c: CreateTableUsingAsLogicalPlan if !c.temporary =>
- sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.")
- case c: CreateTableUsingAsLogicalPlan if c.temporary && c.allowExisting =>
- sys.error("allowExisting should be set to false when creating a temporary table.")
-
- case LogicalDescribeCommand(table, isExtended) =>
- val resultPlan = self.sqlContext.executePlan(table).executedPlan
- ExecutedCommand(
- RunnableDescribeCommand(resultPlan, resultPlan.output, isExtended)) :: Nil
case LogicalDescribeCommand(table, isExtended) =>
val resultPlan = self.sqlContext.executePlan(table).executedPlan
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4dc506c21ab9e..710268584cff1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -134,13 +134,15 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
val ord = new RowOrdering(sortOrder, child.output)
+ private def collectData() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+
// TODO: Is this copying for no reason?
- override def executeCollect() = child.execute().map(_.copy()).takeOrdered(limit)(ord)
- .map(ScalaReflection.convertRowToScala(_, this.schema))
+ override def executeCollect() =
+ collectData().map(ScalaReflection.convertRowToScala(_, this.schema))
// TODO: Terminal split should be implemented differently from non-terminal split.
// TODO: Pick num splits based on |limit|.
- override def execute() = sparkContext.makeRDD(executeCollect(), 1)
+ override def execute() = sparkContext.makeRDD(collectData(), 1)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
index 335757087deef..a11232142d0fb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -20,9 +20,10 @@ package org.apache.spark.sql.execution
import org.apache.spark.Logging
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types.{BooleanType, StructField, StructType, StringType}
import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
import org.apache.spark.sql.catalyst.errors.TreeNodeException
-import org.apache.spark.sql.catalyst.expressions.{Row, Attribute}
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Row, Attribute}
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import scala.collection.mutable.ArrayBuffer
@@ -116,7 +117,9 @@ case class SetCommand(
@DeveloperApi
case class ExplainCommand(
logicalPlan: LogicalPlan,
- override val output: Seq[Attribute], extended: Boolean = false) extends RunnableCommand {
+ override val output: Seq[Attribute] =
+ Seq(AttributeReference("plan", StringType, nullable = false)()),
+ extended: Boolean = false) extends RunnableCommand {
// Run through the optimizer to generate the physical plan.
override def run(sqlContext: SQLContext) = try {
@@ -141,7 +144,7 @@ case class CacheTableCommand(
override def run(sqlContext: SQLContext) = {
plan.foreach { logicalPlan =>
- sqlContext.registerRDDAsTable(DataFrame(sqlContext, logicalPlan), tableName)
+ sqlContext.registerDataFrameAsTable(DataFrame(sqlContext, logicalPlan), tableName)
}
sqlContext.cacheTable(tableName)
@@ -171,6 +174,21 @@ case class UncacheTableCommand(tableName: String) extends RunnableCommand {
override def output: Seq[Attribute] = Seq.empty
}
+/**
+ * :: DeveloperApi ::
+ * Clear all cached data from the in-memory cache.
+ */
+@DeveloperApi
+case object ClearCacheCommand extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext) = {
+ sqlContext.clearCache()
+ Seq.empty[Row]
+ }
+
+ override def output: Seq[Attribute] = Seq.empty
+}
+
/**
* :: DeveloperApi ::
*/
@@ -188,3 +206,35 @@ case class DescribeCommand(
}
}
}
+
+/**
+ * A command for users to get tables in the given database.
+ * If a databaseName is not given, the current database will be used.
+ * The syntax of using this command in SQL is:
+ * {{{
+ * SHOW TABLES [IN databaseName]
+ * }}}
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand {
+
+ // The result of SHOW TABLES has two columns, tableName and isTemporary.
+ override val output = {
+ val schema = StructType(
+ StructField("tableName", StringType, false) ::
+ StructField("isTemporary", BooleanType, false) :: Nil)
+
+ schema.toAttributes
+ }
+
+ override def run(sqlContext: SQLContext) = {
+ // Since we need to return a Seq of rows, we will call getTables directly
+ // instead of calling tables in sqlContext.
+ val rows = sqlContext.catalog.getTables(databaseName).map {
+ case (tableName, isTemporary) => Row(tableName, isTemporary)
+ }
+
+ rows
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 5cc67cdd13944..ffe388cfa9532 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.HashSet
import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{SQLConf, SQLContext, DataFrame, Row}
import org.apache.spark.sql.catalyst.trees.TreeNodeRef
import org.apache.spark.sql.types._
@@ -32,11 +32,22 @@ import org.apache.spark.sql.types._
*
* Usage:
* {{{
- * sql("SELECT key FROM src").debug
+ * import org.apache.spark.sql.execution.debug._
+ * sql("SELECT key FROM src").debug()
+ * dataFrame.typeCheck()
* }}}
*/
package object debug {
+ /**
+ * Augments [[SQLContext]] with debug methods.
+ */
+ implicit class DebugSQLContext(sqlContext: SQLContext) {
+ def debug() = {
+ sqlContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false")
+ }
+ }
+
/**
* :: DeveloperApi ::
* Augments [[DataFrame]]s with debug methods.
@@ -135,11 +146,9 @@ package object debug {
}
/**
- * :: DeveloperApi ::
* Helper functions for checking that runtime types match a given schema.
*/
- @DeveloperApi
- object TypeCheck {
+ private[sql] object TypeCheck {
def typeCheck(data: Any, schema: DataType): Unit = (data, schema) match {
case (null, _) =>
@@ -159,16 +168,15 @@ package object debug {
case (_: Short, ShortType) =>
case (_: Boolean, BooleanType) =>
case (_: Double, DoubleType) =>
+ case (v, udt: UserDefinedType[_]) => typeCheck(v, udt.sqlType)
case (d, t) => sys.error(s"Invalid data found: got $d (${d.getClass}) expected $t")
}
}
/**
- * :: DeveloperApi ::
* Augments [[DataFrame]]s with debug methods.
*/
- @DeveloperApi
private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan {
import TypeCheck._
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index 3a2f8d75dac5e..33632b8e82ff9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -186,6 +186,7 @@ object EvaluatePython {
case (c: Int, ShortType) => c.toShort
case (c: Long, ShortType) => c.toShort
case (c: Long, IntegerType) => c.toInt
+ case (c: Int, LongType) => c.toLong
case (c: Double, FloatType) => c.toFloat
case (c, StringType) if !c.isInstanceOf[String] => c.toString
@@ -205,6 +206,9 @@ case class EvaluatePython(
extends logical.UnaryNode {
def output = child.output :+ resultAttribute
+
+ // References should not include the produced attribute.
+ override def references = udf.references
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
similarity index 73%
rename from sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
rename to sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 6bf21dd1bc79b..111e751588a8b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -20,35 +20,41 @@ package org.apache.spark.sql
import scala.language.implicitConversions
import scala.reflect.runtime.universe.{TypeTag, typeTag}
+import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types._
/**
- * Domain specific functions available for [[DataFrame]].
+ * :: Experimental ::
+ * Functions available for [[DataFrame]].
+ *
+ * @groupname udf_funcs UDF functions
+ * @groupname agg_funcs Aggregate functions
+ * @groupname sort_funcs Sorting functions
+ * @groupname normal_funcs Non-aggregate functions
+ * @groupname Ungrouped Support functions for DataFrames.
*/
-object Dsl {
-
- /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
- implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
-
- /** Converts $"col name" into an [[Column]]. */
- implicit class StringToColumn(val sc: StringContext) extends AnyVal {
- def $(args: Any*): ColumnName = {
- new ColumnName(sc.s(args :_*))
- }
- }
+@Experimental
+// scalastyle:off
+object functions {
+// scalastyle:on
private[this] implicit def toColumn(expr: Expression): Column = Column(expr)
/**
* Returns a [[Column]] based on the given column name.
+ *
+ * @group normal_funcs
*/
def col(colName: String): Column = Column(colName)
/**
* Returns a [[Column]] based on the given column name. Alias of [[col]].
+ *
+ * @group normal_funcs
*/
def column(colName: String): Column = Column(colName)
@@ -58,6 +64,8 @@ object Dsl {
* The passed in object is returned directly if it is already a [[Column]].
* If the object is a Scala Symbol, it is converted into a [[Column]] also.
* Otherwise, a new [[Column]] is created to represent the literal value.
+ *
+ * @group normal_funcs
*/
def lit(literal: Any): Column = {
literal match {
@@ -66,106 +74,219 @@ object Dsl {
case _ => // continue
}
- val literalExpr = literal match {
- case v: Boolean => Literal(v, BooleanType)
- case v: Byte => Literal(v, ByteType)
- case v: Short => Literal(v, ShortType)
- case v: Int => Literal(v, IntegerType)
- case v: Long => Literal(v, LongType)
- case v: Float => Literal(v, FloatType)
- case v: Double => Literal(v, DoubleType)
- case v: String => Literal(v, StringType)
- case v: BigDecimal => Literal(Decimal(v), DecimalType.Unlimited)
- case v: java.math.BigDecimal => Literal(Decimal(v), DecimalType.Unlimited)
- case v: Decimal => Literal(v, DecimalType.Unlimited)
- case v: java.sql.Timestamp => Literal(v, TimestampType)
- case v: java.sql.Date => Literal(v, DateType)
- case v: Array[Byte] => Literal(v, BinaryType)
- case null => Literal(null, NullType)
- case _ =>
- throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal)
- }
+ val literalExpr = Literal(literal)
Column(literalExpr)
}
//////////////////////////////////////////////////////////////////////////////////////////////
+ // Sort functions
+ //////////////////////////////////////////////////////////////////////////////////////////////
+
+ /**
+ * Returns a sort expression based on ascending order of the column.
+ * {{
+ * // Sort by dept in ascending order, and then age in descending order.
+ * df.sort(asc("dept"), desc("age"))
+ * }}
+ *
+ * @group sort_funcs
+ */
+ def asc(columnName: String): Column = Column(columnName).asc
+
+ /**
+ * Returns a sort expression based on the descending order of the column.
+ * {{
+ * // Sort by dept in ascending order, and then age in descending order.
+ * df.sort(asc("dept"), desc("age"))
+ * }}
+ *
+ * @group sort_funcs
+ */
+ def desc(columnName: String): Column = Column(columnName).desc
+
+ //////////////////////////////////////////////////////////////////////////////////////////////
+ // Aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////
- /** Aggregate function: returns the sum of all values in the expression. */
+ /**
+ * Aggregate function: returns the sum of all values in the expression.
+ *
+ * @group agg_funcs
+ */
def sum(e: Column): Column = Sum(e.expr)
- /** Aggregate function: returns the sum of all values in the given column. */
+ /**
+ * Aggregate function: returns the sum of all values in the given column.
+ *
+ * @group agg_funcs
+ */
def sum(columnName: String): Column = sum(Column(columnName))
- /** Aggregate function: returns the sum of distinct values in the expression. */
+ /**
+ * Aggregate function: returns the sum of distinct values in the expression.
+ *
+ * @group agg_funcs
+ */
def sumDistinct(e: Column): Column = SumDistinct(e.expr)
- /** Aggregate function: returns the sum of distinct values in the expression. */
+ /**
+ * Aggregate function: returns the sum of distinct values in the expression.
+ *
+ * @group agg_funcs
+ */
def sumDistinct(columnName: String): Column = sumDistinct(Column(columnName))
- /** Aggregate function: returns the number of items in a group. */
- def count(e: Column): Column = Count(e.expr)
+ /**
+ * Aggregate function: returns the number of items in a group.
+ *
+ * @group agg_funcs
+ */
+ def count(e: Column): Column = e.expr match {
+ // Turn count(*) into count(1)
+ case s: Star => Count(Literal(1))
+ case _ => Count(e.expr)
+ }
- /** Aggregate function: returns the number of items in a group. */
+ /**
+ * Aggregate function: returns the number of items in a group.
+ *
+ * @group agg_funcs
+ */
def count(columnName: String): Column = count(Column(columnName))
- /** Aggregate function: returns the number of distinct items in a group. */
+ /**
+ * Aggregate function: returns the number of distinct items in a group.
+ *
+ * @group agg_funcs
+ */
@scala.annotation.varargs
def countDistinct(expr: Column, exprs: Column*): Column =
CountDistinct((expr +: exprs).map(_.expr))
- /** Aggregate function: returns the number of distinct items in a group. */
+ /**
+ * Aggregate function: returns the number of distinct items in a group.
+ *
+ * @group agg_funcs
+ */
@scala.annotation.varargs
def countDistinct(columnName: String, columnNames: String*): Column =
countDistinct(Column(columnName), columnNames.map(Column.apply) :_*)
- /** Aggregate function: returns the approximate number of distinct items in a group. */
+ /**
+ * Aggregate function: returns the approximate number of distinct items in a group.
+ *
+ * @group agg_funcs
+ */
def approxCountDistinct(e: Column): Column = ApproxCountDistinct(e.expr)
- /** Aggregate function: returns the approximate number of distinct items in a group. */
+ /**
+ * Aggregate function: returns the approximate number of distinct items in a group.
+ *
+ * @group agg_funcs
+ */
def approxCountDistinct(columnName: String): Column = approxCountDistinct(column(columnName))
- /** Aggregate function: returns the approximate number of distinct items in a group. */
+ /**
+ * Aggregate function: returns the approximate number of distinct items in a group.
+ *
+ * @group agg_funcs
+ */
def approxCountDistinct(e: Column, rsd: Double): Column = ApproxCountDistinct(e.expr, rsd)
- /** Aggregate function: returns the approximate number of distinct items in a group. */
+ /**
+ * Aggregate function: returns the approximate number of distinct items in a group.
+ *
+ * @group agg_funcs
+ */
def approxCountDistinct(columnName: String, rsd: Double): Column = {
approxCountDistinct(Column(columnName), rsd)
}
- /** Aggregate function: returns the average of the values in a group. */
+ /**
+ * Aggregate function: returns the average of the values in a group.
+ *
+ * @group agg_funcs
+ */
def avg(e: Column): Column = Average(e.expr)
- /** Aggregate function: returns the average of the values in a group. */
+ /**
+ * Aggregate function: returns the average of the values in a group.
+ *
+ * @group agg_funcs
+ */
def avg(columnName: String): Column = avg(Column(columnName))
- /** Aggregate function: returns the first value in a group. */
+ /**
+ * Aggregate function: returns the first value in a group.
+ *
+ * @group agg_funcs
+ */
def first(e: Column): Column = First(e.expr)
- /** Aggregate function: returns the first value of a column in a group. */
+ /**
+ * Aggregate function: returns the first value of a column in a group.
+ *
+ * @group agg_funcs
+ */
def first(columnName: String): Column = first(Column(columnName))
- /** Aggregate function: returns the last value in a group. */
+ /**
+ * Aggregate function: returns the last value in a group.
+ *
+ * @group agg_funcs
+ */
def last(e: Column): Column = Last(e.expr)
- /** Aggregate function: returns the last value of the column in a group. */
+ /**
+ * Aggregate function: returns the last value of the column in a group.
+ *
+ * @group agg_funcs
+ */
def last(columnName: String): Column = last(Column(columnName))
- /** Aggregate function: returns the minimum value of the expression in a group. */
+ /**
+ * Aggregate function: returns the minimum value of the expression in a group.
+ *
+ * @group agg_funcs
+ */
def min(e: Column): Column = Min(e.expr)
- /** Aggregate function: returns the minimum value of the column in a group. */
+ /**
+ * Aggregate function: returns the minimum value of the column in a group.
+ *
+ * @group agg_funcs
+ */
def min(columnName: String): Column = min(Column(columnName))
- /** Aggregate function: returns the maximum value of the expression in a group. */
+ /**
+ * Aggregate function: returns the maximum value of the expression in a group.
+ *
+ * @group agg_funcs
+ */
def max(e: Column): Column = Max(e.expr)
- /** Aggregate function: returns the maximum value of the column in a group. */
+ /**
+ * Aggregate function: returns the maximum value of the column in a group.
+ *
+ * @group agg_funcs
+ */
def max(columnName: String): Column = max(Column(columnName))
//////////////////////////////////////////////////////////////////////////////////////////////
+ // Non-aggregate functions
//////////////////////////////////////////////////////////////////////////////////////////////
+ /**
+ * Returns the first column that is not null.
+ * {{{
+ * df.select(coalesce(df("a"), df("b")))
+ * }}}
+ *
+ * @group normal_funcs
+ */
+ @scala.annotation.varargs
+ def coalesce(e: Column*): Column = Coalesce(e.map(_.expr))
+
/**
* Unary minus, i.e. negate the expression.
* {{{
@@ -176,6 +297,8 @@ object Dsl {
* // Java:
* df.select( negate(df.col("amount")) );
* }}}
+ *
+ * @group normal_funcs
*/
def negate(e: Column): Column = -e
@@ -188,19 +311,37 @@ object Dsl {
* // Java:
* df.filter( not(df.col("isActive")) );
* }}
+ *
+ * @group normal_funcs
*/
def not(e: Column): Column = !e
- /** Converts a string expression to upper case. */
+ /**
+ * Converts a string expression to upper case.
+ *
+ * @group normal_funcs
+ */
def upper(e: Column): Column = Upper(e.expr)
- /** Converts a string exprsesion to lower case. */
+ /**
+ * Converts a string exprsesion to lower case.
+ *
+ * @group normal_funcs
+ */
def lower(e: Column): Column = Lower(e.expr)
- /** Computes the square root of the specified float value. */
+ /**
+ * Computes the square root of the specified float value.
+ *
+ * @group normal_funcs
+ */
def sqrt(e: Column): Column = Sqrt(e.expr)
- /** Computes the absolutle value. */
+ /**
+ * Computes the absolutle value.
+ *
+ * @group normal_funcs
+ */
def abs(e: Column): Column = Abs(e.expr)
//////////////////////////////////////////////////////////////////////////////////////////////
@@ -216,6 +357,8 @@ object Dsl {
/**
* Defines a user-defined function of ${x} arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -230,6 +373,8 @@ object Dsl {
/**
* Call a Scala function of ${x} arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = {
ScalaUdf(f, returnType, Seq($argsInUdf))
@@ -240,6 +385,8 @@ object Dsl {
/**
* Defines a user-defined function of 0 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -248,6 +395,8 @@ object Dsl {
/**
* Defines a user-defined function of 1 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -256,6 +405,8 @@ object Dsl {
/**
* Defines a user-defined function of 2 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -264,6 +415,8 @@ object Dsl {
/**
* Defines a user-defined function of 3 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -272,6 +425,8 @@ object Dsl {
/**
* Defines a user-defined function of 4 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -280,6 +435,8 @@ object Dsl {
/**
* Defines a user-defined function of 5 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -288,6 +445,8 @@ object Dsl {
/**
* Defines a user-defined function of 6 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -296,6 +455,8 @@ object Dsl {
/**
* Defines a user-defined function of 7 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -304,6 +465,8 @@ object Dsl {
/**
* Defines a user-defined function of 8 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -312,6 +475,8 @@ object Dsl {
/**
* Defines a user-defined function of 9 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -320,6 +485,8 @@ object Dsl {
/**
* Defines a user-defined function of 10 arguments as user-defined function (UDF).
* The data types are automatically inferred based on the function's signature.
+ *
+ * @group udf_funcs
*/
def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = {
UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType)
@@ -330,6 +497,8 @@ object Dsl {
/**
* Call a Scala function of 0 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function0[_], returnType: DataType): Column = {
ScalaUdf(f, returnType, Seq())
@@ -338,6 +507,8 @@ object Dsl {
/**
* Call a Scala function of 1 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr))
@@ -346,6 +517,8 @@ object Dsl {
/**
* Call a Scala function of 2 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr))
@@ -354,6 +527,8 @@ object Dsl {
/**
* Call a Scala function of 3 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr))
@@ -362,6 +537,8 @@ object Dsl {
/**
* Call a Scala function of 4 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr))
@@ -370,6 +547,8 @@ object Dsl {
/**
* Call a Scala function of 5 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr))
@@ -378,6 +557,8 @@ object Dsl {
/**
* Call a Scala function of 6 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr))
@@ -386,6 +567,8 @@ object Dsl {
/**
* Call a Scala function of 7 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr))
@@ -394,6 +577,8 @@ object Dsl {
/**
* Call a Scala function of 8 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr))
@@ -402,6 +587,8 @@ object Dsl {
/**
* Call a Scala function of 9 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr))
@@ -410,6 +597,8 @@ object Dsl {
/**
* Call a Scala function of 10 arguments as user-defined function (UDF). This requires
* you to specify the return data type.
+ *
+ * @group udf_funcs
*/
def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = {
ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
index 0bec32cca1325..87304ce2496b4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRDD.scala
@@ -370,21 +370,21 @@ private[sql] class JDBCRDD(
def close() {
if (closed) return
try {
- if (null != rs && ! rs.isClosed()) {
+ if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
- if (null != stmt && ! stmt.isClosed()) {
+ if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
- if (null != conn && ! conn.isClosed()) {
+ if (null != conn) {
conn.close()
}
logInfo("closed connection")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
index 66ad38eb7c45b..beb76f2c553c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JDBCRelation.scala
@@ -48,11 +48,6 @@ private[sql] object JDBCRelation {
* exactly once. The parameters minValue and maxValue are advisory in that
* incorrect values may cause the partitioning to be poor, but no data
* will fail to be represented.
- *
- * @param column - Column name. Must refer to a column of integral type.
- * @param numPartitions - Number of partitions
- * @param minValue - Smallest value of column. Advisory.
- * @param maxValue - Largest value of column. Advisory.
*/
def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = {
if (partitioning == null) return Array[Partition](JDBCPartition(null, 0))
@@ -68,12 +63,17 @@ private[sql] object JDBCRelation {
var currentValue: Long = partitioning.lowerBound
var ans = new ArrayBuffer[Partition]()
while (i < numPartitions) {
- val lowerBound = (if (i != 0) s"$column >= $currentValue" else null)
+ val lowerBound = if (i != 0) s"$column >= $currentValue" else null
currentValue += stride
- val upperBound = (if (i != numPartitions - 1) s"$column < $currentValue" else null)
- val whereClause = (if (upperBound == null) lowerBound
- else if (lowerBound == null) upperBound
- else s"$lowerBound AND $upperBound")
+ val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null
+ val whereClause =
+ if (upperBound == null) {
+ lowerBound
+ } else if (lowerBound == null) {
+ upperBound
+ } else {
+ s"$lowerBound AND $upperBound"
+ }
ans += JDBCPartition(whereClause, i)
i = i + 1
}
@@ -96,8 +96,7 @@ private[sql] class DefaultSource extends RelationProvider {
if (driver != null) Class.forName(driver)
- if (
- partitionColumn != null
+ if (partitionColumn != null
&& (lowerBound == null || upperBound == null || numPartitions == null)) {
sys.error("Partitioning incompletely specified")
}
@@ -119,7 +118,8 @@ private[sql] class DefaultSource extends RelationProvider {
private[sql] case class JDBCRelation(
url: String,
table: String,
- parts: Array[Partition])(@transient val sqlContext: SQLContext) extends PrunedFilteredScan {
+ parts: Array[Partition])(@transient val sqlContext: SQLContext)
+ extends PrunedFilteredScan {
override val schema = JDBCRDD.resolveTable(url, table)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
index 34a83f0a5dad8..34f864f5fda7a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/jdbc.scala
@@ -26,11 +26,11 @@ import org.apache.spark.sql.jdbc.{JDBCPartitioningInfo, JDBCRelation, JDBCPartit
import org.apache.spark.sql.types._
package object jdbc {
- object JDBCWriteDetails extends Logging {
+ private[sql] object JDBCWriteDetails extends Logging {
/**
* Returns a PreparedStatement that inserts a row into table via conn.
*/
- private def insertStatement(conn: Connection, table: String, rddSchema: StructType):
+ def insertStatement(conn: Connection, table: String, rddSchema: StructType):
PreparedStatement = {
val sql = new StringBuilder(s"INSERT INTO $table VALUES (")
var fieldsLeft = rddSchema.fields.length
@@ -56,7 +56,7 @@ package object jdbc {
* non-Serializable. Instead, we explicitly close over all variables that
* are used.
*/
- private[jdbc] def savePartition(url: String, table: String, iterator: Iterator[Row],
+ def savePartition(url: String, table: String, iterator: Iterator[Row],
rddSchema: StructType, nullTypes: Array[Int]): Iterator[Byte] = {
val conn = DriverManager.getConnection(url)
var committed = false
@@ -117,19 +117,14 @@ package object jdbc {
}
Array[Byte]().iterator
}
- }
- /**
- * Make it so that you can call createJDBCTable and insertIntoJDBC on a DataFrame.
- */
- implicit class JDBCDataFrame(rdd: DataFrame) {
/**
* Compute the schema string for this RDD.
*/
- private def schemaString(url: String): String = {
+ def schemaString(df: DataFrame, url: String): String = {
val sb = new StringBuilder()
val quirks = DriverQuirks.get(url)
- rdd.schema.fields foreach { field => {
+ df.schema.fields foreach { field => {
val name = field.name
var typ: String = quirks.getJDBCType(field.dataType)._1
if (typ == null) typ = field.dataType match {
@@ -156,9 +151,9 @@ package object jdbc {
/**
* Saves the RDD to the database in a single transaction.
*/
- private def saveTable(url: String, table: String) {
+ def saveTable(df: DataFrame, url: String, table: String) {
val quirks = DriverQuirks.get(url)
- var nullTypes: Array[Int] = rdd.schema.fields.map(field => {
+ var nullTypes: Array[Int] = df.schema.fields.map(field => {
var nullType: Option[Int] = quirks.getJDBCType(field.dataType)._2
if (nullType.isEmpty) {
field.dataType match {
@@ -175,61 +170,16 @@ package object jdbc {
case DateType => java.sql.Types.DATE
case DecimalType.Unlimited => java.sql.Types.DECIMAL
case _ => throw new IllegalArgumentException(
- s"Can't translate null value for field $field")
+ s"Can't translate null value for field $field")
}
} else nullType.get
}).toArray
- val rddSchema = rdd.schema
- rdd.mapPartitions(iterator => JDBCWriteDetails.savePartition(
- url, table, iterator, rddSchema, nullTypes)).collect()
- }
-
- /**
- * Save this RDD to a JDBC database at `url` under the table name `table`.
- * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements.
- * If you pass `true` for `allowExisting`, it will drop any table with the
- * given name; if you pass `false`, it will throw if the table already
- * exists.
- */
- def createJDBCTable(url: String, table: String, allowExisting: Boolean) {
- val conn = DriverManager.getConnection(url)
- try {
- if (allowExisting) {
- val sql = s"DROP TABLE IF EXISTS $table"
- conn.prepareStatement(sql).executeUpdate()
- }
- val schema = schemaString(url)
- val sql = s"CREATE TABLE $table ($schema)"
- conn.prepareStatement(sql).executeUpdate()
- } finally {
- conn.close()
+ val rddSchema = df.schema
+ df.foreachPartition { iterator =>
+ JDBCWriteDetails.savePartition(url, table, iterator, rddSchema, nullTypes)
}
- saveTable(url, table)
}
- /**
- * Save this RDD to a JDBC database at `url` under the table name `table`.
- * Assumes the table already exists and has a compatible schema. If you
- * pass `true` for `overwrite`, it will `TRUNCATE` the table before
- * performing the `INSERT`s.
- *
- * The table must already exist on the database. It must have a schema
- * that is compatible with the schema of this RDD; inserting the rows of
- * the RDD in order via the simple statement
- * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail.
- */
- def insertIntoJDBC(url: String, table: String, overwrite: Boolean) {
- if (overwrite) {
- val conn = DriverManager.getConnection(url)
- try {
- val sql = s"TRUNCATE TABLE $table"
- conn.prepareStatement(sql).executeUpdate()
- } finally {
- conn.close()
- }
- }
- saveTable(url, table)
- }
- } // implicit class JDBCDataFrame
+ }
} // package object jdbc
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
index c4e14c6c92908..3b68b7c275016 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JSONRelation.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.json
import java.io.IOException
import org.apache.hadoop.fs.Path
-import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.StructType
@@ -29,11 +29,15 @@ import org.apache.spark.sql.types.StructType
private[sql] class DefaultSource
extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {
+ private def checkPath(parameters: Map[String, String]): String = {
+ parameters.getOrElse("path", sys.error("'path' must be specified for json data."))
+ }
+
/** Returns a new base relation with the parameters. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String]): BaseRelation = {
- val path = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
+ val path = checkPath(parameters)
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
JSONRelation(path, samplingRatio, None)(sqlContext)
@@ -44,7 +48,7 @@ private[sql] class DefaultSource
sqlContext: SQLContext,
parameters: Map[String, String],
schema: StructType): BaseRelation = {
- val path = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
+ val path = checkPath(parameters)
val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
JSONRelation(path, samplingRatio, Some(schema))(sqlContext)
@@ -52,15 +56,30 @@ private[sql] class DefaultSource
override def createRelation(
sqlContext: SQLContext,
+ mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
- val path = parameters.getOrElse("path", sys.error("Option 'path' not specified"))
+ val path = checkPath(parameters)
val filesystemPath = new Path(path)
val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
- if (fs.exists(filesystemPath)) {
- sys.error(s"path $path already exists.")
+ val doSave = if (fs.exists(filesystemPath)) {
+ mode match {
+ case SaveMode.Append =>
+ sys.error(s"Append mode is not supported by ${this.getClass.getCanonicalName}")
+ case SaveMode.Overwrite =>
+ fs.delete(filesystemPath, true)
+ true
+ case SaveMode.ErrorIfExists =>
+ sys.error(s"path $path already exists.")
+ case SaveMode.Ignore => false
+ }
+ } else {
+ true
+ }
+ if (doSave) {
+ // Only save data when the save mode is not ignore.
+ data.toJSON.saveAsTextFile(path)
}
- data.toJSON.saveAsTextFile(path)
createRelation(sqlContext, parameters, data.schema)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
index 33ce71b51b213..d83bdc2f7ff9a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala
@@ -17,14 +17,12 @@
package org.apache.spark.sql.json
-import java.io.StringWriter
-import java.sql.{Date, Timestamp}
+import java.sql.Timestamp
import scala.collection.Map
import scala.collection.convert.Wrappers.{JMapWrapper, JListWrapper}
-import com.fasterxml.jackson.core.JsonProcessingException
-import com.fasterxml.jackson.core.JsonFactory
+import com.fasterxml.jackson.core.{JsonGenerator, JsonProcessingException}
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.spark.rdd.RDD
@@ -179,7 +177,12 @@ private[sql] object JsonRDD extends Logging {
}
private def typeOfPrimitiveValue: PartialFunction[Any, DataType] = {
- ScalaReflection.typeOfObject orElse {
+ // For Integer values, use LongType by default.
+ val useLongType: PartialFunction[Any, DataType] = {
+ case value: IntegerType.JvmType => LongType
+ }
+
+ useLongType orElse ScalaReflection.typeOfObject orElse {
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
case value: java.math.BigInteger => DecimalType.Unlimited
@@ -303,6 +306,10 @@ private[sql] object JsonRDD extends Logging {
val parsed = mapper.readValue(record, classOf[Object]) match {
case map: java.util.Map[_, _] => scalafy(map).asInstanceOf[Map[String, Any]] :: Nil
case list: java.util.List[_] => scalafy(list).asInstanceOf[Seq[Map[String, Any]]]
+ case _ =>
+ sys.error(
+ s"Failed to parse record $record. Please make sure that each line of the file " +
+ "(or each string in the RDD) is a valid JSON object or an array of JSON objects.")
}
parsed
@@ -409,6 +416,9 @@ private[sql] object JsonRDD extends Logging {
case NullType => null
case ArrayType(elementType, _) =>
value.asInstanceOf[Seq[Any]].map(enforceCorrectType(_, elementType))
+ case MapType(StringType, valueType, _) =>
+ val map = value.asInstanceOf[Map[String, Any]]
+ map.mapValues(enforceCorrectType(_, valueType)).map(identity)
case struct: StructType => asRow(value.asInstanceOf[Map[String, Any]], struct)
case DateType => toDate(value)
case TimestampType => toTimestamp(value)
@@ -430,14 +440,11 @@ private[sql] object JsonRDD extends Logging {
/** Transforms a single Row to JSON using Jackson
*
- * @param jsonFactory a JsonFactory object to construct a JsonGenerator
* @param rowSchema the schema object used for conversion
+ * @param gen a JsonGenerator object
* @param row The row to convert
*/
- private[sql] def rowToJSON(rowSchema: StructType, jsonFactory: JsonFactory)(row: Row): String = {
- val writer = new StringWriter()
- val gen = jsonFactory.createGenerator(writer)
-
+ private[sql] def rowToJSON(rowSchema: StructType, gen: JsonGenerator)(row: Row) = {
def valWriter: (DataType, Any) => Unit = {
case (_, null) | (NullType, _) => gen.writeNull()
case (StringType, v: String) => gen.writeString(v)
@@ -479,8 +486,5 @@ private[sql] object JsonRDD extends Logging {
}
valWriter(rowSchema, row)
- gen.close()
- writer.toString
}
-
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index b0db9943a506c..a0d1005c0cae3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -18,11 +18,12 @@
package org.apache.spark.sql.parquet
import java.io.IOException
+import java.util.logging.Level
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.permission.FsAction
-import parquet.hadoop.ParquetOutputFormat
+import parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat}
import parquet.hadoop.metadata.CompressionCodecName
import parquet.schema.MessageType
@@ -91,7 +92,7 @@ private[sql] object ParquetRelation {
// checks first to see if there's any handlers already set
// and if not it creates them. If this method executes prior
// to that class being loaded then:
- // 1) there's no handlers installed so there's none to
+ // 1) there's no handlers installed so there's none to
// remove. But when it IS finally loaded the desired affect
// of removing them is circumvented.
// 2) The parquet.Log static initializer calls setUseParentHanders(false)
@@ -99,7 +100,7 @@ private[sql] object ParquetRelation {
//
// Therefore we need to force the class to be loaded.
// This should really be resolved by Parquet.
- Class.forName(classOf[parquet.Log].getName())
+ Class.forName(classOf[parquet.Log].getName)
// Note: Logger.getLogger("parquet") has a default logger
// that appends to Console which needs to be cleared.
@@ -108,6 +109,11 @@ private[sql] object ParquetRelation {
// TODO(witgo): Need to set the log level ?
// if(parquetLogger.getLevel != null) parquetLogger.setLevel(null)
if (!parquetLogger.getUseParentHandlers) parquetLogger.setUseParentHandlers(true)
+
+ // Disables WARN log message in ParquetOutputCommitter.
+ // See https://issues.apache.org/jira/browse/SPARK-5968 for details
+ Class.forName(classOf[ParquetOutputCommitter].getName)
+ java.util.logging.Logger.getLogger(classOf[ParquetOutputCommitter].getName).setLevel(Level.OFF)
}
// The element type for the RDDs that this relation maps to.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 28cd17fde46ab..225ec6db7d553 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -48,6 +48,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row, _}
import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
+import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.{Logging, SerializableWritable, TaskContext}
/**
@@ -55,7 +56,7 @@ import org.apache.spark.{Logging, SerializableWritable, TaskContext}
* Parquet table scan operator. Imports the file that backs the given
* [[org.apache.spark.sql.parquet.ParquetRelation]] as a ``RDD[Row]``.
*/
-case class ParquetTableScan(
+private[sql] case class ParquetTableScan(
attributes: Seq[Attribute],
relation: ParquetRelation,
columnPruningPred: Seq[Expression])
@@ -125,6 +126,13 @@ case class ParquetTableScan(
conf)
if (requestedPartitionOrdinals.nonEmpty) {
+ // This check is based on CatalystConverter.createRootConverter.
+ val primitiveRow = output.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))
+
+ // Uses temporary variable to avoid the whole `ParquetTableScan` object being captured into
+ // the `mapPartitionsWithInputSplit` closure below.
+ val outputSize = output.size
+
baseRDD.mapPartitionsWithInputSplit { case (split, iter) =>
val partValue = "([^=]+)=([^=]+)".r
val partValues =
@@ -142,19 +150,47 @@ case class ParquetTableScan(
relation.partitioningAttributes
.map(a => Cast(Literal(partValues(a.name)), a.dataType).eval(EmptyRow))
- new Iterator[Row] {
- def hasNext = iter.hasNext
- def next() = {
- val row = iter.next()._2.asInstanceOf[SpecificMutableRow]
-
- // Parquet will leave partitioning columns empty, so we fill them in here.
- var i = 0
- while (i < requestedPartitionOrdinals.size) {
- row(requestedPartitionOrdinals(i)._2) =
- partitionRowValues(requestedPartitionOrdinals(i)._1)
- i += 1
+ if (primitiveRow) {
+ new Iterator[Row] {
+ def hasNext = iter.hasNext
+ def next() = {
+ // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow.
+ val row = iter.next()._2.asInstanceOf[SpecificMutableRow]
+
+ // Parquet will leave partitioning columns empty, so we fill them in here.
+ var i = 0
+ while (i < requestedPartitionOrdinals.size) {
+ row(requestedPartitionOrdinals(i)._2) =
+ partitionRowValues(requestedPartitionOrdinals(i)._1)
+ i += 1
+ }
+ row
+ }
+ }
+ } else {
+ // Create a mutable row since we need to fill in values from partition columns.
+ val mutableRow = new GenericMutableRow(outputSize)
+ new Iterator[Row] {
+ def hasNext = iter.hasNext
+ def next() = {
+ // We are using CatalystGroupConverter and it returns a GenericRow.
+ // Since GenericRow is not mutable, we just cast it to a Row.
+ val row = iter.next()._2.asInstanceOf[Row]
+
+ var i = 0
+ while (i < row.size) {
+ mutableRow(i) = row(i)
+ i += 1
+ }
+ // Parquet will leave partitioning columns empty, so we fill them in here.
+ i = 0
+ while (i < requestedPartitionOrdinals.size) {
+ mutableRow(requestedPartitionOrdinals(i)._2) =
+ partitionRowValues(requestedPartitionOrdinals(i)._1)
+ i += 1
+ }
+ mutableRow
}
- row
}
}
}
@@ -210,7 +246,7 @@ case class ParquetTableScan(
* (only detected via filename pattern so will not catch all cases).
*/
@DeveloperApi
-case class InsertIntoParquetTable(
+private[sql] case class InsertIntoParquetTable(
relation: ParquetRelation,
child: SparkPlan,
overwrite: Boolean = false)
@@ -373,8 +409,6 @@ private[parquet] class AppendingParquetOutputFormat(offset: Int)
private[parquet] class FilteringParquetRowInputFormat
extends parquet.hadoop.ParquetInputFormat[Row] with Logging {
- private var footers: JList[Footer] = _
-
private var fileStatuses = Map.empty[Path, FileStatus]
override def createRecordReader(
@@ -395,46 +429,15 @@ private[parquet] class FilteringParquetRowInputFormat
}
}
- override def getFooters(jobContext: JobContext): JList[Footer] = {
- import org.apache.spark.sql.parquet.FilteringParquetRowInputFormat.footerCache
-
- if (footers eq null) {
- val conf = ContextUtil.getConfiguration(jobContext)
- val cacheMetadata = conf.getBoolean(SQLConf.PARQUET_CACHE_METADATA, true)
- val statuses = listStatus(jobContext)
- fileStatuses = statuses.map(file => file.getPath -> file).toMap
- if (statuses.isEmpty) {
- footers = Collections.emptyList[Footer]
- } else if (!cacheMetadata) {
- // Read the footers from HDFS
- footers = getFooters(conf, statuses)
- } else {
- // Read only the footers that are not in the footerCache
- val foundFooters = footerCache.getAllPresent(statuses)
- val toFetch = new ArrayList[FileStatus]
- for (s <- statuses) {
- if (!foundFooters.containsKey(s)) {
- toFetch.add(s)
- }
- }
- val newFooters = new mutable.HashMap[FileStatus, Footer]
- if (toFetch.size > 0) {
- val startFetch = System.currentTimeMillis
- val fetched = getFooters(conf, toFetch)
- logInfo(s"Fetched $toFetch footers in ${System.currentTimeMillis - startFetch} ms")
- for ((status, i) <- toFetch.zipWithIndex) {
- newFooters(status) = fetched.get(i)
- }
- footerCache.putAll(newFooters)
- }
- footers = new ArrayList[Footer](statuses.size)
- for (status <- statuses) {
- footers.add(newFooters.getOrElse(status, foundFooters.get(status)))
- }
- }
- }
+ // This is only a temporary solution sicne we need to use fileStatuses in
+ // both getClientSideSplits and getTaskSideSplits. It can be removed once we get rid of these
+ // two methods.
+ override def getSplits(jobContext: JobContext): JList[InputSplit] = {
+ // First set fileStatuses.
+ val statuses = listStatus(jobContext)
+ fileStatuses = statuses.map(file => file.getPath -> file).toMap
- footers
+ super.getSplits(jobContext)
}
// TODO Remove this method and related code once PARQUET-16 is fixed
@@ -459,13 +462,21 @@ private[parquet] class FilteringParquetRowInputFormat
val getGlobalMetaData =
classOf[ParquetFileWriter].getDeclaredMethod("getGlobalMetaData", classOf[JList[Footer]])
getGlobalMetaData.setAccessible(true)
- val globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData]
+ var globalMetaData = getGlobalMetaData.invoke(null, footers).asInstanceOf[GlobalMetaData]
if (globalMetaData == null) {
val splits = mutable.ArrayBuffer.empty[ParquetInputSplit]
return splits
}
+ val metadata = configuration.get(RowWriteSupport.SPARK_ROW_SCHEMA)
+ val mergedMetadata = globalMetaData
+ .getKeyValueMetaData
+ .updated(RowReadSupport.SPARK_METADATA_KEY, setAsJavaSet(Set(metadata)))
+
+ globalMetaData = new GlobalMetaData(globalMetaData.getSchema,
+ mergedMetadata, globalMetaData.getCreatedBy)
+
val readContext = getReadSupport(configuration).init(
new InitContext(configuration,
globalMetaData.getKeyValueMetaData,
@@ -647,6 +658,6 @@ private[parquet] object FileSystemHelper {
sys.error("ERROR: attempting to append to set of Parquet files and found file" +
s"that does not match name pattern: $other")
case _ => 0
- }.reduceLeft((a, b) => if (a < b) b else a)
+ }.reduceOption(_ max _).getOrElse(0)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
index 8d3e094e3344d..d6ea6679c5966 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala
@@ -23,8 +23,8 @@ import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
import scala.util.Try
-import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.util
+import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
import org.apache.spark.util.Utils
/**
@@ -34,10 +34,11 @@ import org.apache.spark.util.Utils
* convenient to use tuples rather than special case classes when writing test cases/suites.
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
*/
-trait ParquetTest {
+private[sql] trait ParquetTest {
val sqlContext: SQLContext
- import sqlContext._
+ import sqlContext.implicits.{localSeqToDataFrameHolder, rddToDataFrameHolder}
+ import sqlContext.{conf, sparkContext}
protected def configuration = sparkContext.hadoopConfiguration
@@ -49,11 +50,11 @@ trait ParquetTest {
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(getConf(key)).toOption)
- (keys, values).zipped.foreach(setConf)
+ val currentValues = keys.map(key => Try(conf.getConf(key)).toOption)
+ (keys, values).zipped.foreach(conf.setConf)
try f finally {
keys.zip(currentValues).foreach {
- case (key, Some(value)) => setConf(key, value)
+ case (key, Some(value)) => conf.setConf(key, value)
case (key, None) => conf.unsetConf(key)
}
}
@@ -88,9 +89,8 @@ trait ParquetTest {
protected def withParquetFile[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: String => Unit): Unit = {
- import sqlContext.implicits._
withTempPath { file =>
- sparkContext.parallelize(data).saveAsParquetFile(file.getCanonicalPath)
+ sparkContext.parallelize(data).toDF().saveAsParquetFile(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
@@ -99,17 +99,17 @@ trait ParquetTest {
* Writes `data` to a Parquet file and reads it back as a [[DataFrame]],
* which is then passed to `f`. The Parquet file will be deleted after `f` returns.
*/
- protected def withParquetRDD[T <: Product: ClassTag: TypeTag]
+ protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
- withParquetFile(data)(path => f(parquetFile(path)))
+ withParquetFile(data)(path => f(sqlContext.parquetFile(path)))
}
/**
* Drops temporary table `tableName` after calling `f`.
*/
protected def withTempTable(tableName: String)(f: => Unit): Unit = {
- try f finally dropTempTable(tableName)
+ try f finally sqlContext.dropTempTable(tableName)
}
/**
@@ -120,9 +120,36 @@ trait ParquetTest {
protected def withParquetTable[T <: Product: ClassTag: TypeTag]
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
- withParquetRDD(data) { rdd =>
- sqlContext.registerRDDAsTable(rdd, tableName)
+ withParquetDataFrame(data) { df =>
+ sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
+
+ protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
+ data: Seq[T], path: File): Unit = {
+ data.toDF().save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite)
+ }
+
+ protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
+ df: DataFrame, path: File): Unit = {
+ df.save(path.getCanonicalPath, "org.apache.spark.sql.parquet", SaveMode.Overwrite)
+ }
+
+ protected def makePartitionDir(
+ basePath: File,
+ defaultPartitionName: String,
+ partitionCols: (String, Any)*): File = {
+ val partNames = partitionCols.map { case (k, v) =>
+ val valueString = if (v == null || v == "") defaultPartitionName else v.toString
+ s"$k=$valueString"
+ }
+
+ val partDir = partNames.foldLeft(basePath) { (parent, child) =>
+ new File(parent, child)
+ }
+
+ assert(partDir.mkdirs(), s"Couldn't create directory $partDir")
+ partDir
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
index 49d46334b6525..6d56be3ab8dd4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala
@@ -20,7 +20,7 @@ import java.io.IOException
import java.lang.{Double => JDouble, Float => JFloat, Long => JLong}
import java.math.{BigDecimal => JBigDecimal}
import java.text.SimpleDateFormat
-import java.util.{List => JList, Date}
+import java.util.{Date, List => JList}
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
@@ -32,34 +32,51 @@ import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat
import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext}
+
import parquet.filter2.predicate.FilterApi
import parquet.format.converter.ParquetMetadataConverter
-import parquet.hadoop.{ParquetInputFormat, _}
+import parquet.hadoop.metadata.CompressionCodecName
import parquet.hadoop.util.ContextUtil
+import parquet.hadoop.{ParquetInputFormat, _}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.{NewHadoopPartition, NewHadoopRDD, RDD}
+import org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.parquet.ParquetTypesConverter._
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{IntegerType, StructField, StructType, _}
-import org.apache.spark.sql.types.StructType._
-import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext}
-import org.apache.spark.{Partition => SparkPartition, TaskContext, SerializableWritable, Logging, SparkException}
-
+import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext, SaveMode}
+import org.apache.spark.{Logging, Partition => SparkPartition, SerializableWritable, SparkException, TaskContext}
/**
- * Allows creation of parquet based tables using the syntax
- * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option
- * required is `path`, which should be the location of a collection of, optionally partitioned,
- * parquet files.
+ * Allows creation of Parquet based tables using the syntax:
+ * {{{
+ * CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet OPTIONS (...)
+ * }}}
+ *
+ * Supported options include:
+ *
+ * - `path`: Required. When reading Parquet files, `path` should point to the location of the
+ * Parquet file(s). It can be either a single raw Parquet file, or a directory of Parquet files.
+ * In the latter case, this data source tries to discover partitioning information if the the
+ * directory is structured in the same style of Hive partitioned tables. When writing Parquet
+ * file, `path` should point to the destination folder.
+ *
+ * - `mergeSchema`: Optional. Indicates whether we should merge potentially different (but
+ * compatible) schemas stored in all Parquet part-files.
+ *
+ * - `partition.defaultName`: Optional. Partition name used when a value of a partition column is
+ * null or empty string. This is similar to the `hive.exec.default.partition.name` configuration
+ * in Hive.
*/
-class DefaultSource
+private[sql] class DefaultSource
extends RelationProvider
with SchemaRelationProvider
with CreatableRelationProvider {
+
private def checkPath(parameters: Map[String, String]): String = {
parameters.getOrElse("path", sys.error("'path' must be specified for parquet tables."))
}
@@ -71,6 +88,7 @@ class DefaultSource
ParquetRelation2(Seq(checkPath(parameters)), parameters, None)(sqlContext)
}
+ /** Returns a new base relation with the given parameters and schema. */
override def createRelation(
sqlContext: SQLContext,
parameters: Map[String, String],
@@ -78,55 +96,63 @@ class DefaultSource
ParquetRelation2(Seq(checkPath(parameters)), parameters, Some(schema))(sqlContext)
}
+ /** Returns a new base relation with the given parameters and save given data into it. */
override def createRelation(
sqlContext: SQLContext,
+ mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
val path = checkPath(parameters)
- ParquetRelation.createEmpty(
- path,
- data.schema.toAttributes,
- false,
- sqlContext.sparkContext.hadoopConfiguration,
- sqlContext)
-
- val relation = createRelation(sqlContext, parameters, data.schema)
- relation.asInstanceOf[ParquetRelation2].insert(data, true)
+ val filesystemPath = new Path(path)
+ val fs = filesystemPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
+ val doInsertion = (mode, fs.exists(filesystemPath)) match {
+ case (SaveMode.ErrorIfExists, true) =>
+ sys.error(s"path $path already exists.")
+ case (SaveMode.Append, _) | (SaveMode.Overwrite, _) | (SaveMode.ErrorIfExists, false) =>
+ true
+ case (SaveMode.Ignore, exists) =>
+ !exists
+ }
+
+ val relation = if (doInsertion) {
+ val createdRelation =
+ createRelation(sqlContext, parameters, data.schema).asInstanceOf[ParquetRelation2]
+ createdRelation.insert(data, overwrite = mode == SaveMode.Overwrite)
+ createdRelation
+ } else {
+ // If the save mode is Ignore, we will just create the relation based on existing data.
+ createRelation(sqlContext, parameters)
+ }
+
relation
}
}
-private[parquet] case class Partition(values: Row, path: String)
+private[sql] case class Partition(values: Row, path: String)
-private[parquet] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition])
+private[sql] case class PartitionSpec(partitionColumns: StructType, partitions: Seq[Partition])
/**
* An alternative to [[ParquetRelation]] that plugs in using the data sources API. This class is
- * currently not intended as a full replacement of the parquet support in Spark SQL though it is
- * likely that it will eventually subsume the existing physical plan implementation.
+ * intended as a full replacement of the Parquet support in Spark SQL. The old implementation will
+ * be deprecated and eventually removed once this version is proved to be stable enough.
*
- * Compared with the current implementation, this class has the following notable differences:
+ * Compared with the old implementation, this class has the following notable differences:
*
- * Partitioning: Partitions are auto discovered and must be in the form of directories `key=value/`
- * located at `path`. Currently only a single partitioning column is supported and it must
- * be an integer. This class supports both fully self-describing data, which contains the partition
- * key, and data where the partition key is only present in the folder structure. The presence
- * of the partitioning key in the data is also auto-detected. The `null` partition is not yet
- * supported.
- *
- * Metadata: The metadata is automatically discovered by reading the first parquet file present.
- * There is currently no support for working with files that have different schema. Additionally,
- * when parquet metadata caching is turned on, the FileStatus objects for all data will be cached
- * to improve the speed of interactive querying. When data is added to a table it must be dropped
- * and recreated to pick up any changes.
- *
- * Statistics: Statistics for the size of the table are automatically populated during metadata
- * discovery.
+ * - Partitioning discovery: Hive style multi-level partitions are auto discovered.
+ * - Metadata discovery: Parquet is a format comes with schema evolving support. This data source
+ * can detect and merge schemas from all Parquet part-files as long as they are compatible.
+ * Also, metadata and [[FileStatus]]es are cached for better performance.
+ * - Statistics: Statistics for the size of the table are automatically populated during schema
+ * discovery.
*/
@DeveloperApi
-case class ParquetRelation2
- (paths: Seq[String], parameters: Map[String, String], maybeSchema: Option[StructType] = None)
- (@transient val sqlContext: SQLContext)
+private[sql] case class ParquetRelation2(
+ paths: Seq[String],
+ parameters: Map[String, String],
+ maybeSchema: Option[StructType] = None,
+ maybePartitionSpec: Option[PartitionSpec] = None)(
+ @transient val sqlContext: SQLContext)
extends CatalystScan
with InsertableRelation
with SparkHadoopMapReduceUtil
@@ -149,42 +175,90 @@ case class ParquetRelation2
override def equals(other: Any) = other match {
case relation: ParquetRelation2 =>
+ // If schema merging is required, we don't compare the actual schemas since they may evolve.
+ val schemaEquality = if (shouldMergeSchemas) {
+ shouldMergeSchemas == relation.shouldMergeSchemas
+ } else {
+ schema == relation.schema
+ }
+
paths.toSet == relation.paths.toSet &&
+ schemaEquality &&
maybeMetastoreSchema == relation.maybeMetastoreSchema &&
- (shouldMergeSchemas == relation.shouldMergeSchemas || schema == relation.schema)
+ maybePartitionSpec == relation.maybePartitionSpec
+
+ case _ => false
}
private[sql] def sparkContext = sqlContext.sparkContext
- @transient private val fs = FileSystem.get(sparkContext.hadoopConfiguration)
-
private class MetadataCache {
+ // `FileStatus` objects of all "_metadata" files.
private var metadataStatuses: Array[FileStatus] = _
+
+ // `FileStatus` objects of all "_common_metadata" files.
private var commonMetadataStatuses: Array[FileStatus] = _
- private var footers: Map[FileStatus, Footer] = _
- private var parquetSchema: StructType = _
+ // Parquet footer cache.
+ var footers: Map[FileStatus, Footer] = _
+
+ // `FileStatus` objects of all data files (Parquet part-files).
var dataStatuses: Array[FileStatus] = _
+
+ // Partition spec of this table, including names, data types, and values of each partition
+ // column, and paths of each partition.
var partitionSpec: PartitionSpec = _
+
+ // Schema of the actual Parquet files, without partition columns discovered from partition
+ // directory paths.
+ var parquetSchema: StructType = _
+
+ // Schema of the whole table, including partition columns.
var schema: StructType = _
- var dataSchemaIncludesPartitionKeys: Boolean = _
+ // Indicates whether partition columns are also included in Parquet data file schema. If not,
+ // we need to fill in partition column values into read rows when scanning the table.
+ var partitionKeysIncludedInParquetSchema: Boolean = _
+
+ def prepareMetadata(path: Path, schema: StructType, conf: Configuration): Unit = {
+ conf.set(
+ ParquetOutputFormat.COMPRESSION,
+ ParquetRelation
+ .shortParquetCompressionCodecNames
+ .getOrElse(
+ sqlContext.conf.parquetCompressionCodec.toUpperCase,
+ CompressionCodecName.UNCOMPRESSED).name())
+
+ ParquetRelation.enableLogForwarding()
+ ParquetTypesConverter.writeMetaData(schema.toAttributes, path, conf)
+ }
+
+ /**
+ * Refreshes `FileStatus`es, footers, partition spec, and table schema.
+ */
def refresh(): Unit = {
- val baseStatuses = {
- val statuses = paths.distinct.map(p => fs.getFileStatus(fs.makeQualified(new Path(p))))
- // Support either reading a collection of raw Parquet part-files, or a collection of folders
- // containing Parquet files (e.g. partitioned Parquet table).
- assert(statuses.forall(!_.isDir) || statuses.forall(_.isDir))
- statuses.toArray
- }
+ val fs = FileSystem.get(sparkContext.hadoopConfiguration)
+ // Support either reading a collection of raw Parquet part-files, or a collection of folders
+ // containing Parquet files (e.g. partitioned Parquet table).
+ val baseStatuses = paths.distinct.map { p =>
+ val qualified = fs.makeQualified(new Path(p))
+
+ if (!fs.exists(qualified) && maybeSchema.isDefined) {
+ fs.mkdirs(qualified)
+ prepareMetadata(qualified, maybeSchema.get, sparkContext.hadoopConfiguration)
+ }
+
+ fs.getFileStatus(qualified)
+ }.toArray
+ assert(baseStatuses.forall(!_.isDir) || baseStatuses.forall(_.isDir))
+
+ // Lists `FileStatus`es of all leaf nodes (files) under all base directories.
val leaves = baseStatuses.flatMap { f =>
- val statuses = SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f =>
+ SparkHadoopUtil.get.listLeafStatuses(fs, f.getPath).filter { f =>
isSummaryFile(f.getPath) ||
!(f.getPath.getName.startsWith("_") || f.getPath.getName.startsWith("."))
}
- assert(statuses.nonEmpty, s"${f.getPath} is an empty folder.")
- statuses
}
dataStatuses = leaves.filterNot(f => isSummaryFile(f.getPath))
@@ -198,13 +272,14 @@ case class ParquetRelation2
f -> new Footer(f.getPath, parquetMetadata)
}.seq.toMap
- partitionSpec = {
- val partitionDirs = dataStatuses
+ partitionSpec = maybePartitionSpec.getOrElse {
+ val partitionDirs = leaves
.filterNot(baseStatuses.contains)
.map(_.getPath.getParent)
.distinct
if (partitionDirs.nonEmpty) {
+ // Parses names and values of partition columns, and infer their data types.
ParquetRelation2.parsePartitions(partitionDirs, defaultPartitionName)
} else {
// No partition directories found, makes an empty specification
@@ -212,26 +287,37 @@ case class ParquetRelation2
}
}
- parquetSchema = maybeSchema.getOrElse(readSchema())
-
- dataSchemaIncludesPartitionKeys =
+ // To get the schema. We first try to get the schema defined in maybeSchema.
+ // If maybeSchema is not defined, we will try to get the schema from existing parquet data
+ // (through readSchema). If data does not exist, we will try to get the schema defined in
+ // maybeMetastoreSchema (defined in the options of the data source).
+ // Finally, if we still could not get the schema. We throw an error.
+ parquetSchema =
+ maybeSchema
+ .orElse(readSchema())
+ .orElse(maybeMetastoreSchema)
+ .getOrElse(sys.error("Failed to get the schema."))
+
+ partitionKeysIncludedInParquetSchema =
isPartitioned &&
- partitionColumns.forall(f => metadataCache.parquetSchema.fieldNames.contains(f.name))
+ partitionColumns.forall(f => parquetSchema.fieldNames.contains(f.name))
schema = {
- val fullParquetSchema = if (dataSchemaIncludesPartitionKeys) {
- metadataCache.parquetSchema
+ val fullRelationSchema = if (partitionKeysIncludedInParquetSchema) {
+ parquetSchema
} else {
- StructType(metadataCache.parquetSchema.fields ++ partitionColumns.fields)
+ StructType(parquetSchema.fields ++ partitionColumns.fields)
}
+ // If this Parquet relation is converted from a Hive Metastore table, must reconcile case
+ // insensitivity issue and possible schema mismatch.
maybeMetastoreSchema
- .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullParquetSchema))
- .getOrElse(fullParquetSchema)
+ .map(ParquetRelation2.mergeMetastoreParquetSchema(_, fullRelationSchema))
+ .getOrElse(fullRelationSchema)
}
}
- private def readSchema(): StructType = {
+ private def readSchema(): Option[StructType] = {
// Sees which file(s) we need to touch in order to figure out the schema.
val filesToTouch =
// Always tries the summary files first if users don't require a merged schema. In this case,
@@ -276,13 +362,17 @@ case class ParquetRelation2
@transient private val metadataCache = new MetadataCache
metadataCache.refresh()
- private def partitionColumns = metadataCache.partitionSpec.partitionColumns
+ def partitionSpec = metadataCache.partitionSpec
- private def partitions = metadataCache.partitionSpec.partitions
+ def partitionColumns = metadataCache.partitionSpec.partitionColumns
- private def isPartitioned = partitionColumns.nonEmpty
+ def partitions = metadataCache.partitionSpec.partitions
- private def dataSchemaIncludesPartitionKeys = metadataCache.dataSchemaIncludesPartitionKeys
+ def isPartitioned = partitionColumns.nonEmpty
+
+ private def partitionKeysIncludedInDataSchema = metadataCache.partitionKeysIncludedInParquetSchema
+
+ private def parquetSchema = metadataCache.parquetSchema
override def schema = metadataCache.schema
@@ -310,6 +400,7 @@ case class ParquetRelation2
} else {
metadataCache.dataStatuses.toSeq
}
+ val selectedFooters = selectedFiles.map(metadataCache.footers)
// FileInputFormat cannot handle empty lists.
if (selectedFiles.nonEmpty) {
@@ -357,11 +448,16 @@ case class ParquetRelation2
@transient
val cachedStatus = selectedFiles
+ @transient
+ val cachedFooters = selectedFooters
+
// Overridden so we can inject our own cached files statuses.
override def getPartitions: Array[SparkPartition] = {
val inputFormat = if (cacheMetadata) {
new FilteringParquetRowInputFormat {
override def listStatus(jobContext: JobContext): JList[FileStatus] = cachedStatus
+
+ override def getFooters(jobContext: JobContext): JList[Footer] = cachedFooters
}
} else {
new FilteringParquetRowInputFormat
@@ -385,21 +481,54 @@ case class ParquetRelation2
// When the data does not include the key and the key is requested then we must fill it in
// based on information from the input split.
- if (!dataSchemaIncludesPartitionKeys && partitionKeyLocations.nonEmpty) {
+ if (!partitionKeysIncludedInDataSchema && partitionKeyLocations.nonEmpty) {
+ // This check is based on CatalystConverter.createRootConverter.
+ val primitiveRow =
+ requestedSchema.forall(a => ParquetTypesConverter.isPrimitiveType(a.dataType))
+
baseRDD.mapPartitionsWithInputSplit { case (split: ParquetInputSplit, iterator) =>
val partValues = selectedPartitions.collectFirst {
case p if split.getPath.getParent.toString == p.path => p.values
}.get
- iterator.map { pair =>
- val row = pair._2.asInstanceOf[SpecificMutableRow]
- var i = 0
- while (i < partValues.size) {
- // TODO Avoids boxing cost here!
- row.update(partitionKeyLocations(i), partValues(i))
- i += 1
+ val requiredPartOrdinal = partitionKeyLocations.keys.toSeq
+
+ if (primitiveRow) {
+ iterator.map { pair =>
+ // We are using CatalystPrimitiveRowConverter and it returns a SpecificMutableRow.
+ val row = pair._2.asInstanceOf[SpecificMutableRow]
+ var i = 0
+ while (i < requiredPartOrdinal.size) {
+ // TODO Avoids boxing cost here!
+ val partOrdinal = requiredPartOrdinal(i)
+ row.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal))
+ i += 1
+ }
+ row
+ }
+ } else {
+ // Create a mutable row since we need to fill in values from partition columns.
+ val mutableRow = new GenericMutableRow(requestedSchema.size)
+ iterator.map { pair =>
+ // We are using CatalystGroupConverter and it returns a GenericRow.
+ // Since GenericRow is not mutable, we just cast it to a Row.
+ val row = pair._2.asInstanceOf[Row]
+ var i = 0
+ while (i < row.size) {
+ // TODO Avoids boxing cost here!
+ mutableRow(i) = row(i)
+ i += 1
+ }
+
+ i = 0
+ while (i < requiredPartOrdinal.size) {
+ // TODO Avoids boxing cost here!
+ val partOrdinal = requiredPartOrdinal(i)
+ mutableRow.update(partitionKeyLocations(partOrdinal), partValues(partOrdinal))
+ i += 1
+ }
+ mutableRow
}
- row
}
}
} else {
@@ -415,7 +544,8 @@ case class ParquetRelation2
_.references.map(_.name).toSet.subsetOf(partitionColumnNames)
}
- val rawPredicate = partitionPruningPredicates.reduceOption(And).getOrElse(Literal(true))
+ val rawPredicate =
+ partitionPruningPredicates.reduceOption(expressions.And).getOrElse(Literal(true))
val boundPredicate = InterpretedPredicate(rawPredicate transform {
case a: AttributeReference =>
val index = partitionColumns.indexWhere(a.name == _.name)
@@ -430,6 +560,8 @@ case class ParquetRelation2
}
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
+ assert(paths.size == 1, s"Can't write to multiple destinations: ${paths.mkString(",")}")
+
// TODO: currently we do not check whether the "schema"s are compatible
// That means if one first creates a table and then INSERTs data with
// and incompatible schema the execution will fail. It would be nice
@@ -437,7 +569,7 @@ case class ParquetRelation2
// before calling execute().
val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
- val writeSupport = if (schema.map(_.dataType).forall(_.isPrimitive)) {
+ val writeSupport = if (parquetSchema.map(_.dataType).forall(_.isPrimitive)) {
log.debug("Initializing MutableRowWriteSupport")
classOf[MutableRowWriteSupport]
} else {
@@ -447,7 +579,7 @@ case class ParquetRelation2
ParquetOutputFormat.setWriteSupportClass(job, writeSupport)
val conf = ContextUtil.getConfiguration(job)
- RowWriteSupport.setSchema(schema.toAttributes, conf)
+ RowWriteSupport.setSchema(data.schema.toAttributes, conf)
val destinationPath = new Path(paths.head)
@@ -513,20 +645,19 @@ case class ParquetRelation2
}
}
-object ParquetRelation2 {
+private[sql] object ParquetRelation2 {
// Whether we should merge schemas collected from all Parquet part-files.
val MERGE_SCHEMA = "mergeSchema"
- // Hive Metastore schema, passed in when the Parquet relation is converted from Metastore
- val METASTORE_SCHEMA = "metastoreSchema"
-
- // Default partition name to use when the partition column value is null or empty string
+ // Default partition name to use when the partition column value is null or empty string.
val DEFAULT_PARTITION_NAME = "partition.defaultName"
- // When true, the Parquet data source caches Parquet metadata for performance
- val CACHE_METADATA = "cacheMetadata"
+ // Hive Metastore schema, used when converting Metastore Parquet tables. This option is only used
+ // internally.
+ private[sql] val METASTORE_SCHEMA = "metastoreSchema"
- private[parquet] def readSchema(footers: Seq[Footer], sqlContext: SQLContext): StructType = {
+ private[parquet] def readSchema(
+ footers: Seq[Footer], sqlContext: SQLContext): Option[StructType] = {
footers.map { footer =>
val metadata = footer.getParquetMetadata.getFileMetaData
val parquetSchema = metadata.getSchema
@@ -545,13 +676,22 @@ object ParquetRelation2 {
sqlContext.conf.isParquetBinaryAsString,
sqlContext.conf.isParquetINT96AsTimestamp))
}
- }.reduce { (left, right) =>
+ }.reduceOption { (left, right) =>
try left.merge(right) catch { case e: Throwable =>
throw new SparkException(s"Failed to merge incompatible schemas $left and $right", e)
}
}
}
+ /**
+ * Reconciles Hive Metastore case insensitivity issue and data type conflicts between Metastore
+ * schema and Parquet schema.
+ *
+ * Hive doesn't retain case information, while Parquet is case sensitive. On the other hand, the
+ * schema read from Parquet files may be incomplete (e.g. older versions of Parquet doesn't
+ * distinguish binary and string). This method generates a correct schema by merging Metastore
+ * schema data types and Parquet schema field names.
+ */
private[parquet] def mergeMetastoreParquetSchema(
metastoreSchema: StructType,
parquetSchema: StructType): StructType = {
@@ -692,16 +832,15 @@ object ParquetRelation2 {
* }}}
*/
private[parquet] def resolvePartitions(values: Seq[PartitionValues]): Seq[PartitionValues] = {
- val distinctColNamesOfPartitions = values.map(_.columnNames).distinct
- val columnCount = values.head.columnNames.size
-
// Column names of all partitions must match
- assert(distinctColNamesOfPartitions.size == 1, {
- val list = distinctColNamesOfPartitions.mkString("\t", "\n", "")
+ val distinctPartitionsColNames = values.map(_.columnNames).distinct
+ assert(distinctPartitionsColNames.size == 1, {
+ val list = distinctPartitionsColNames.mkString("\t", "\n", "")
s"Conflicting partition column names detected:\n$list"
})
// Resolves possible type conflicts for each column
+ val columnCount = values.head.columnNames.size
val resolvedValues = (0 until columnCount).map { i =>
resolveTypeConflicts(values.map(_.literals(i)))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala
index 887161684429f..e24475292ceaf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/timestamp/NanoTime.scala
@@ -53,7 +53,7 @@ private[parquet] class NanoTime extends Serializable {
"NanoTime{julianDay=" + julianDay + ", timeOfDayNanos=" + timeOfDayNanos + "}"
}
-object NanoTime {
+private[sql] object NanoTime {
def fromBinary(bytes: Binary): NanoTime = {
Preconditions.checkArgument(bytes.length() == 12, "Must be 12 bytes")
val buf = bytes.toByteBuffer
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
index 624369afe87b5..67f3507c61ab6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql.sources
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions
-import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.{Row, Strategy, execution}
+import org.apache.spark.sql.{Row, Strategy, execution, sources}
/**
* A Strategy for planning scans over data sources defined using the sources API.
@@ -55,10 +55,7 @@ private[sql] object DataSourceStrategy extends Strategy {
execution.PhysicalRDD(l.output, t.buildScan()) :: Nil
case i @ logical.InsertIntoTable(
- l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) =>
- if (partition.nonEmpty) {
- sys.error(s"Insert into a partition is not allowed because $l is not partitioned.")
- }
+ l @ LogicalRelation(t: InsertableRelation), part, query, overwrite) if part.isEmpty =>
execution.ExecutedCommand(InsertIntoDataSource(l, query, overwrite)) :: Nil
case _ => Nil
@@ -88,7 +85,7 @@ private[sql] object DataSourceStrategy extends Strategy {
val projectSet = AttributeSet(projectList.flatMap(_.references))
val filterSet = AttributeSet(filterPredicates.flatMap(_.references))
- val filterCondition = filterPredicates.reduceLeftOption(And)
+ val filterCondition = filterPredicates.reduceLeftOption(expressions.And)
val pushedFilters = filterPredicates.map { _ transform {
case a: AttributeReference => relation.attributeMap(a) // Match original case of attributes.
@@ -118,27 +115,60 @@ private[sql] object DataSourceStrategy extends Strategy {
}
}
- /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */
- protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect {
- case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v)
- case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v)
-
- case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v)
- case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v)
-
- case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v)
- case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v)
-
- case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
- GreaterThanOrEqual(a.name, v)
- case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
- LessThanOrEqual(a.name, v)
-
- case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) =>
- LessThanOrEqual(a.name, v)
- case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) =>
- GreaterThanOrEqual(a.name, v)
+ /**
+ * Selects Catalyst predicate [[Expression]]s which are convertible into data source [[Filter]]s,
+ * and convert them.
+ */
+ protected[sql] def selectFilters(filters: Seq[Expression]) = {
+ def translate(predicate: Expression): Option[Filter] = predicate match {
+ case expressions.EqualTo(a: Attribute, Literal(v, _)) =>
+ Some(sources.EqualTo(a.name, v))
+ case expressions.EqualTo(Literal(v, _), a: Attribute) =>
+ Some(sources.EqualTo(a.name, v))
+
+ case expressions.GreaterThan(a: Attribute, Literal(v, _)) =>
+ Some(sources.GreaterThan(a.name, v))
+ case expressions.GreaterThan(Literal(v, _), a: Attribute) =>
+ Some(sources.LessThan(a.name, v))
+
+ case expressions.LessThan(a: Attribute, Literal(v, _)) =>
+ Some(sources.LessThan(a.name, v))
+ case expressions.LessThan(Literal(v, _), a: Attribute) =>
+ Some(sources.GreaterThan(a.name, v))
+
+ case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) =>
+ Some(sources.GreaterThanOrEqual(a.name, v))
+ case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) =>
+ Some(sources.LessThanOrEqual(a.name, v))
+
+ case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) =>
+ Some(sources.LessThanOrEqual(a.name, v))
+ case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) =>
+ Some(sources.GreaterThanOrEqual(a.name, v))
+
+ case expressions.InSet(a: Attribute, set) =>
+ Some(sources.In(a.name, set.toArray))
+
+ case expressions.IsNull(a: Attribute) =>
+ Some(sources.IsNull(a.name))
+ case expressions.IsNotNull(a: Attribute) =>
+ Some(sources.IsNotNull(a.name))
+
+ case expressions.And(left, right) =>
+ (translate(left) ++ translate(right)).reduceOption(sources.And)
+
+ case expressions.Or(left, right) =>
+ for {
+ leftFilter <- translate(left)
+ rightFilter <- translate(right)
+ } yield sources.Or(leftFilter, rightFilter)
+
+ case expressions.Not(child) =>
+ translate(child).map(sources.Not)
+
+ case _ => None
+ }
- case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray)
+ filters.flatMap(translate)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
index a692ef51b31ed..5020689f7a105 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.sources
import scala.language.implicitConversions
import org.apache.spark.Logging
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.AbstractSparkSQLParser
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Row}
import org.apache.spark.sql.execution.RunnableCommand
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -32,7 +32,8 @@ import org.apache.spark.util.Utils
/**
* A parser for foreign DDL commands.
*/
-private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
+private[sql] class DDLParser(
+ parseQuery: String => LogicalPlan) extends AbstractSparkSQLParser with Logging {
def apply(input: String, exceptionOnError: Boolean): Option[LogicalPlan] = {
try {
@@ -66,6 +67,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
protected val EXTENDED = Keyword("EXTENDED")
protected val AS = Keyword("AS")
protected val COMMENT = Keyword("COMMENT")
+ protected val REFRESH = Keyword("REFRESH")
// Data types.
protected val STRING = Keyword("STRING")
@@ -85,7 +87,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
protected val MAP = Keyword("MAP")
protected val STRUCT = Keyword("STRUCT")
- protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable
+ protected lazy val ddl: Parser[LogicalPlan] = createTable | describeTable | refreshTable
protected def start: Parser[LogicalPlan] = ddl
@@ -104,26 +106,37 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
* AS SELECT ...
*/
protected lazy val createTable: Parser[LogicalPlan] =
- (
- (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident
- ~ (tableCols).? ~ (USING ~> className) ~ (OPTIONS ~> options) ~ (AS ~> restInput).? ^^ {
+ // TODO: Support database.table.
+ (CREATE ~> TEMPORARY.? <~ TABLE) ~ (IF ~> NOT <~ EXISTS).? ~ ident ~
+ tableCols.? ~ (USING ~> className) ~ (OPTIONS ~> options).? ~ (AS ~> restInput).? ^^ {
case temp ~ allowExisting ~ tableName ~ columns ~ provider ~ opts ~ query =>
if (temp.isDefined && allowExisting.isDefined) {
throw new DDLException(
"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.")
}
+ val options = opts.getOrElse(Map.empty[String, String])
if (query.isDefined) {
if (columns.isDefined) {
throw new DDLException(
"a CREATE TABLE AS SELECT statement does not allow column definitions.")
}
+ // When IF NOT EXISTS clause appears in the query, the save mode will be ignore.
+ val mode = if (allowExisting.isDefined) {
+ SaveMode.Ignore
+ } else if (temp.isDefined) {
+ SaveMode.Overwrite
+ } else {
+ SaveMode.ErrorIfExists
+ }
+
+ val queryPlan = parseQuery(query.get)
CreateTableUsingAsSelect(tableName,
provider,
temp.isDefined,
- opts,
- allowExisting.isDefined,
- query.get)
+ mode,
+ options,
+ queryPlan)
} else {
val userSpecifiedSchema = columns.flatMap(fields => Some(StructType(fields)))
CreateTableUsing(
@@ -131,17 +144,17 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
userSpecifiedSchema,
provider,
temp.isDefined,
- opts,
- allowExisting.isDefined)
+ options,
+ allowExisting.isDefined,
+ managedIfNoPath = false)
}
- }
- )
+ }
protected lazy val tableCols: Parser[Seq[StructField]] = "(" ~> repsep(column, ",") <~ ")"
/*
* describe [extended] table avroTable
- * This will display all columns of table `avroTable` includes column_name,column_type,nullable
+ * This will display all columns of table `avroTable` includes column_name,column_type,comment
*/
protected lazy val describeTable: Parser[LogicalPlan] =
(DESCRIBE ~> opt(EXTENDED)) ~ (ident <~ ".").? ~ ident ^^ {
@@ -155,6 +168,12 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
DescribeCommand(UnresolvedRelation(tblIdentifier, None), e.isDefined)
}
+ protected lazy val refreshTable: Parser[LogicalPlan] =
+ REFRESH ~> TABLE ~> (ident <~ ".").? ~ ident ^^ {
+ case maybeDatabaseName ~ tableName =>
+ RefreshTable(maybeDatabaseName.getOrElse("default"), tableName)
+ }
+
protected lazy val options: Parser[Map[String, String]] =
"(" ~> repsep(pair, ",") <~ ")" ^^ { case s: Seq[(String, String)] => s.toMap }
@@ -166,10 +185,10 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm =>
val meta = cm match {
case Some(comment) =>
- new MetadataBuilder().putString(COMMENT.str.toLowerCase(), comment).build()
+ new MetadataBuilder().putString(COMMENT.str.toLowerCase, comment).build()
case None => Metadata.empty
}
- StructField(columnName, typ, true, meta)
+ StructField(columnName, typ, nullable = true, meta)
}
protected lazy val primitiveType: Parser[DataType] =
@@ -223,64 +242,73 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging {
primitiveType
}
-object ResolvedDataSource {
- def apply(
- sqlContext: SQLContext,
- userSpecifiedSchema: Option[StructType],
- provider: String,
- options: Map[String, String]): ResolvedDataSource = {
+private[sql] object ResolvedDataSource {
+
+ private val builtinSources = Map(
+ "jdbc" -> classOf[org.apache.spark.sql.jdbc.DefaultSource],
+ "json" -> classOf[org.apache.spark.sql.json.DefaultSource],
+ "parquet" -> classOf[org.apache.spark.sql.parquet.DefaultSource]
+ )
+
+ /** Given a provider name, look up the data source class definition. */
+ def lookupDataSource(provider: String): Class[_] = {
+ if (builtinSources.contains(provider)) {
+ return builtinSources(provider)
+ }
+
val loader = Utils.getContextOrSparkClassLoader
- val clazz: Class[_] = try loader.loadClass(provider) catch {
+ try {
+ loader.loadClass(provider)
+ } catch {
case cnf: java.lang.ClassNotFoundException =>
- try loader.loadClass(provider + ".DefaultSource") catch {
+ try {
+ loader.loadClass(provider + ".DefaultSource")
+ } catch {
case cnf: java.lang.ClassNotFoundException =>
sys.error(s"Failed to load class for data source: $provider")
}
}
+ }
+ /** Create a [[ResolvedDataSource]] for reading data in. */
+ def apply(
+ sqlContext: SQLContext,
+ userSpecifiedSchema: Option[StructType],
+ provider: String,
+ options: Map[String, String]): ResolvedDataSource = {
+ val clazz: Class[_] = lookupDataSource(provider)
val relation = userSpecifiedSchema match {
- case Some(schema: StructType) => {
- clazz.newInstance match {
- case dataSource: SchemaRelationProvider =>
- dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
- case dataSource: org.apache.spark.sql.sources.RelationProvider =>
- sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
- }
+ case Some(schema: StructType) => clazz.newInstance() match {
+ case dataSource: SchemaRelationProvider =>
+ dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options), schema)
+ case dataSource: org.apache.spark.sql.sources.RelationProvider =>
+ sys.error(s"${clazz.getCanonicalName} does not allow user-specified schemas.")
}
- case None => {
- clazz.newInstance match {
- case dataSource: RelationProvider =>
- dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
- case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
- sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.")
- }
+
+ case None => clazz.newInstance() match {
+ case dataSource: RelationProvider =>
+ dataSource.createRelation(sqlContext, new CaseInsensitiveMap(options))
+ case dataSource: org.apache.spark.sql.sources.SchemaRelationProvider =>
+ sys.error(s"A schema needs to be specified when using ${clazz.getCanonicalName}.")
}
}
-
new ResolvedDataSource(clazz, relation)
}
+ /** Create a [[ResolvedDataSource]] for saving the content of the given [[DataFrame]]. */
def apply(
sqlContext: SQLContext,
provider: String,
+ mode: SaveMode,
options: Map[String, String],
data: DataFrame): ResolvedDataSource = {
- val loader = Utils.getContextOrSparkClassLoader
- val clazz: Class[_] = try loader.loadClass(provider) catch {
- case cnf: java.lang.ClassNotFoundException =>
- try loader.loadClass(provider + ".DefaultSource") catch {
- case cnf: java.lang.ClassNotFoundException =>
- sys.error(s"Failed to load class for data source: $provider")
- }
- }
-
- val relation = clazz.newInstance match {
+ val clazz: Class[_] = lookupDataSource(provider)
+ val relation = clazz.newInstance() match {
case dataSource: CreatableRelationProvider =>
- dataSource.createRelation(sqlContext, options, data)
+ dataSource.createRelation(sqlContext, mode, options, data)
case _ =>
sys.error(s"${clazz.getCanonicalName} does not allow create table as select.")
}
-
new ResolvedDataSource(clazz, relation)
}
}
@@ -298,39 +326,47 @@ private[sql] case class DescribeCommand(
isExtended: Boolean) extends Command {
override val output = Seq(
// Column names are based on Hive.
- AttributeReference("col_name", StringType, nullable = false,
+ AttributeReference("col_name", StringType, nullable = false,
new MetadataBuilder().putString("comment", "name of the column").build())(),
- AttributeReference("data_type", StringType, nullable = false,
+ AttributeReference("data_type", StringType, nullable = false,
new MetadataBuilder().putString("comment", "data type of the column").build())(),
- AttributeReference("comment", StringType, nullable = false,
+ AttributeReference("comment", StringType, nullable = false,
new MetadataBuilder().putString("comment", "comment of the column").build())())
}
+/**
+ * Used to represent the operation of create table using a data source.
+ * @param allowExisting If it is true, we will do nothing when the table already exists.
+ * If it is false, an exception will be thrown
+ */
private[sql] case class CreateTableUsing(
tableName: String,
userSpecifiedSchema: Option[StructType],
provider: String,
temporary: Boolean,
options: Map[String, String],
- allowExisting: Boolean) extends Command
-
-private[sql] case class CreateTableUsingAsSelect(
- tableName: String,
- provider: String,
- temporary: Boolean,
- options: Map[String, String],
allowExisting: Boolean,
- query: String) extends Command
+ managedIfNoPath: Boolean) extends Command
-private[sql] case class CreateTableUsingAsLogicalPlan(
+/**
+ * A node used to support CTAS statements and saveAsTable for the data source API.
+ * This node is a [[UnaryNode]] instead of a [[Command]] because we want the analyzer
+ * can analyze the logical plan that will be used to populate the table.
+ * So, [[PreWriteCheck]] can detect cases that are not allowed.
+ */
+private[sql] case class CreateTableUsingAsSelect(
tableName: String,
provider: String,
temporary: Boolean,
+ mode: SaveMode,
options: Map[String, String],
- allowExisting: Boolean,
- query: LogicalPlan) extends Command
+ child: LogicalPlan) extends UnaryNode {
+ override def output = Seq.empty[Attribute]
+ // TODO: Override resolved after we support databaseName.
+ // override lazy val resolved = databaseName != None && childrenResolved
+}
-private [sql] case class CreateTempTableUsing(
+private[sql] case class CreateTempTableUsing(
tableName: String,
userSpecifiedSchema: Option[StructType],
provider: String,
@@ -338,32 +374,42 @@ private [sql] case class CreateTempTableUsing(
def run(sqlContext: SQLContext) = {
val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options)
- sqlContext.registerRDDAsTable(
+ sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
Seq.empty
}
}
-private [sql] case class CreateTempTableUsingAsSelect(
+private[sql] case class CreateTempTableUsingAsSelect(
tableName: String,
provider: String,
+ mode: SaveMode,
options: Map[String, String],
query: LogicalPlan) extends RunnableCommand {
def run(sqlContext: SQLContext) = {
val df = DataFrame(sqlContext, query)
- val resolved = ResolvedDataSource(sqlContext, provider, options, df)
- sqlContext.registerRDDAsTable(
+ val resolved = ResolvedDataSource(sqlContext, provider, mode, options, df)
+ sqlContext.registerDataFrameAsTable(
DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName)
Seq.empty
}
}
+private[sql] case class RefreshTable(databaseName: String, tableName: String)
+ extends RunnableCommand {
+
+ override def run(sqlContext: SQLContext): Seq[Row] = {
+ sqlContext.catalog.refreshTable(databaseName, tableName)
+ Seq.empty[Row]
+ }
+}
+
/**
* Builds a map in which keys are case insensitive
*/
-protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
+protected[sql] class CaseInsensitiveMap(map: Map[String, String]) extends Map[String, String]
with Serializable {
val baseMap = map.map(kv => kv.copy(_1 = kv._1.toLowerCase))
@@ -375,11 +421,10 @@ protected class CaseInsensitiveMap(map: Map[String, String]) extends Map[String,
override def iterator: Iterator[(String, String)] = baseMap.iterator
- override def -(key: String): Map[String, String] = baseMap - key.toLowerCase()
+ override def -(key: String): Map[String, String] = baseMap - key.toLowerCase
}
/**
* The exception thrown from the DDL parser.
- * @param message
*/
protected[sql] class DDLException(message: String) extends Exception(message)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
index 4a9fefc12b9ad..1e4505e36d2f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala
@@ -25,3 +25,8 @@ case class GreaterThanOrEqual(attribute: String, value: Any) extends Filter
case class LessThan(attribute: String, value: Any) extends Filter
case class LessThanOrEqual(attribute: String, value: Any) extends Filter
case class In(attribute: String, values: Array[Any]) extends Filter
+case class IsNull(attribute: String) extends Filter
+case class IsNotNull(attribute: String) extends Filter
+case class And(left: Filter, right: Filter) extends Filter
+case class Or(left: Filter, right: Filter) extends Filter
+case class Not(child: Filter) extends Filter
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
index 5eecc303ef72b..0c4b706eeebae 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.sources
import org.apache.spark.annotation.{Experimental, DeveloperApi}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{SaveMode, DataFrame, Row, SQLContext}
import org.apache.spark.sql.catalyst.expressions.{Expression, Attribute}
import org.apache.spark.sql.types.StructType
@@ -79,8 +79,27 @@ trait SchemaRelationProvider {
@DeveloperApi
trait CreatableRelationProvider {
+ /**
+ * Creates a relation with the given parameters based on the contents of the given
+ * DataFrame. The mode specifies the expected behavior of createRelation when
+ * data already exists.
+ * Right now, there are three modes, Append, Overwrite, and ErrorIfExists.
+ * Append mode means that when saving a DataFrame to a data source, if data already exists,
+ * contents of the DataFrame are expected to be appended to existing data.
+ * Overwrite mode means that when saving a DataFrame to a data source, if data already exists,
+ * existing data is expected to be overwritten by the contents of the DataFrame.
+ * ErrorIfExists mode means that when saving a DataFrame to a data source,
+ * if data already exists, an exception is expected to be thrown.
+ *
+ * @param sqlContext
+ * @param mode
+ * @param parameters
+ * @param data
+ * @return
+ */
def createRelation(
sqlContext: SQLContext,
+ mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
index 4ed22d363da5b..8440581074877 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/rules.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.sources
+import org.apache.spark.sql.{SaveMode, AnalysisException}
+import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, Catalog}
import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Alias}
+import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.types.DataType
@@ -26,11 +29,9 @@ import org.apache.spark.sql.types.DataType
* A rule to do pre-insert data type casting and field renaming. Before we insert into
* an [[InsertableRelation]], we will use this rule to make sure that
* the columns to be inserted have the correct data type and fields have the correct names.
- * @param resolver The resolver used by the Analyzer.
*/
private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = {
- plan.transform {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Wait until children are resolved.
case p: LogicalPlan if !p.childrenResolved => p
@@ -46,7 +47,6 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
}
castAndRenameChildOutput(i, l.output, child)
}
- }
}
/** If necessary, cast data types and rename fields to the expected types and names. */
@@ -74,3 +74,65 @@ private[sql] object PreInsertCastAndRename extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * A rule to do various checks before inserting into or writing to a data source table.
+ */
+private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => Unit) {
+ def failAnalysis(msg: String) = { throw new AnalysisException(msg) }
+
+ def apply(plan: LogicalPlan): Unit = {
+ plan.foreach {
+ case i @ logical.InsertIntoTable(
+ l @ LogicalRelation(t: InsertableRelation), partition, query, overwrite) =>
+ // Right now, we do not support insert into a data source table with partition specs.
+ if (partition.nonEmpty) {
+ failAnalysis(s"Insert into a partition is not allowed because $l is not partitioned.")
+ } else {
+ // Get all input data source relations of the query.
+ val srcRelations = query.collect {
+ case LogicalRelation(src: BaseRelation) => src
+ }
+ if (srcRelations.contains(t)) {
+ failAnalysis(
+ "Cannot insert overwrite into table that is also being read from.")
+ } else {
+ // OK
+ }
+ }
+
+ case i @ logical.InsertIntoTable(
+ l: LogicalRelation, partition, query, overwrite) if !l.isInstanceOf[InsertableRelation] =>
+ // The relation in l is not an InsertableRelation.
+ failAnalysis(s"$l does not allow insertion.")
+
+ case CreateTableUsingAsSelect(tableName, _, _, SaveMode.Overwrite, _, query) =>
+ // When the SaveMode is Overwrite, we need to check if the table is an input table of
+ // the query. If so, we will throw an AnalysisException to let users know it is not allowed.
+ if (catalog.tableExists(Seq(tableName))) {
+ // Need to remove SubQuery operator.
+ EliminateSubQueries(catalog.lookupRelation(Seq(tableName))) match {
+ // Only do the check if the table is a data source table
+ // (the relation is a BaseRelation).
+ case l @ LogicalRelation(dest: BaseRelation) =>
+ // Get all input data source relations of the query.
+ val srcRelations = query.collect {
+ case LogicalRelation(src: BaseRelation) => src
+ }
+ if (srcRelations.contains(dest)) {
+ failAnalysis(
+ s"Cannot overwrite table $tableName that is also being read from.")
+ } else {
+ // OK
+ }
+
+ case _ => // OK
+ }
+ } else {
+ // OK
+ }
+
+ case _ => // OK
+ }
+ }
+}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
deleted file mode 100644
index 639436368c4a3..0000000000000
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java
+++ /dev/null
@@ -1,120 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.api.java;
-
-import com.google.common.collect.ImmutableMap;
-
-import org.apache.spark.sql.Column;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.types.DataTypes;
-
-import static org.apache.spark.sql.Dsl.*;
-
-/**
- * This test doesn't actually run anything. It is here to check the API compatibility for Java.
- */
-public class JavaDsl {
-
- public static void testDataFrame(final DataFrame df) {
- DataFrame df1 = df.select("colA");
- df1 = df.select("colA", "colB");
-
- df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1));
-
- df1 = df.filter(col("colA"));
-
- java.util.Map aggExprs = ImmutableMap.builder()
- .put("colA", "sum")
- .put("colB", "avg")
- .build();
-
- df1 = df.agg(aggExprs);
-
- df1 = df.groupBy("groupCol").agg(aggExprs);
-
- df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer");
-
- df.orderBy("colA");
- df.orderBy("colA", "colB", "colC");
- df.orderBy(col("colA").desc());
- df.orderBy(col("colA").desc(), col("colB").asc());
-
- df.sort("colA");
- df.sort("colA", "colB", "colC");
- df.sort(col("colA").desc());
- df.sort(col("colA").desc(), col("colB").asc());
-
- df.as("b");
-
- df.limit(5);
-
- df.unionAll(df1);
- df.intersect(df1);
- df.except(df1);
-
- df.sample(true, 0.1, 234);
-
- df.head();
- df.head(5);
- df.first();
- df.count();
- }
-
- public static void testColumn(final Column c) {
- c.asc();
- c.desc();
-
- c.endsWith("abcd");
- c.startsWith("afgasdf");
-
- c.like("asdf%");
- c.rlike("wef%asdf");
-
- c.as("newcol");
-
- c.cast("int");
- c.cast(DataTypes.IntegerType);
- }
-
- public static void testDsl() {
- // Creating a column.
- Column c = col("abcd");
- Column c1 = column("abcd");
-
- // Literals
- Column l1 = lit(1);
- Column l2 = lit(1.0);
- Column l3 = lit("abcd");
-
- // Functions
- Column a = upper(c);
- a = lower(c);
- a = sqrt(c);
- a = abs(c);
-
- // Aggregates
- a = min(c);
- a = max(c);
- a = sum(c);
- a = sumDistinct(c);
- a = countDistinct(c, a);
- a = avg(c);
- a = first(c);
- a = last(c);
- }
-}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java b/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java
deleted file mode 100644
index 80bd74f5b5525..0000000000000
--- a/sql/core/src/test/java/org/apache/spark/sql/jdbc/JavaJDBCTest.java
+++ /dev/null
@@ -1,102 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.jdbc;
-
-import org.junit.*;
-import static org.junit.Assert.*;
-import java.sql.Connection;
-import java.sql.DriverManager;
-
-import org.apache.spark.SparkEnv;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.Row;
-import org.apache.spark.sql.api.java.*;
-import org.apache.spark.sql.test.TestSQLContext$;
-
-public class JavaJDBCTest {
- static String url = "jdbc:h2:mem:testdb1";
-
- static Connection conn = null;
-
- // This variable will always be null if TestSQLContext is intact when running
- // these tests. Some Java tests do not play nicely with others, however;
- // they create a SparkContext of their own at startup and stop it at exit.
- // This renders TestSQLContext inoperable, meaning we have to do the same
- // thing. If this variable is nonnull, that means we allocated a
- // SparkContext of our own and that we need to stop it at teardown.
- static JavaSparkContext localSparkContext = null;
-
- static SQLContext sql = TestSQLContext$.MODULE$;
-
- @Before
- public void beforeTest() throws Exception {
- if (SparkEnv.get() == null) { // A previous test destroyed TestSQLContext.
- localSparkContext = new JavaSparkContext("local", "JavaAPISuite");
- sql = new SQLContext(localSparkContext);
- }
- Class.forName("org.h2.Driver");
- conn = DriverManager.getConnection(url);
- conn.prepareStatement("create schema test").executeUpdate();
- conn.prepareStatement("create table test.people (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate();
- conn.prepareStatement("insert into test.people values ('fred', 1)").executeUpdate();
- conn.prepareStatement("insert into test.people values ('mary', 2)").executeUpdate();
- conn.prepareStatement("insert into test.people values ('joe', 3)").executeUpdate();
- conn.commit();
- }
-
- @After
- public void afterTest() throws Exception {
- if (localSparkContext != null) {
- localSparkContext.stop();
- localSparkContext = null;
- }
- try {
- conn.close();
- } finally {
- conn = null;
- }
- }
-
- @Test
- public void basicTest() throws Exception {
- DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE");
- Row[] rows = rdd.collect();
- assertEquals(rows.length, 3);
- }
-
- @Test
- public void partitioningTest() throws Exception {
- String[] parts = new String[2];
- parts[0] = "THEID < 2";
- parts[1] = "THEID = 2"; // Deliberately forget about one of them.
- DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE", parts);
- Row[] rows = rdd.collect();
- assertEquals(rows.length, 2);
- }
-
- @Test
- public void writeTest() throws Exception {
- DataFrame rdd = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLE");
- JDBCUtils.createJDBCTable(rdd, url, "TEST.PEOPLECOPY", false);
- DataFrame rdd2 = JDBCUtils.jdbcRDD(sql, url, "TEST.PEOPLECOPY");
- Row[] rows = rdd2.collect();
- assertEquals(rows.length, 3);
- }
-}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
similarity index 90%
rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index 2e6e977fdc752..c344a9b095c52 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.api.java;
+package test.org.apache.spark.sql;
import java.io.Serializable;
import java.util.ArrayList;
@@ -39,18 +39,18 @@
// see http://stackoverflow.com/questions/758570/.
public class JavaApplySchemaSuite implements Serializable {
private transient JavaSparkContext javaCtx;
- private transient SQLContext javaSqlCtx;
+ private transient SQLContext sqlContext;
@Before
public void setUp() {
- javaSqlCtx = TestSQLContext$.MODULE$;
- javaCtx = new JavaSparkContext(javaSqlCtx.sparkContext());
+ sqlContext = TestSQLContext$.MODULE$;
+ javaCtx = new JavaSparkContext(sqlContext.sparkContext());
}
@After
public void tearDown() {
javaCtx = null;
- javaSqlCtx = null;
+ sqlContext = null;
}
public static class Person implements Serializable {
@@ -98,9 +98,9 @@ public Row call(Person person) throws Exception {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- DataFrame df = javaSqlCtx.applySchema(rowRDD, schema);
+ DataFrame df = sqlContext.applySchema(rowRDD, schema);
df.registerTempTable("people");
- Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect();
+ Row[] actual = sqlContext.sql("SELECT * FROM people").collect();
List expected = new ArrayList(2);
expected.add(RowFactory.create("Michael", 29));
@@ -109,8 +109,6 @@ public Row call(Person person) throws Exception {
Assert.assertEquals(expected, Arrays.asList(actual));
}
-
-
@Test
public void dataFrameRDDOperations() {
List personList = new ArrayList(2);
@@ -135,9 +133,9 @@ public Row call(Person person) throws Exception {
fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false));
StructType schema = DataTypes.createStructType(fields);
- DataFrame df = javaSqlCtx.applySchema(rowRDD, schema);
+ DataFrame df = sqlContext.applySchema(rowRDD, schema);
df.registerTempTable("people");
- List actual = javaSqlCtx.sql("SELECT * FROM people").toJavaRDD().map(new Function() {
+ List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() {
public String call(Row row) {
return row.getString(0) + "_" + row.get(1).toString();
@@ -164,7 +162,7 @@ public void applySchemaToJSON() {
fields.add(DataTypes.createStructField("bigInteger", DataTypes.createDecimalType(), true));
fields.add(DataTypes.createStructField("boolean", DataTypes.BooleanType, true));
fields.add(DataTypes.createStructField("double", DataTypes.DoubleType, true));
- fields.add(DataTypes.createStructField("integer", DataTypes.IntegerType, true));
+ fields.add(DataTypes.createStructField("integer", DataTypes.LongType, true));
fields.add(DataTypes.createStructField("long", DataTypes.LongType, true));
fields.add(DataTypes.createStructField("null", DataTypes.StringType, true));
fields.add(DataTypes.createStructField("string", DataTypes.StringType, true));
@@ -189,18 +187,18 @@ public void applySchemaToJSON() {
null,
"this is another simple string."));
- DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD);
+ DataFrame df1 = sqlContext.jsonRDD(jsonRDD);
StructType actualSchema1 = df1.schema();
Assert.assertEquals(expectedSchema, actualSchema1);
df1.registerTempTable("jsonTable1");
- List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList();
+ List actual1 = sqlContext.sql("select * from jsonTable1").collectAsList();
Assert.assertEquals(expectedResult, actual1);
- DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD, expectedSchema);
+ DataFrame df2 = sqlContext.jsonRDD(jsonRDD, expectedSchema);
StructType actualSchema2 = df2.schema();
Assert.assertEquals(expectedSchema, actualSchema2);
df2.registerTempTable("jsonTable2");
- List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList();
+ List actual2 = sqlContext.sql("select * from jsonTable2").collectAsList();
Assert.assertEquals(expectedResult, actual2);
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
new file mode 100644
index 0000000000000..2d586f784ac5a
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.test.TestSQLContext$;
+import static org.apache.spark.sql.functions.*;
+
+
+public class JavaDataFrameSuite {
+ private transient SQLContext context;
+
+ @Before
+ public void setUp() {
+ // Trigger static initializer of TestData
+ TestData$.MODULE$.testData();
+ context = TestSQLContext$.MODULE$;
+ }
+
+ @After
+ public void tearDown() {
+ context = null;
+ }
+
+ @Test
+ public void testExecution() {
+ DataFrame df = context.table("testData").filter("key = 1");
+ Assert.assertEquals(df.select("key").collect()[0].get(0), 1);
+ }
+
+ /**
+ * See SPARK-5904. Abstract vararg methods defined in Scala do not work in Java.
+ */
+ @Test
+ public void testVarargMethods() {
+ DataFrame df = context.table("testData");
+
+ df.toDF("key1", "value1");
+
+ df.select("key", "value");
+ df.select(col("key"), col("value"));
+ df.selectExpr("key", "value + 1");
+
+ df.sort("key", "value");
+ df.sort(col("key"), col("value"));
+ df.orderBy("key", "value");
+ df.orderBy(col("key"), col("value"));
+
+ df.groupBy("key", "value").agg(col("key"), col("value"), sum("value"));
+ df.groupBy(col("key"), col("value")).agg(col("key"), col("value"), sum("value"));
+ df.agg(first("key"), sum("value"));
+
+ df.groupBy().avg("key");
+ df.groupBy().mean("key");
+ df.groupBy().max("key");
+ df.groupBy().min("key");
+ df.groupBy().sum("key");
+
+ // Varargs in column expressions
+ df.groupBy().agg(countDistinct("key", "value"));
+ df.groupBy().agg(countDistinct(col("key"), col("value")));
+ df.select(coalesce(col("key")));
+ }
+
+ @Ignore
+ public void testShow() {
+ // This test case is intended ignored, but to make sure it compiles correctly
+ DataFrame df = context.table("testData");
+ df.show();
+ df.show(1000);
+ }
+}
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java
similarity index 99%
rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java
rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java
index fbfcd3f59d910..4ce1d1dddb26a 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaRowSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaRowSuite.java
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.api.java;
+package test.org.apache.spark.sql;
import java.math.BigDecimal;
import java.sql.Date;
diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
similarity index 94%
rename from sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
rename to sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index a21a15409080c..79d92734ff375 100644
--- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -15,24 +15,26 @@
* limitations under the License.
*/
-package org.apache.spark.sql.api.java;
+package test.org.apache.spark.sql;
import java.io.Serializable;
-import org.apache.spark.sql.test.TestSQLContext$;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
-import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.api.java.UDF1;
+import org.apache.spark.sql.api.java.UDF2;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.DataTypes;
// The test suite itself is Serializable so that anonymous Function implementations can be
// serialized, as an alternative to converting these anonymous classes to static inner classes;
// see http://stackoverflow.com/questions/758570/.
-public class JavaAPISuite implements Serializable {
+public class JavaUDFSuite implements Serializable {
private transient JavaSparkContext sc;
private transient SQLContext sqlContext;
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
new file mode 100644
index 0000000000000..b76f7d421f643
--- /dev/null
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.*;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.test.TestSQLContext$;
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.Utils;
+
+public class JavaSaveLoadSuite {
+
+ private transient JavaSparkContext sc;
+ private transient SQLContext sqlContext;
+
+ String originalDefaultSource;
+ File path;
+ DataFrame df;
+
+ private void checkAnswer(DataFrame actual, List expected) {
+ String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
+ if (errorMessage != null) {
+ Assert.fail(errorMessage);
+ }
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ sqlContext = TestSQLContext$.MODULE$;
+ sc = new JavaSparkContext(sqlContext.sparkContext());
+
+ originalDefaultSource = sqlContext.conf().defaultDataSourceName();
+ path =
+ Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile();
+ if (path.exists()) {
+ path.delete();
+ }
+
+ List jsonObjects = new ArrayList(10);
+ for (int i = 0; i < 10; i++) {
+ jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}");
+ }
+ JavaRDD rdd = sc.parallelize(jsonObjects);
+ df = sqlContext.jsonRDD(rdd);
+ df.registerTempTable("jsonTable");
+ }
+
+ @Test
+ public void saveAndLoad() {
+ Map options = new HashMap();
+ options.put("path", path.toString());
+ df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options);
+
+ DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", options);
+
+ checkAnswer(loadedDF, df.collectAsList());
+ }
+
+ @Test
+ public void saveAndLoadWithSchema() {
+ Map options = new HashMap();
+ options.put("path", path.toString());
+ df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, options);
+
+ List fields = new ArrayList();
+ fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame loadedDF = sqlContext.load("org.apache.spark.sql.json", schema, options);
+
+ checkAnswer(loadedDF, sqlContext.sql("SELECT b FROM jsonTable").collectAsList());
+ }
+}
diff --git a/sql/core/src/test/resources/log4j.properties b/sql/core/src/test/resources/log4j.properties
index fbed0a782dd3e..28e90b9520b2c 100644
--- a/sql/core/src/test/resources/log4j.properties
+++ b/sql/core/src/test/resources/log4j.properties
@@ -39,6 +39,9 @@ log4j.appender.FA.Threshold = INFO
log4j.additivity.parquet.hadoop.ParquetRecordReader=false
log4j.logger.parquet.hadoop.ParquetRecordReader=OFF
+log4j.additivity.parquet.hadoop.ParquetOutputCommitter=false
+log4j.logger.parquet.hadoop.ParquetOutputCommitter=OFF
+
log4j.additivity.org.apache.hadoop.hive.serde2.lazy.LazyStruct=false
log4j.logger.org.apache.hadoop.hive.serde2.lazy.LazyStruct=OFF
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index 1318750a4a3b0..c240f2be955ca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -18,24 +18,21 @@
package org.apache.spark.sql
import scala.concurrent.duration._
-import scala.language.implicitConversions
-import scala.language.postfixOps
+import scala.language.{implicitConversions, postfixOps}
import org.scalatest.concurrent.Eventually._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
-import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.storage.{StorageLevel, RDDBlockId}
+import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.storage.{RDDBlockId, StorageLevel}
case class BigData(s: String)
class CachedTableSuite extends QueryTest {
TestData // Load test tables.
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
def rddIdOf(tableName: String): Int = {
val executedPlan = table(tableName).queryExecution.executedPlan
executedPlan.collect {
@@ -60,15 +57,15 @@ class CachedTableSuite extends QueryTest {
test("unpersist an uncached table will not raise exception") {
assert(None == cacheManager.lookupCachedData(testData))
- testData.unpersist(true)
+ testData.unpersist(blocking = true)
assert(None == cacheManager.lookupCachedData(testData))
- testData.unpersist(false)
+ testData.unpersist(blocking = false)
assert(None == cacheManager.lookupCachedData(testData))
testData.persist()
assert(None != cacheManager.lookupCachedData(testData))
- testData.unpersist(true)
+ testData.unpersist(blocking = true)
assert(None == cacheManager.lookupCachedData(testData))
- testData.unpersist(false)
+ testData.unpersist(blocking = false)
assert(None == cacheManager.lookupCachedData(testData))
}
@@ -95,7 +92,7 @@ class CachedTableSuite extends QueryTest {
test("too big for memory") {
val data = "*" * 10000
- sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).registerTempTable("bigData")
+ sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF().registerTempTable("bigData")
table("bigData").persist(StorageLevel.MEMORY_AND_DISK)
assert(table("bigData").count() === 200000L)
table("bigData").unpersist(blocking = true)
@@ -283,4 +280,20 @@ class CachedTableSuite extends QueryTest {
assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found"))
assert(!isCached("t2"))
}
+
+ test("Clear all cache") {
+ sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
+ sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
+ cacheTable("t1")
+ cacheTable("t2")
+ clearCache()
+ assert(cacheManager.isEmpty)
+
+ sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1")
+ sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2")
+ cacheTable("t1")
+ cacheTable("t2")
+ sql("Clear CACHE")
+ assert(cacheManager.isEmpty)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index fa4cdecbcb340..37c02aaa5460b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -17,8 +17,9 @@
package org.apache.spark.sql
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType}
@@ -27,43 +28,10 @@ class ColumnExpressionSuite extends QueryTest {
// TODO: Add test cases for bitwise operations.
- test("computability check") {
- def shouldBeComputable(c: Column): Unit = assert(c.isComputable === true)
-
- def shouldNotBeComputable(c: Column): Unit = {
- assert(c.isComputable === false)
- intercept[UnsupportedOperationException] { c.head() }
- }
-
- shouldBeComputable(testData2("a"))
- shouldBeComputable(testData2("b"))
-
- shouldBeComputable(testData2("a") + testData2("b"))
- shouldBeComputable(testData2("a") + testData2("b") + 1)
-
- shouldBeComputable(-testData2("a"))
- shouldBeComputable(!testData2("a"))
-
- shouldBeComputable(testData2.select(($"a" + 1).as("c"))("c") + testData2("b"))
- shouldBeComputable(
- testData2.select(($"a" + 1).as("c"))("c") + testData2.select(($"b" / 2).as("d"))("d"))
- shouldBeComputable(
- testData2.select(($"a" + 1).as("c")).select(($"c" + 2).as("d"))("d") + testData2("b"))
-
- // Literals and unresolved columns should not be computable.
- shouldNotBeComputable(col("1"))
- shouldNotBeComputable(col("1") + 2)
- shouldNotBeComputable(lit(100))
- shouldNotBeComputable(lit(100) + 10)
- shouldNotBeComputable(-col("1"))
- shouldNotBeComputable(!col("1"))
-
- // Getting data from different frames should not be computable.
- shouldNotBeComputable(testData2("a") + testData("key"))
- shouldNotBeComputable(testData2("a") + 1 + testData("key"))
-
- // Aggregate functions alone should not be computable.
- shouldNotBeComputable(sum(testData2("a")))
+ test("collect on column produced by a binary operator") {
+ val df = Seq((1, 2, 3)).toDF("a", "b", "c")
+ checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
+ checkAnswer(df.select(df("a") + df("b").as("c")), Seq(Row(3)))
}
test("star") {
@@ -71,8 +39,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("star qualified by data frame object") {
- // This is not yet supported.
- val df = testData.toDataFrame
+ val df = testData.toDF
val goldAnswer = df.collect().toSeq
checkAnswer(df.select(df("*")), goldAnswer)
@@ -149,13 +116,13 @@ class ColumnExpressionSuite extends QueryTest {
test("isNull") {
checkAnswer(
- nullStrings.toDataFrame.where($"s".isNull),
+ nullStrings.toDF.where($"s".isNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
}
test("isNotNull") {
checkAnswer(
- nullStrings.toDataFrame.where($"s".isNotNull),
+ nullStrings.toDF.where($"s".isNotNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
}
@@ -180,7 +147,7 @@ class ColumnExpressionSuite extends QueryTest {
}
test("!==") {
- val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(1, 1) ::
Row(1, 2) ::
Row(1, null) ::
@@ -240,7 +207,7 @@ class ColumnExpressionSuite extends QueryTest {
testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1)))
}
- val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize(
+ val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize(
Row(false, false) ::
Row(false, true) ::
Row(true, false) ::
@@ -342,4 +309,8 @@ class ColumnExpressionSuite extends QueryTest {
(1 to 100).map(n => Row(null))
)
}
+
+ test("lift alias out of cast") {
+ assert(col("1234").as("name").cast("int").expr === col("1234").cast("int").as("name").expr)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
new file mode 100644
index 0000000000000..2d2367d6e7292
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
@@ -0,0 +1,55 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc}
+import org.apache.spark.sql.test.TestSQLContext.implicits._
+
+
+class DataFrameImplicitsSuite extends QueryTest {
+
+ test("RDD of tuples") {
+ checkAnswer(
+ sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
+ (1 to 10).map(i => Row(i, i.toString)))
+ }
+
+ test("Seq of tuples") {
+ checkAnswer(
+ (1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"),
+ (1 to 10).map(i => Row(i, i.toString)))
+ }
+
+ test("RDD[Int]") {
+ checkAnswer(
+ sc.parallelize(1 to 10).toDF("intCol"),
+ (1 to 10).map(i => Row(i)))
+ }
+
+ test("RDD[Long]") {
+ checkAnswer(
+ sc.parallelize(1L to 10L).toDF("longCol"),
+ (1L to 10L).map(i => Row(i)))
+ }
+
+ test("RDD[String]") {
+ checkAnswer(
+ sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"),
+ (1 to 10).map(i => Row(i.toString)))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 5aa3db720c886..ff441ef26f9c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -19,11 +19,12 @@ package org.apache.spark.sql
import scala.language.postfixOps
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.logicalPlanToSparkQuery
import org.apache.spark.sql.test.TestSQLContext.implicits._
+import org.apache.spark.sql.test.TestSQLContext.sql
class DataFrameSuite extends QueryTest {
@@ -53,18 +54,88 @@ class DataFrameSuite extends QueryTest {
TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
}
+ test("dataframe toString") {
+ assert(testData.toString === "[key: int, value: string]")
+ assert(testData("key").toString === "key")
+ assert($"test".toString === "test")
+ }
+
+ test("invalid plan toString, debug mode") {
+ val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis
+ TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true")
+
+ // Turn on debug mode so we can see invalid query plans.
+ import org.apache.spark.sql.execution.debug._
+ TestSQLContext.debug()
+
+ val badPlan = testData.select('badColumn)
+
+ assert(badPlan.toString contains badPlan.queryExecution.toString,
+ "toString on bad query plans should include the query execution but was:\n" +
+ badPlan.toString)
+
+ // Set the flag back to original value before this test.
+ TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString)
+ }
+
test("table scan") {
checkAnswer(
testData,
testData.collect().toSeq)
}
+ test("head and take") {
+ assert(testData.take(2) === testData.collect().take(2))
+ assert(testData.head(2) === testData.collect().take(2))
+ assert(testData.head(2).head.schema === testData.schema)
+ }
+
+ test("self join") {
+ val df1 = testData.select(testData("key")).as('df1)
+ val df2 = testData.select(testData("key")).as('df2)
+
+ checkAnswer(
+ df1.join(df2, $"df1.key" === $"df2.key"),
+ sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq)
+ }
+
+ test("simple explode") {
+ val df = Seq(Tuple1("a b c"), Tuple1("d e")).toDF("words")
+
+ checkAnswer(
+ df.explode("words", "word") { word: String => word.split(" ").toSeq }.select('word),
+ Row("a") :: Row("b") :: Row("c") :: Row("d") ::Row("e") :: Nil
+ )
+ }
+
+ test("explode") {
+ val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters")
+ val df2 =
+ df.explode('letters) {
+ case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq
+ }
+
+ checkAnswer(
+ df2
+ .select('_1 as 'letter, 'number)
+ .groupBy('letter)
+ .agg('letter, countDistinct('number)),
+ Row("a", 3) :: Row("b", 2) :: Row("c", 1) :: Nil
+ )
+ }
+
test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),
testData.collect().map(row => Row(math.abs(row.getInt(0)), row.getString(1))).toSeq)
}
+ test("selectExpr with alias") {
+ checkAnswer(
+ testData.selectExpr("key as k").select("k"),
+ testData.select("key").collect().toSeq)
+ }
+
test("filterExpr") {
checkAnswer(
testData.filter("key > 90"),
@@ -77,15 +148,42 @@ class DataFrameSuite extends QueryTest {
testData.select('key).collect().toSeq)
}
- test("agg") {
+ test("groupBy") {
checkAnswer(
testData2.groupBy("a").agg($"a", sum($"b")),
- Seq(Row(1,3), Row(2,3), Row(3,3))
+ Seq(Row(1, 3), Row(2, 3), Row(3, 3))
)
checkAnswer(
testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)),
Row(9)
)
+ checkAnswer(
+ testData2.groupBy("a").agg(col("a"), count("*")),
+ Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
+ )
+ checkAnswer(
+ testData2.groupBy("a").agg(Map("*" -> "count")),
+ Row(1, 2) :: Row(2, 2) :: Row(3, 2) :: Nil
+ )
+ checkAnswer(
+ testData2.groupBy("a").agg(Map("b" -> "sum")),
+ Row(1, 3) :: Row(2, 3) :: Row(3, 3) :: Nil
+ )
+
+ val df1 = Seq(("a", 1, 0, "b"), ("b", 2, 4, "c"), ("a", 2, 3, "d"))
+ .toDF("key", "value1", "value2", "rest")
+
+ checkAnswer(
+ df1.groupBy("key").min(),
+ df1.groupBy("key").min("value1", "value2").collect()
+ )
+ checkAnswer(
+ df1.groupBy("key").min("value2"),
+ Seq(Row("a", 0), Row("b", 4))
+ )
+ }
+
+ test("agg without groups") {
checkAnswer(
testData2.agg(sum('b)),
Row(9)
@@ -141,6 +239,10 @@ class DataFrameSuite extends QueryTest {
testData2.orderBy('a.asc, 'b.asc),
Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2)))
+ checkAnswer(
+ testData2.orderBy(asc("a"), desc("b")),
+ Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
+
checkAnswer(
testData2.orderBy('a.asc, 'b.desc),
Seq(Row(1,2), Row(1,1), Row(2,2), Row(2,1), Row(3,2), Row(3,1)))
@@ -154,20 +256,20 @@ class DataFrameSuite extends QueryTest {
Seq(Row(3,1), Row(3,2), Row(2,1), Row(2,2), Row(1,1), Row(1,2)))
checkAnswer(
- arrayData.orderBy('data.getItem(0).asc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
+ arrayData.toDF().orderBy('data.getItem(0).asc),
+ arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq)
checkAnswer(
- arrayData.orderBy('data.getItem(0).desc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
+ arrayData.toDF().orderBy('data.getItem(0).desc),
+ arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq)
checkAnswer(
- arrayData.orderBy('data.getItem(1).asc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
+ arrayData.toDF().orderBy('data.getItem(1).asc),
+ arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq)
checkAnswer(
- arrayData.orderBy('data.getItem(1).desc),
- arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
+ arrayData.toDF().orderBy('data.getItem(1).desc),
+ arrayData.toDF().collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq)
}
test("limit") {
@@ -176,11 +278,11 @@ class DataFrameSuite extends QueryTest {
testData.take(10).toSeq)
checkAnswer(
- arrayData.limit(1),
+ arrayData.toDF().limit(1),
arrayData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
checkAnswer(
- mapData.limit(1),
+ mapData.toDF().limit(1),
mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq)))
}
@@ -313,8 +415,8 @@ class DataFrameSuite extends QueryTest {
)
}
- test("addColumn") {
- val df = testData.toDataFrame.addColumn("newCol", col("key") + 1)
+ test("withColumn") {
+ val df = testData.toDF().withColumn("newCol", col("key") + 1)
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
@@ -323,9 +425,9 @@ class DataFrameSuite extends QueryTest {
assert(df.schema.map(_.name).toSeq === Seq("key", "value", "newCol"))
}
- test("renameColumn") {
- val df = testData.toDataFrame.addColumn("newCol", col("key") + 1)
- .renameColumn("value", "valueRenamed")
+ test("withColumnRenamed") {
+ val df = testData.toDF().withColumn("newCol", col("key") + 1)
+ .withColumnRenamed("value", "valueRenamed")
checkAnswer(
df,
testData.collect().map { case Row(key: Int, value: String) =>
@@ -336,7 +438,12 @@ class DataFrameSuite extends QueryTest {
test("apply on query results (SPARK-5462)") {
val df = testData.sqlContext.sql("select key from testData")
- checkAnswer(df("key"), testData.select('key).collect().toSeq)
+ checkAnswer(df.select(df("key")), testData.select('key).collect().toSeq)
}
+ ignore("show") {
+ // This test case is intended ignored, but to make sure it compiles correctly
+ testData.select($"*").show()
+ testData.select($"*").show(1000)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index f0c939dbb195f..dd0948ad824be 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
class JoinSuite extends QueryTest with BeforeAndAfterEach {
@@ -39,8 +40,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
def assertJoin(sqlString: String, c: Class[_]): Any = {
- val rdd = sql(sqlString)
- val physical = rdd.queryExecution.sparkPlan
+ val df = sql(sqlString)
+ val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
case j: ShuffledHashJoin => j
case j: HashOuterJoin => j
@@ -409,8 +410,8 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
}
test("left semi join") {
- val rdd = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
- checkAnswer(rdd,
+ val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
+ checkAnswer(df,
Row(1, 1) ::
Row(1, 2) ::
Row(2, 1) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
new file mode 100644
index 0000000000000..f9f41eb358bd5
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -0,0 +1,87 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfter
+
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
+
+class ListTablesSuite extends QueryTest with BeforeAndAfter {
+
+ import org.apache.spark.sql.test.TestSQLContext.implicits._
+
+ val df =
+ sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value")
+
+ before {
+ df.registerTempTable("ListTablesSuiteTable")
+ }
+
+ after {
+ catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ }
+
+ test("get all tables") {
+ checkAnswer(
+ tables().filter("tableName = 'ListTablesSuiteTable'"),
+ Row("ListTablesSuiteTable", true))
+
+ checkAnswer(
+ sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
+ Row("ListTablesSuiteTable", true))
+
+ catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+ }
+
+ test("getting all Tables with a database name has no impact on returned table names") {
+ checkAnswer(
+ tables("DB").filter("tableName = 'ListTablesSuiteTable'"),
+ Row("ListTablesSuiteTable", true))
+
+ checkAnswer(
+ sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
+ Row("ListTablesSuiteTable", true))
+
+ catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0)
+ }
+
+ test("query the returned DataFrame of tables") {
+ val expectedSchema = StructType(
+ StructField("tableName", StringType, false) ::
+ StructField("isTemporary", BooleanType, false) :: Nil)
+
+ Seq(tables(), sql("SHOW TABLes")).foreach {
+ case tableDF =>
+ assert(expectedSchema === tableDF.schema)
+
+ tableDF.registerTempTable("tables")
+ checkAnswer(
+ sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
+ Row(true, "ListTablesSuiteTable")
+ )
+ checkAnswer(
+ tables().filter("tableName = 'tables'").select("tableName", "isTemporary"),
+ Row("tables", true))
+ dropTempTable("tables")
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index f9ddd2ca5c567..9b4dd6c620fec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql
import java.util.{Locale, TimeZone}
+import scala.collection.JavaConversions._
+
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.columnar.InMemoryRelation
@@ -33,29 +35,71 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer contains all of the keywords, or the
* none of keywords are listed in the answer
- * @param rdd the [[DataFrame]] to be executed
+ * @param df the [[DataFrame]] to be executed
* @param exists true for make sure the keywords are listed in the output, otherwise
* to make sure none of the keyword are not listed in the output
* @param keywords keyword in string array
*/
- def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) {
- val outputs = rdd.collect().map(_.mkString).mkString
+ def checkExistence(df: DataFrame, exists: Boolean, keywords: String*) {
+ val outputs = df.collect().map(_.mkString).mkString
for (key <- keywords) {
if (exists) {
- assert(outputs.contains(key), s"Failed for $rdd ($key doens't exist in result)")
+ assert(outputs.contains(key), s"Failed for $df ($key doesn't exist in result)")
} else {
- assert(!outputs.contains(key), s"Failed for $rdd ($key existed in the result)")
+ assert(!outputs.contains(key), s"Failed for $df ($key existed in the result)")
}
}
}
/**
* Runs the plan and makes sure the answer matches the expected result.
- * @param rdd the [[DataFrame]] to be executed
- * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
+ * @param df the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ protected def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ QueryTest.checkAnswer(df, expectedAnswer) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ protected def checkAnswer(df: DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(df, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
+ test(sqlString) {
+ checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ }
+ }
+
+ /**
+ * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
*/
- protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
- val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
+ def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
+ val planWithCaching = query.queryExecution.withCachedData
+ val cachedData = planWithCaching collect {
+ case cached: InMemoryRelation => cached
+ }
+
+ assert(
+ cachedData.size == numCachedTables,
+ s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
+ planWithCaching)
+ }
+}
+
+object QueryTest {
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * If there was exception during the execution or the contents of the DataFrame does not
+ * match the expected result, an error message will be returned. Otherwise, a [[None]] will
+ * be returned.
+ * @param df the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ def checkAnswer(df: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
+ val isSorted = df.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
@@ -66,61 +110,47 @@ class QueryTest extends PlanTest {
case o => o
})
}
- if (!isSorted) converted.sortBy(_.toString) else converted
+ if (!isSorted) converted.sortBy(_.toString()) else converted
}
- val sparkAnswer = try rdd.collect().toSeq catch {
+ val sparkAnswer = try df.collect().toSeq catch {
case e: Exception =>
- fail(
+ val errorMessage =
s"""
|Exception thrown while executing query:
- |${rdd.queryExecution}
+ |${df.queryExecution}
|== Exception ==
|$e
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin)
+ """.stripMargin
+ return Some(errorMessage)
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
- fail(s"""
+ val errorMessage =
+ s"""
|Results do not match for query:
- |${rdd.logicalPlan}
+ |${df.logicalPlan}
|== Analyzed Plan ==
- |${rdd.queryExecution.analyzed}
+ |${df.queryExecution.analyzed}
|== Physical Plan ==
- |${rdd.queryExecution.executedPlan}
+ |${df.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer).map(_.toString),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
- """.stripMargin)
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString()),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
+ """.stripMargin
+ return Some(errorMessage)
}
- }
- protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
- checkAnswer(rdd, Seq(expectedAnswer))
+ return None
}
- def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
- test(sqlString) {
- checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ def checkAnswer(df: DataFrame, expectedAnswer: java.util.List[Row]): String = {
+ checkAnswer(df, expectedAnswer.toSeq) match {
+ case Some(errorMessage) => errorMessage
+ case None => null
}
}
-
- /**
- * Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
- */
- def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = {
- val planWithCaching = query.queryExecution.withCachedData
- val cachedData = planWithCaching collect {
- case cached: InMemoryRelation => cached
- }
-
- assert(
- cachedData.size == numCachedTables,
- s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" +
- planWithCaching)
- }
-
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 11502edf972e9..097bf0dd23c89 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.types._
@@ -34,6 +34,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
TestData
import org.apache.spark.sql.test.TestSQLContext.implicits._
+ val sqlCtx = TestSQLContext
test("SPARK-4625 support SORT BY in SimpleSQLParser & DSL") {
checkAnswer(
@@ -589,7 +590,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
("1" :: "2" :: "3" :: "4" :: "A" :: "B" :: "C" :: "D" :: "E" :: "F" :: Nil).map(Row(_)))
// Column type mismatches where a coercion is not possible, in this case between integer
// and array types, trigger a TreeNodeException.
- intercept[TreeNodeException[_]] {
+ intercept[AnalysisException] {
sql("SELECT data FROM arrayData UNION SELECT 1 FROM arrayData").collect()
}
}
@@ -669,7 +670,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(values(0).toInt, values(1), values(2).toBoolean, v4)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = sqlCtx.createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
checkAnswer(
sql("SELECT * FROM applySchema1"),
@@ -699,7 +700,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df2 = applySchema(rowRDD2, schema2)
+ val df2 = sqlCtx.createDataFrame(rowRDD2, schema2)
df2.registerTempTable("applySchema2")
checkAnswer(
sql("SELECT * FROM applySchema2"),
@@ -724,7 +725,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD3, schema2)
+ val df3 = sqlCtx.createDataFrame(rowRDD3, schema2)
df3.registerTempTable("applySchema3")
checkAnswer(
@@ -769,7 +770,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
.build()
val schemaWithMeta = new StructType(Array(
schema("id"), schema("name").copy(metadata = metadata), schema("age")))
- val personWithMeta = applySchema(person.rdd, schemaWithMeta)
+ val personWithMeta = sqlCtx.createDataFrame(person.rdd, schemaWithMeta)
def validateMetadata(rdd: DataFrame): Unit = {
assert(rdd.schema("name").metadata.getString(docKey) == docValue)
}
@@ -805,10 +806,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("throw errors for non-aggregate attributes with aggregation") {
def checkAggregation(query: String, isInvalidQuery: Boolean = true) {
if (isInvalidQuery) {
- val e = intercept[TreeNodeException[LogicalPlan]](sql(query).queryExecution.analyzed)
- assert(
- e.getMessage.startsWith("Expression not in GROUP BY"),
- "Non-aggregate attribute(s) not detected\n")
+ val e = intercept[AnalysisException](sql(query).queryExecution.analyzed)
+ assert(e.getMessage contains "group by")
} else {
// Should not throw
sql(query).queryExecution.analyzed
@@ -1035,10 +1034,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Supporting relational operator '<=>' in Spark SQL") {
val nullCheckData1 = TestData(1,"1") :: TestData(2,null) :: Nil
val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i)))
- rdd1.registerTempTable("nulldata1")
+ rdd1.toDF().registerTempTable("nulldata1")
val nullCheckData2 = TestData(1,"1") :: TestData(2,null) :: Nil
val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i)))
- rdd2.registerTempTable("nulldata2")
+ rdd2.toDF().registerTempTable("nulldata2")
checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " +
"nulldata2 on nulldata1.value <=> nulldata2.value"),
(1 to 2).map(i => Row(i)))
@@ -1047,7 +1046,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
test("Multi-column COUNT(DISTINCT ...)") {
val data = TestData(1,"val_1") :: TestData(2,"val_2") :: Nil
val rdd = sparkContext.parallelize((0 to 1).map(i => data(i)))
- rdd.registerTempTable("distinctData")
+ rdd.toDF().registerTempTable("distinctData")
checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index 93782619826f0..23df6e7eac043 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -82,7 +82,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1,2,3))
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectData")
+ rdd.toDF().registerTempTable("reflectData")
assert(sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
@@ -93,7 +93,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD with nulls") {
val data = NullReflectData(null, null, null, null, null, null, null)
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectNullData")
+ rdd.toDF().registerTempTable("reflectNullData")
assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
@@ -101,7 +101,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
test("query case class RDD with Nones") {
val data = OptionalReflectData(None, None, None, None, None, None, None)
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectOptionalData")
+ rdd.toDF().registerTempTable("reflectOptionalData")
assert(sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null)))
}
@@ -109,7 +109,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
// Equality is broken for Arrays, so we test that separately.
test("query binary data") {
val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil)
- rdd.registerTempTable("reflectBinary")
+ rdd.toDF().registerTempTable("reflectBinary")
val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]]
assert(result.toSeq === Seq[Byte](1))
@@ -128,7 +128,7 @@ class ScalaReflectionRelationSuite extends FunSuite {
Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None),
Nested(None, "abc")))
val rdd = sparkContext.parallelize(data :: Nil)
- rdd.registerTempTable("reflectComplexData")
+ rdd.toDF().registerTempTable("reflectComplexData")
assert(sql("SELECT * FROM reflectComplexData").collect().head ===
new GenericRow(Array[Any](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
new file mode 100644
index 0000000000000..6f6d3c9c243d4
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.sql.test.TestSQLContext
+
+class SerializationSuite extends FunSuite {
+
+ test("[SPARK-5235] SQLContext should be serializable") {
+ val sqlContext = new SQLContext(TestSQLContext.sparkContext)
+ new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 0ed437edd05fd..637f59b2e68ca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql
import java.sql.Timestamp
import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test._
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -29,11 +29,11 @@ case class TestData(key: Int, value: String)
object TestData {
val testData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString))).toDataFrame
+ (1 to 100).map(i => TestData(i, i.toString))).toDF()
testData.registerTempTable("testData")
val negativeData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(-i, (-i).toString))).toDataFrame
+ (1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
negativeData.registerTempTable("negativeData")
case class LargeAndSmallInts(a: Int, b: Int)
@@ -44,7 +44,7 @@ object TestData {
LargeAndSmallInts(2147483645, 1) ::
LargeAndSmallInts(2, 2) ::
LargeAndSmallInts(2147483646, 1) ::
- LargeAndSmallInts(3, 2) :: Nil).toDataFrame
+ LargeAndSmallInts(3, 2) :: Nil).toDF()
largeAndSmallInts.registerTempTable("largeAndSmallInts")
case class TestData2(a: Int, b: Int)
@@ -55,7 +55,7 @@ object TestData {
TestData2(2, 1) ::
TestData2(2, 2) ::
TestData2(3, 1) ::
- TestData2(3, 2) :: Nil, 2).toDataFrame
+ TestData2(3, 2) :: Nil, 2).toDF()
testData2.registerTempTable("testData2")
case class DecimalData(a: BigDecimal, b: BigDecimal)
@@ -67,7 +67,7 @@ object TestData {
DecimalData(2, 1) ::
DecimalData(2, 2) ::
DecimalData(3, 1) ::
- DecimalData(3, 2) :: Nil).toDataFrame
+ DecimalData(3, 2) :: Nil).toDF()
decimalData.registerTempTable("decimalData")
case class BinaryData(a: Array[Byte], b: Int)
@@ -77,14 +77,14 @@ object TestData {
BinaryData("22".getBytes(), 5) ::
BinaryData("122".getBytes(), 3) ::
BinaryData("121".getBytes(), 2) ::
- BinaryData("123".getBytes(), 4) :: Nil).toDataFrame
+ BinaryData("123".getBytes(), 4) :: Nil).toDF()
binaryData.registerTempTable("binaryData")
case class TestData3(a: Int, b: Option[Int])
val testData3 =
TestSQLContext.sparkContext.parallelize(
TestData3(1, None) ::
- TestData3(2, Some(2)) :: Nil).toDataFrame
+ TestData3(2, Some(2)) :: Nil).toDF()
testData3.registerTempTable("testData3")
val emptyTableData = logical.LocalRelation($"a".int, $"b".int)
@@ -97,7 +97,7 @@ object TestData {
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
- UpperCaseData(6, "F") :: Nil).toDataFrame
+ UpperCaseData(6, "F") :: Nil).toDF()
upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String)
@@ -106,7 +106,7 @@ object TestData {
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
- LowerCaseData(4, "d") :: Nil).toDataFrame
+ LowerCaseData(4, "d") :: Nil).toDF()
lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
@@ -114,7 +114,7 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ArrayData(Seq(1,2,3), Seq(Seq(1,2,3))) ::
ArrayData(Seq(2,3,4), Seq(Seq(2,3,4))) :: Nil)
- arrayData.registerTempTable("arrayData")
+ arrayData.toDF().registerTempTable("arrayData")
case class MapData(data: scala.collection.Map[Int, String])
val mapData =
@@ -124,18 +124,18 @@ object TestData {
MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
MapData(Map(1 -> "a4", 2 -> "b4")) ::
MapData(Map(1 -> "a5")) :: Nil)
- mapData.registerTempTable("mapData")
+ mapData.toDF().registerTempTable("mapData")
case class StringData(s: String)
val repeatedData =
TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
- repeatedData.registerTempTable("repeatedData")
+ repeatedData.toDF().registerTempTable("repeatedData")
val nullableRepeatedData =
TestSQLContext.sparkContext.parallelize(
List.fill(2)(StringData(null)) ++
List.fill(2)(StringData("test")))
- nullableRepeatedData.registerTempTable("nullableRepeatedData")
+ nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData")
case class NullInts(a: Integer)
val nullInts =
@@ -144,7 +144,7 @@ object TestData {
NullInts(2) ::
NullInts(3) ::
NullInts(null) :: Nil
- )
+ ).toDF()
nullInts.registerTempTable("nullInts")
val allNulls =
@@ -152,7 +152,7 @@ object TestData {
NullInts(null) ::
NullInts(null) ::
NullInts(null) ::
- NullInts(null) :: Nil)
+ NullInts(null) :: Nil).toDF()
allNulls.registerTempTable("allNulls")
case class NullStrings(n: Int, s: String)
@@ -160,11 +160,15 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
NullStrings(1, "abc") ::
NullStrings(2, "ABC") ::
- NullStrings(3, null) :: Nil).toDataFrame
+ NullStrings(3, null) :: Nil).toDF()
nullStrings.registerTempTable("nullStrings")
case class TableName(tableName: String)
- TestSQLContext.sparkContext.parallelize(TableName("test") :: Nil).registerTempTable("tableName")
+ TestSQLContext
+ .sparkContext
+ .parallelize(TableName("test") :: Nil)
+ .toDF()
+ .registerTempTable("tableName")
val unparsedStrings =
TestSQLContext.sparkContext.parallelize(
@@ -177,22 +181,22 @@ object TestData {
val timestamps = TestSQLContext.sparkContext.parallelize((1 to 3).map { i =>
TimestampField(new Timestamp(i))
})
- timestamps.registerTempTable("timestamps")
+ timestamps.toDF().registerTempTable("timestamps")
case class IntField(i: Int)
// An RDD with 4 elements and 8 partitions
val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
- withEmptyParts.registerTempTable("withEmptyParts")
+ withEmptyParts.toDF().registerTempTable("withEmptyParts")
case class Person(id: Int, name: String, age: Int)
case class Salary(personId: Int, salary: Double)
val person = TestSQLContext.sparkContext.parallelize(
Person(0, "mike", 30) ::
- Person(1, "jim", 20) :: Nil)
+ Person(1, "jim", 20) :: Nil).toDF()
person.registerTempTable("person")
val salary = TestSQLContext.sparkContext.parallelize(
Salary(0, 2000.0) ::
- Salary(1, 1000.0) :: Nil)
+ Salary(1, 1000.0) :: Nil).toDF()
salary.registerTempTable("salary")
case class ComplexData(m: Map[Int, String], s: TestData, a: Seq[Int], b: Boolean)
@@ -200,6 +204,6 @@ object TestData {
TestSQLContext.sparkContext.parallelize(
ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true)
:: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false)
- :: Nil).toDataFrame
+ :: Nil).toDF()
complexData.registerTempTable("complexData")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 95923f9aad931..be105c6e83594 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,11 +17,11 @@
package org.apache.spark.sql
-import org.apache.spark.sql.Dsl.StringToColumn
import org.apache.spark.sql.test._
/* Implicits */
import TestSQLContext._
+import TestSQLContext.implicits._
case class FunctionResult(f1: String, f2: String)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 3c1657cd5fc3a..47fdb5543235c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql
+import java.io.File
+
import scala.beans.{BeanInfo, BeanProperty}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql}
import org.apache.spark.sql.test.TestSQLContext.implicits._
@@ -66,7 +68,7 @@ class UserDefinedTypeSuite extends QueryTest {
val points = Seq(
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0))))
- val pointsRDD: RDD[MyLabeledPoint] = sparkContext.parallelize(points)
+ val pointsRDD = sparkContext.parallelize(points).toDF()
test("register user type: MyDenseVector for MyLabeledPoint") {
@@ -91,4 +93,26 @@ class UserDefinedTypeSuite extends QueryTest {
sql("SELECT testType(features) from points"),
Seq(Row(true), Row(true)))
}
+
+
+ test("UDTs with Parquet") {
+ val tempDir = File.createTempFile("parquet", "test")
+ tempDir.delete()
+ pointsRDD.saveAsParquetFile(tempDir.getCanonicalPath)
+ }
+
+ test("Repartition UDTs with Parquet") {
+ val tempDir = File.createTempFile("parquet", "test")
+ tempDir.delete()
+ pointsRDD.repartition(1).saveAsParquetFile(tempDir.getCanonicalPath)
+ }
+
+ // Tests to make sure that all operators correctly convert types on the way out.
+ test("Local UDTs") {
+ val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec")
+ df.collect()(0).getAs[MyDenseVector](1)
+ df.take(1)(0).getAs[MyDenseVector](1)
+ df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
+ df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 86b1b5fda1c0f..38b0f666ab90b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.columnar
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.{QueryTest, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
@@ -28,8 +29,6 @@ class InMemoryColumnarQuerySuite extends QueryTest {
// Make sure the tables are loaded.
TestData
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
test("simple columnar query") {
val plan = executePlan(testData.logicalPlan).executedPlan
val scan = InMemoryRelation(useCompression = true, 5, MEMORY_ONLY, plan, None)
@@ -39,7 +38,8 @@ class InMemoryColumnarQuerySuite extends QueryTest {
test("default size avoids broadcast") {
// TODO: Improve this test when we have better statistics
- sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).registerTempTable("sizeTst")
+ sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
+ .toDF().registerTempTable("sizeTst")
cacheTable("sizeTst")
assert(
table("sizeTst").queryExecution.logical.statistics.sizeInBytes >
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 55a9f735b3506..e57bb06e7263b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -21,13 +21,12 @@ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll, FunSuite}
import org.apache.spark.sql._
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with BeforeAndAfter {
val originalColumnBatchSize = conf.columnBatchSize
val originalInMemoryPartitionPruning = conf.inMemoryPartitionPruning
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
override protected def beforeAll(): Unit = {
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
setConf(SQLConf.COLUMN_BATCH_SIZE, "10")
@@ -35,7 +34,7 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be
val pruningData = sparkContext.makeRDD((1 to 100).map { key =>
val string = if (((key - 1) / 10) % 2 == 0) null else key.toString
TestData(key, string)
- }, 5)
+ }, 5).toDF()
pruningData.registerTempTable("pruningData")
// Enable in-memory partition pruning
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index df108a9d262bb..523be56df65ba 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.execution
import org.scalatest.FunSuite
import org.apache.spark.sql.{SQLConf, execution}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.test.TestSQLContext.planner._
import org.apache.spark.sql.types._
@@ -71,7 +72,7 @@ class PlannerSuite extends FunSuite {
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
- applySchema(rowRDD, schema).registerTempTable("testLimit")
+ createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index d25c1390db15c..cd737c0b62767 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -18,9 +18,11 @@
package org.apache.spark.sql.jdbc
import java.math.BigDecimal
+import java.sql.DriverManager
+import java.util.{Calendar, GregorianCalendar}
+
import org.apache.spark.sql.test._
import org.scalatest.{FunSuite, BeforeAndAfter}
-import java.sql.DriverManager
import TestSQLContext._
class JDBCSuite extends FunSuite with BeforeAndAfter {
@@ -164,17 +166,16 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
}
test("Basic API") {
- assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE").collect.size == 3)
+ assert(TestSQLContext.jdbc(url, "TEST.PEOPLE").collect.size == 3)
}
test("Partitioning via JDBCPartitioningInfo API") {
- val parts = JDBCPartitioningInfo("THEID", 0, 4, 3)
- assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3)
+ assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", "THEID", 0, 4, 3).collect.size == 3)
}
test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
- assert(TestSQLContext.jdbcRDD(url, "TEST.PEOPLE", parts).collect.size == 3)
+ assert(TestSQLContext.jdbc(url, "TEST.PEOPLE", parts).collect.size == 3)
}
test("H2 integral types") {
@@ -207,20 +208,25 @@ class JDBCSuite extends FunSuite with BeforeAndAfter {
assert(rows(0).getString(5).equals("I am a clob!"))
}
+
test("H2 time types") {
val rows = sql("SELECT * FROM timetypes").collect()
- assert(rows(0).getAs[java.sql.Timestamp](0).getHours == 12)
- assert(rows(0).getAs[java.sql.Timestamp](0).getMinutes == 34)
- assert(rows(0).getAs[java.sql.Timestamp](0).getSeconds == 56)
- assert(rows(0).getAs[java.sql.Date](1).getYear == 96)
- assert(rows(0).getAs[java.sql.Date](1).getMonth == 0)
- assert(rows(0).getAs[java.sql.Date](1).getDate == 1)
- assert(rows(0).getAs[java.sql.Timestamp](2).getYear == 102)
- assert(rows(0).getAs[java.sql.Timestamp](2).getMonth == 1)
- assert(rows(0).getAs[java.sql.Timestamp](2).getDate == 20)
- assert(rows(0).getAs[java.sql.Timestamp](2).getHours == 11)
- assert(rows(0).getAs[java.sql.Timestamp](2).getMinutes == 22)
- assert(rows(0).getAs[java.sql.Timestamp](2).getSeconds == 33)
+ val cal = new GregorianCalendar(java.util.Locale.ROOT)
+ cal.setTime(rows(0).getAs[java.sql.Timestamp](0))
+ assert(cal.get(Calendar.HOUR_OF_DAY) == 12)
+ assert(cal.get(Calendar.MINUTE) == 34)
+ assert(cal.get(Calendar.SECOND) == 56)
+ cal.setTime(rows(0).getAs[java.sql.Timestamp](1))
+ assert(cal.get(Calendar.YEAR) == 1996)
+ assert(cal.get(Calendar.MONTH) == 0)
+ assert(cal.get(Calendar.DAY_OF_MONTH) == 1)
+ cal.setTime(rows(0).getAs[java.sql.Timestamp](2))
+ assert(cal.get(Calendar.YEAR) == 2002)
+ assert(cal.get(Calendar.MONTH) == 1)
+ assert(cal.get(Calendar.DAY_OF_MONTH) == 20)
+ assert(cal.get(Calendar.HOUR) == 11)
+ assert(cal.get(Calendar.MINUTE) == 22)
+ assert(cal.get(Calendar.SECOND) == 33)
assert(rows(0).getAs[java.sql.Timestamp](2).getNanos == 543543543)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index e581ac9b50c2b..ee5c7620d1a22 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql.jdbc
-import java.math.BigDecimal
+import java.sql.DriverManager
+
+import org.scalatest.{BeforeAndAfter, FunSuite}
+
import org.apache.spark.sql.Row
-import org.apache.spark.sql.types._
import org.apache.spark.sql.test._
-import org.scalatest.{FunSuite, BeforeAndAfter}
-import java.sql.DriverManager
-import TestSQLContext._
+import org.apache.spark.sql.types._
class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
val url = "jdbc:h2:mem:testdb2"
@@ -54,53 +54,53 @@ class JDBCWriteSuite extends FunSuite with BeforeAndAfter {
StructField("seq", IntegerType) :: Nil)
test("Basic CREATE") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
+ val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
- srdd.createJDBCTable(url, "TEST.BASICCREATETEST", false)
- assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").count)
- assert(2 == TestSQLContext.jdbcRDD(url, "TEST.BASICCREATETEST").collect()(0).length)
+ df.createJDBCTable(url, "TEST.BASICCREATETEST", false)
+ assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").count)
+ assert(2 == TestSQLContext.jdbc(url, "TEST.BASICCREATETEST").collect()(0).length)
}
test("CREATE with overwrite") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
+ val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
- srdd.createJDBCTable(url, "TEST.DROPTEST", false)
- assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
- assert(3 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length)
+ df.createJDBCTable(url, "TEST.DROPTEST", false)
+ assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count)
+ assert(3 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length)
- srdd2.createJDBCTable(url, "TEST.DROPTEST", true)
- assert(1 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").count)
- assert(2 == TestSQLContext.jdbcRDD(url, "TEST.DROPTEST").collect()(0).length)
+ df2.createJDBCTable(url, "TEST.DROPTEST", true)
+ assert(1 == TestSQLContext.jdbc(url, "TEST.DROPTEST").count)
+ assert(2 == TestSQLContext.jdbc(url, "TEST.DROPTEST").collect()(0).length)
}
test("CREATE then INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
- srdd.createJDBCTable(url, "TEST.APPENDTEST", false)
- srdd2.insertIntoJDBC(url, "TEST.APPENDTEST", false)
- assert(3 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").count)
- assert(2 == TestSQLContext.jdbcRDD(url, "TEST.APPENDTEST").collect()(0).length)
+ df.createJDBCTable(url, "TEST.APPENDTEST", false)
+ df2.insertIntoJDBC(url, "TEST.APPENDTEST", false)
+ assert(3 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").count)
+ assert(2 == TestSQLContext.jdbc(url, "TEST.APPENDTEST").collect()(0).length)
}
test("CREATE then INSERT to truncate") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr1x2), schema2)
+ val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr1x2), schema2)
- srdd.createJDBCTable(url, "TEST.TRUNCATETEST", false)
- srdd2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true)
- assert(1 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").count)
- assert(2 == TestSQLContext.jdbcRDD(url, "TEST.TRUNCATETEST").collect()(0).length)
+ df.createJDBCTable(url, "TEST.TRUNCATETEST", false)
+ df2.insertIntoJDBC(url, "TEST.TRUNCATETEST", true)
+ assert(1 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").count)
+ assert(2 == TestSQLContext.jdbc(url, "TEST.TRUNCATETEST").collect()(0).length)
}
test("Incompatible INSERT to append") {
- val srdd = TestSQLContext.applySchema(sc.parallelize(arr2x2), schema2)
- val srdd2 = TestSQLContext.applySchema(sc.parallelize(arr2x3), schema3)
+ val df = TestSQLContext.createDataFrame(sc.parallelize(arr2x2), schema2)
+ val df2 = TestSQLContext.createDataFrame(sc.parallelize(arr2x3), schema3)
- srdd.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false)
+ df.createJDBCTable(url, "TEST.INCOMPATIBLETEST", false)
intercept[org.apache.spark.SparkException] {
- srdd2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true)
+ df2.insertIntoJDBC(url, "TEST.INCOMPATIBLETEST", true)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala
index 89920f2650c3a..5b8a76f461faf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegration.scala
@@ -18,18 +18,13 @@
package org.apache.spark.sql.jdbc
import java.math.BigDecimal
-import java.sql.{Date, DriverManager, Timestamp}
-import com.spotify.docker.client.{DefaultDockerClient, DockerClient}
+import java.sql.{Date, Timestamp}
+
+import com.spotify.docker.client.DockerClient
import com.spotify.docker.client.messages.ContainerConfig
-import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore}
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.SparkContext._
-import org.apache.spark.sql._
import org.apache.spark.sql.test._
-import TestSQLContext._
-
-import org.apache.spark.sql.jdbc._
class MySQLDatabase {
val docker: DockerClient = DockerClientFactory.get()
@@ -37,9 +32,9 @@ class MySQLDatabase {
println("Pulling mysql")
docker.pull("mysql")
println("Configuring container")
- val config = (ContainerConfig.builder().image("mysql")
- .env("MYSQL_ROOT_PASSWORD=rootpass")
- .build())
+ val config = ContainerConfig.builder().image("mysql")
+ .env("MYSQL_ROOT_PASSWORD=rootpass")
+ .build()
println("Creating container")
val id = docker.createContainer(config).id
println("Starting container " + id)
@@ -57,11 +52,10 @@ class MySQLDatabase {
println("Closing docker client")
DockerClientFactory.close(docker)
} catch {
- case e: Exception => {
+ case e: Exception =>
println(e)
println("You may need to clean this up manually.")
throw e
- }
}
}
}
@@ -86,10 +80,9 @@ class MySQLDatabase {
println("Database is up.")
return;
} catch {
- case e: java.sql.SQLException => {
+ case e: java.sql.SQLException =>
lastException = e
java.lang.Thread.sleep(250)
- }
}
}
}
@@ -143,8 +136,8 @@ class MySQLDatabase {
}
test("Basic test") {
- val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "tbl")
- val rows = rdd.collect
+ val df = TestSQLContext.jdbc(url(ip, "foo"), "tbl")
+ val rows = df.collect()
assert(rows.length == 2)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 2)
@@ -153,8 +146,8 @@ class MySQLDatabase {
}
test("Numeric types") {
- val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers")
- val rows = rdd.collect
+ val df = TestSQLContext.jdbc(url(ip, "foo"), "numbers")
+ val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 9)
@@ -181,8 +174,8 @@ class MySQLDatabase {
}
test("Date types") {
- val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates")
- val rows = rdd.collect
+ val df = TestSQLContext.jdbc(url(ip, "foo"), "dates")
+ val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 5)
@@ -199,8 +192,8 @@ class MySQLDatabase {
}
test("String types") {
- val rdd = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings")
- val rows = rdd.collect
+ val df = TestSQLContext.jdbc(url(ip, "foo"), "strings")
+ val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 9)
@@ -225,11 +218,11 @@ class MySQLDatabase {
}
test("Basic write test") {
- val rdd1 = TestSQLContext.jdbcRDD(url(ip, "foo"), "numbers")
- val rdd2 = TestSQLContext.jdbcRDD(url(ip, "foo"), "dates")
- val rdd3 = TestSQLContext.jdbcRDD(url(ip, "foo"), "strings")
- rdd1.createJDBCTable(url(ip, "foo"), "numberscopy", false)
- rdd2.createJDBCTable(url(ip, "foo"), "datescopy", false)
- rdd3.createJDBCTable(url(ip, "foo"), "stringscopy", false)
+ val df1 = TestSQLContext.jdbc(url(ip, "foo"), "numbers")
+ val df2 = TestSQLContext.jdbc(url(ip, "foo"), "dates")
+ val df3 = TestSQLContext.jdbc(url(ip, "foo"), "strings")
+ df1.createJDBCTable(url(ip, "foo"), "numberscopy", false)
+ df2.createJDBCTable(url(ip, "foo"), "datescopy", false)
+ df3.createJDBCTable(url(ip, "foo"), "stringscopy", false)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala
index c174d7adb7204..e17be99ac31d5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegration.scala
@@ -17,13 +17,13 @@
package org.apache.spark.sql.jdbc
-import java.math.BigDecimal
-import org.apache.spark.sql.test._
-import org.scalatest.{FunSuite, BeforeAndAfterAll, Ignore}
import java.sql.DriverManager
-import TestSQLContext._
-import com.spotify.docker.client.{DefaultDockerClient, DockerClient}
+
+import com.spotify.docker.client.DockerClient
import com.spotify.docker.client.messages.ContainerConfig
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Ignore}
+
+import org.apache.spark.sql.test._
class PostgresDatabase {
val docker: DockerClient = DockerClientFactory.get()
@@ -31,9 +31,9 @@ class PostgresDatabase {
println("Pulling postgres")
docker.pull("postgres")
println("Configuring container")
- val config = (ContainerConfig.builder().image("postgres")
- .env("POSTGRES_PASSWORD=rootpass")
- .build())
+ val config = ContainerConfig.builder().image("postgres")
+ .env("POSTGRES_PASSWORD=rootpass")
+ .build()
println("Creating container")
val id = docker.createContainer(config).id
println("Starting container " + id)
@@ -51,11 +51,10 @@ class PostgresDatabase {
println("Closing docker client")
DockerClientFactory.close(docker)
} catch {
- case e: Exception => {
+ case e: Exception =>
println(e)
println("You may need to clean this up manually.")
throw e
- }
}
}
}
@@ -79,10 +78,9 @@ class PostgresDatabase {
println("Database is up.")
return;
} catch {
- case e: java.sql.SQLException => {
+ case e: java.sql.SQLException =>
lastException = e
java.lang.Thread.sleep(250)
- }
}
}
}
@@ -113,8 +111,8 @@ class PostgresDatabase {
}
test("Type mapping for various types") {
- val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar")
- val rows = rdd.collect
+ val df = TestSQLContext.jdbc(url(db.ip), "public.bar")
+ val rows = df.collect()
assert(rows.length == 1)
val types = rows(0).toSeq.map(x => x.getClass.toString)
assert(types.length == 10)
@@ -142,8 +140,8 @@ class PostgresDatabase {
}
test("Basic write test") {
- val rdd = TestSQLContext.jdbcRDD(url(db.ip), "public.bar")
- rdd.createJDBCTable(url(db.ip), "public.barcopy", false)
+ val df = TestSQLContext.jdbc(url(db.ip), "public.bar")
+ df.createJDBCTable(url(db.ip), "public.barcopy", false)
// Test only that it doesn't bomb out.
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
index 7870cf9b0a868..005f20b96df79 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala
@@ -21,11 +21,12 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType}
import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
@@ -222,7 +223,7 @@ class JsonSuite extends QueryTest {
StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
- StructField("integer", IntegerType, true) ::
+ StructField("integer", LongType, true) ::
StructField("long", LongType, true) ::
StructField("null", StringType, true) ::
StructField("string", StringType, true) :: Nil)
@@ -252,7 +253,7 @@ class JsonSuite extends QueryTest {
StructField("arrayOfBigInteger", ArrayType(DecimalType.Unlimited, false), true) ::
StructField("arrayOfBoolean", ArrayType(BooleanType, false), true) ::
StructField("arrayOfDouble", ArrayType(DoubleType, false), true) ::
- StructField("arrayOfInteger", ArrayType(IntegerType, false), true) ::
+ StructField("arrayOfInteger", ArrayType(LongType, false), true) ::
StructField("arrayOfLong", ArrayType(LongType, false), true) ::
StructField("arrayOfNull", ArrayType(StringType, true), true) ::
StructField("arrayOfString", ArrayType(StringType, false), true) ::
@@ -265,7 +266,7 @@ class JsonSuite extends QueryTest {
StructField("field1", BooleanType, true) ::
StructField("field2", DecimalType.Unlimited, true) :: Nil), true) ::
StructField("structWithArrayFields", StructType(
- StructField("field1", ArrayType(IntegerType, false), true) ::
+ StructField("field1", ArrayType(LongType, false), true) ::
StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil)
assert(expectedSchema === jsonDF.schema)
@@ -486,7 +487,7 @@ class JsonSuite extends QueryTest {
val jsonDF = jsonRDD(complexFieldValueTypeConflict)
val expectedSchema = StructType(
- StructField("array", ArrayType(IntegerType, false), true) ::
+ StructField("array", ArrayType(LongType, false), true) ::
StructField("num_struct", StringType, true) ::
StructField("str_array", StringType, true) ::
StructField("struct", StructType(
@@ -540,7 +541,7 @@ class JsonSuite extends QueryTest {
val expectedSchema = StructType(
StructField("a", BooleanType, true) ::
StructField("b", LongType, true) ::
- StructField("c", ArrayType(IntegerType, false), true) ::
+ StructField("c", ArrayType(LongType, false), true) ::
StructField("d", StructType(
StructField("field", BooleanType, true) :: Nil), true) ::
StructField("e", StringType, true) :: Nil)
@@ -560,7 +561,7 @@ class JsonSuite extends QueryTest {
StructField("bigInteger", DecimalType.Unlimited, true) ::
StructField("boolean", BooleanType, true) ::
StructField("double", DoubleType, true) ::
- StructField("integer", IntegerType, true) ::
+ StructField("integer", LongType, true) ::
StructField("long", LongType, true) ::
StructField("null", StringType, true) ::
StructField("string", StringType, true) :: Nil)
@@ -656,6 +657,62 @@ class JsonSuite extends QueryTest {
)
}
+ test("Applying schemas with MapType") {
+ val schemaWithSimpleMap = StructType(
+ StructField("map", MapType(StringType, IntegerType, true), false) :: Nil)
+ val jsonWithSimpleMap = jsonRDD(mapType1, schemaWithSimpleMap)
+
+ jsonWithSimpleMap.registerTempTable("jsonWithSimpleMap")
+
+ checkAnswer(
+ sql("select map from jsonWithSimpleMap"),
+ Row(Map("a" -> 1)) ::
+ Row(Map("b" -> 2)) ::
+ Row(Map("c" -> 3)) ::
+ Row(Map("c" -> 1, "d" -> 4)) ::
+ Row(Map("e" -> null)) :: Nil
+ )
+
+ checkAnswer(
+ sql("select map['c'] from jsonWithSimpleMap"),
+ Row(null) ::
+ Row(null) ::
+ Row(3) ::
+ Row(1) ::
+ Row(null) :: Nil
+ )
+
+ val innerStruct = StructType(
+ StructField("field1", ArrayType(IntegerType, true), true) ::
+ StructField("field2", IntegerType, true) :: Nil)
+ val schemaWithComplexMap = StructType(
+ StructField("map", MapType(StringType, innerStruct, true), false) :: Nil)
+
+ val jsonWithComplexMap = jsonRDD(mapType2, schemaWithComplexMap)
+
+ jsonWithComplexMap.registerTempTable("jsonWithComplexMap")
+
+ checkAnswer(
+ sql("select map from jsonWithComplexMap"),
+ Row(Map("a" -> Row(Seq(1, 2, 3, null), null))) ::
+ Row(Map("b" -> Row(null, 2))) ::
+ Row(Map("c" -> Row(Seq(), 4))) ::
+ Row(Map("c" -> Row(null, 3), "d" -> Row(Seq(null), null))) ::
+ Row(Map("e" -> null)) ::
+ Row(Map("f" -> Row(null, null))) :: Nil
+ )
+
+ checkAnswer(
+ sql("select map['a'].field1, map['c'].field2 from jsonWithComplexMap"),
+ Row(Seq(1, 2, 3, null), null) ::
+ Row(null, null) ::
+ Row(null, 4) ::
+ Row(null, 3) ::
+ Row(null, null) ::
+ Row(null, null) :: Nil
+ )
+ }
+
test("SPARK-2096 Correctly parse dot notations") {
val jsonDF = jsonRDD(complexFieldAndType2)
jsonDF.registerTempTable("jsonTable")
@@ -781,12 +838,12 @@ class JsonSuite extends QueryTest {
ArrayType(ArrayType(ArrayType(ArrayType(StringType, false), false), true), false), true) ::
StructField("field2",
ArrayType(ArrayType(
- StructType(StructField("Test", IntegerType, true) :: Nil), false), true), true) ::
+ StructType(StructField("Test", LongType, true) :: Nil), false), true), true) ::
StructField("field3",
ArrayType(ArrayType(
StructType(StructField("Test", StringType, true) :: Nil), true), false), true) ::
StructField("field4",
- ArrayType(ArrayType(ArrayType(IntegerType, false), true), false), true) :: Nil)
+ ArrayType(ArrayType(ArrayType(LongType, false), true), false), true) :: Nil)
assert(schema === jsonDF.schema)
@@ -820,12 +877,12 @@ class JsonSuite extends QueryTest {
Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5)
}
- val df1 = applySchema(rowRDD1, schema1)
+ val df1 = createDataFrame(rowRDD1, schema1)
df1.registerTempTable("applySchema1")
- val df2 = df1.toDataFrame
+ val df2 = df1.toDF
val result = df2.toJSON.collect()
- assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
- assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
+ assert(result(0) === "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}")
+ assert(result(3) === "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}")
val schema2 = StructType(
StructField("f1", StructType(
@@ -841,13 +898,13 @@ class JsonSuite extends QueryTest {
Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4))
}
- val df3 = applySchema(rowRDD2, schema2)
+ val df3 = createDataFrame(rowRDD2, schema2)
df3.registerTempTable("applySchema2")
- val df4 = df3.toDataFrame
+ val df4 = df3.toDF
val result2 = df4.toJSON.collect()
- assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
- assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
+ assert(result2(1) === "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}")
+ assert(result2(3) === "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}")
val jsonDF = jsonRDD(primitiveFieldAndType)
val primTable = jsonRDD(jsonDF.toJSON)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
index 3370b3c98b4be..15698f61e0837 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/json/TestJsonData.scala
@@ -146,6 +146,23 @@ object TestJsonData {
]]
}""" :: Nil)
+ val mapType1 =
+ TestSQLContext.sparkContext.parallelize(
+ """{"map": {"a": 1}}""" ::
+ """{"map": {"b": 2}}""" ::
+ """{"map": {"c": 3}}""" ::
+ """{"map": {"c": 1, "d": 4}}""" ::
+ """{"map": {"e": null}}""" :: Nil)
+
+ val mapType2 =
+ TestSQLContext.sparkContext.parallelize(
+ """{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
+ """{"map": {"b": {"field2": 2}}}""" ::
+ """{"map": {"c": {"field1": [], "field2": 4}}}""" ::
+ """{"map": {"c": {"field2": 3}, "d": {"field1": [null]}}}""" ::
+ """{"map": {"e": null}}""" ::
+ """{"map": {"f": {"field1": null}}}""" :: Nil)
+
val nullsInArrays =
TestSQLContext.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
index f8117c21773ae..4d32e84fc1115 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.parquet
+import org.scalatest.BeforeAndAfterAll
import parquet.filter2.predicate.Operators._
import parquet.filter2.predicate.{FilterPredicate, Operators}
@@ -40,11 +41,11 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, SQLConf}
* 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred
* data type is nullable.
*/
-class ParquetFilterSuite extends QueryTest with ParquetTest {
+class ParquetFilterSuiteBase extends QueryTest with ParquetTest {
val sqlContext = TestSQLContext
private def checkFilterPredicate(
- rdd: DataFrame,
+ df: DataFrame,
predicate: Predicate,
filterClass: Class[_ <: FilterPredicate],
checker: (DataFrame, Seq[Row]) => Unit,
@@ -52,7 +53,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
val output = predicate.collect { case a: Attribute => a }.distinct
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") {
- val query = rdd
+ val query = df
.select(output.map(e => Column(e)): _*)
.where(Column(predicate))
@@ -84,238 +85,252 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
private def checkFilterPredicate
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row])
- (implicit rdd: DataFrame): Unit = {
- checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected)
+ (implicit df: DataFrame): Unit = {
+ checkFilterPredicate(df, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected)
}
private def checkFilterPredicate[T]
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T)
- (implicit rdd: DataFrame): Unit = {
- checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd)
+ (implicit df: DataFrame): Unit = {
+ checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df)
}
private def checkBinaryFilterPredicate
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row])
- (implicit rdd: DataFrame): Unit = {
- def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = {
+ (implicit df: DataFrame): Unit = {
+ def checkBinaryAnswer(df: DataFrame, expected: Seq[Row]) = {
assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) {
- rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted
+ df.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted
}
}
- checkFilterPredicate(rdd, predicate, filterClass, checkBinaryAnswer _, expected)
+ checkFilterPredicate(df, predicate, filterClass, checkBinaryAnswer _, expected)
}
private def checkBinaryFilterPredicate
(predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte])
- (implicit rdd: DataFrame): Unit = {
- checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd)
+ (implicit df: DataFrame): Unit = {
+ checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(df)
}
- def run(prefix: String): Unit = {
- test(s"$prefix: filter pushdown - boolean") {
- withParquetRDD((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false)))
+ test("filter pushdown - boolean") {
+ withParquetDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], Seq(Row(true), Row(false)))
- checkFilterPredicate('_1 === true, classOf[Eq[_]], true)
- checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false)
- }
+ checkFilterPredicate('_1 === true, classOf[Eq[_]], true)
+ checkFilterPredicate('_1 !== true, classOf[NotEq[_]], false)
}
+ }
- test(s"$prefix: filter pushdown - short") {
- withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd =>
- checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1)
- checkFilterPredicate(
- Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
-
- checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1)
- checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4)
- checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4)
-
- checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4)
-
- checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate(
- Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3)
- checkFilterPredicate(
- Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3,
- classOf[Operators.Or],
- Seq(Row(1), Row(4)))
- }
+ test("filter pushdown - short") {
+ withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit df =>
+ checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq[_]], 1)
+ checkFilterPredicate(
+ Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+
+ checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt[_]], 4)
+ checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt[_]], 4)
+ checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4)
+
+ checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate(
+ Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, classOf[Operators.And], 3)
+ checkFilterPredicate(
+ Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3,
+ classOf[Operators.Or],
+ Seq(Row(1), Row(4)))
}
+ }
- test(s"$prefix: filter pushdown - integer") {
- withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
+ test("filter pushdown - integer") {
+ withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
+ checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
+ checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
- checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
- }
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
+ checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
+ }
- test(s"$prefix: filter pushdown - long") {
- withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
+ test("filter pushdown - long") {
+ withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
+ checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
+ checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
- checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
- }
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
+ checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
+ }
- test(s"$prefix: filter pushdown - float") {
- withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
+ test("filter pushdown - float") {
+ withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
+ checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
+ checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
- checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
- }
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
+ checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
+ }
- test(s"$prefix: filter pushdown - double") {
- withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
+ test("filter pushdown - double") {
+ withParquetDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df =>
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
- checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
+ checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1)
+ checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_)))
- checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
- checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
- checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
- checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1)
+ checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4)
+ checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1)
+ checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4)
- checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
- checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
- checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
- checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
- checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
+ checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1)
+ checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1)
+ checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4)
+ checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1)
+ checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4)
- checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
- checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
- checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
- }
+ checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4)
+ checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3)
+ checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4)))
}
+ }
- test(s"$prefix: filter pushdown - string") {
- withParquetRDD((1 to 4).map(i => Tuple1(i.toString))) { implicit rdd =>
- checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkFilterPredicate(
- '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString)))
-
- checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1")
- checkFilterPredicate(
- '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString)))
-
- checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1")
- checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4")
- checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1")
- checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4")
-
- checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1")
- checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1")
- checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4")
- checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1")
- checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4")
-
- checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4")
- checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3")
- checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4")))
- }
+ test("filter pushdown - string") {
+ withParquetDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df =>
+ checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkFilterPredicate(
+ '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString)))
+
+ checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1")
+ checkFilterPredicate(
+ '_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString)))
+
+ checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1")
+ checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4")
+ checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1")
+ checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4")
+
+ checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1")
+ checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1")
+ checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4")
+ checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1")
+ checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4")
+
+ checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4")
+ checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3")
+ checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4")))
}
+ }
- test(s"$prefix: filter pushdown - binary") {
- implicit class IntToBinary(int: Int) {
- def b: Array[Byte] = int.toString.getBytes("UTF-8")
- }
+ test("filter pushdown - binary") {
+ implicit class IntToBinary(int: Int) {
+ def b: Array[Byte] = int.toString.getBytes("UTF-8")
+ }
- withParquetRDD((1 to 4).map(i => Tuple1(i.b))) { implicit rdd =>
- checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b)
+ withParquetDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df =>
+ checkBinaryFilterPredicate('_1 === 1.b, classOf[Eq[_]], 1.b)
- checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
- checkBinaryFilterPredicate(
- '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq)
+ checkBinaryFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row])
+ checkBinaryFilterPredicate(
+ '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.b)).toSeq)
- checkBinaryFilterPredicate(
- '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq)
+ checkBinaryFilterPredicate(
+ '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq)
- checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b)
- checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b)
- checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b)
- checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b)
+ checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b)
+ checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b)
+ checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b)
+ checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b)
- checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b)
- checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b)
- checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b)
- checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b)
- checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b)
+ checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b)
+ checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b)
+ checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b)
- checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b)
- checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b)
- checkBinaryFilterPredicate(
- '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b)))
- }
+ checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b)
+ checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b)
+ checkBinaryFilterPredicate(
+ '_1 < 2.b || '_1 > 3.b, classOf[Operators.Or], Seq(Row(1.b), Row(4.b)))
}
}
+}
+
+class ParquetDataSourceOnFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll {
+ val originalConf = sqlContext.conf.parquetUseDataSourceApi
+
+ override protected def beforeAll(): Unit = {
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ }
+
+ override protected def afterAll(): Unit = {
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ }
+}
+
+class ParquetDataSourceOffFilterSuite extends ParquetFilterSuiteBase with BeforeAndAfterAll {
+ val originalConf = sqlContext.conf.parquetUseDataSourceApi
- withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") {
- run("Parquet data source enabled")
+ override protected def beforeAll(): Unit = {
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
}
- withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") {
- run("Parquet data source disabled")
+ override protected def afterAll(): Unit = {
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
index c8ebbbc7d2eac..36f3406a7825f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala
@@ -21,6 +21,9 @@ import scala.collection.JavaConversions._
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.fs.{FileSystem, Path}
+import org.scalatest.BeforeAndAfterAll
import parquet.example.data.simple.SimpleGroup
import parquet.example.data.{Group, GroupWriter}
import parquet.hadoop.api.WriteSupport
@@ -30,15 +33,13 @@ import parquet.hadoop.{ParquetFileWriter, ParquetWriter}
import parquet.io.api.RecordConsumer
import parquet.schema.{MessageType, MessageTypeParser}
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.fs.{FileSystem, Path}
-import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf}
-import org.apache.spark.sql.Dsl._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.TestSQLContext.implicits._
import org.apache.spark.sql.types.DecimalType
+import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf, SaveMode}
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
// with an empty configuration (it is after all not intended to be used in this way?)
@@ -63,239 +64,293 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
/**
* A test suite that tests basic Parquet I/O.
*/
-class ParquetIOSuite extends QueryTest with ParquetTest {
+class ParquetIOSuiteBase extends QueryTest with ParquetTest {
val sqlContext = TestSQLContext
+ import sqlContext.implicits.localSeqToDataFrameHolder
+
/**
* Writes `data` to a Parquet file, reads it back and check file contents.
*/
protected def checkParquetFile[T <: Product : ClassTag: TypeTag](data: Seq[T]): Unit = {
- withParquetRDD(data)(r => checkAnswer(r, data.map(Row.fromTuple)))
+ withParquetDataFrame(data)(r => checkAnswer(r, data.map(Row.fromTuple)))
}
- def run(prefix: String): Unit = {
- test(s"$prefix: basic data types (without binary)") {
- val data = (1 to 4).map { i =>
- (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
- }
- checkParquetFile(data)
+ test("basic data types (without binary)") {
+ val data = (1 to 4).map { i =>
+ (i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
}
+ checkParquetFile(data)
+ }
- test(s"$prefix: raw binary") {
- val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte)))
- withParquetRDD(data) { rdd =>
- assertResult(data.map(_._1.mkString(",")).sorted) {
- rdd.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted
- }
+ test("raw binary") {
+ val data = (1 to 4).map(i => Tuple1(Array.fill(3)(i.toByte)))
+ withParquetDataFrame(data) { df =>
+ assertResult(data.map(_._1.mkString(",")).sorted) {
+ df.collect().map(_.getAs[Array[Byte]](0).mkString(",")).sorted
}
}
+ }
- test(s"$prefix: string") {
- val data = (1 to 4).map(i => Tuple1(i.toString))
- // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL
- // as we store Spark SQL schema in the extra metadata.
- withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data))
- withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data))
- }
+ test("string") {
+ val data = (1 to 4).map(i => Tuple1(i.toString))
+ // Property spark.sql.parquet.binaryAsString shouldn't affect Parquet files written by Spark SQL
+ // as we store Spark SQL schema in the extra metadata.
+ withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "false")(checkParquetFile(data))
+ withSQLConf(SQLConf.PARQUET_BINARY_AS_STRING -> "true")(checkParquetFile(data))
+ }
- test(s"$prefix: fixed-length decimals") {
- import org.apache.spark.sql.test.TestSQLContext.implicits._
-
- def makeDecimalRDD(decimal: DecimalType): DataFrame =
- sparkContext
- .parallelize(0 to 1000)
- .map(i => Tuple1(i / 100.0))
- // Parquet doesn't allow column names with spaces, have to add an alias here
- .select($"_1" cast decimal as "dec")
-
- for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
- withTempPath { dir =>
- val data = makeDecimalRDD(DecimalType(precision, scale))
- data.saveAsParquetFile(dir.getCanonicalPath)
- checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq)
- }
+ test("fixed-length decimals") {
+
+ def makeDecimalRDD(decimal: DecimalType): DataFrame =
+ sparkContext
+ .parallelize(0 to 1000)
+ .map(i => Tuple1(i / 100.0))
+ .toDF()
+ // Parquet doesn't allow column names with spaces, have to add an alias here
+ .select($"_1" cast decimal as "dec")
+
+ for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) {
+ withTempPath { dir =>
+ val data = makeDecimalRDD(DecimalType(precision, scale))
+ data.saveAsParquetFile(dir.getCanonicalPath)
+ checkAnswer(parquetFile(dir.getCanonicalPath), data.collect().toSeq)
}
+ }
- // Decimals with precision above 18 are not yet supported
- intercept[RuntimeException] {
- withTempPath { dir =>
- makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath)
- parquetFile(dir.getCanonicalPath).collect()
- }
+ // Decimals with precision above 18 are not yet supported
+ intercept[RuntimeException] {
+ withTempPath { dir =>
+ makeDecimalRDD(DecimalType(19, 10)).saveAsParquetFile(dir.getCanonicalPath)
+ parquetFile(dir.getCanonicalPath).collect()
}
+ }
- // Unlimited-length decimals are not yet supported
- intercept[RuntimeException] {
- withTempPath { dir =>
- makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath)
- parquetFile(dir.getCanonicalPath).collect()
- }
+ // Unlimited-length decimals are not yet supported
+ intercept[RuntimeException] {
+ withTempPath { dir =>
+ makeDecimalRDD(DecimalType.Unlimited).saveAsParquetFile(dir.getCanonicalPath)
+ parquetFile(dir.getCanonicalPath).collect()
}
}
+ }
+
+ test("map") {
+ val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i")))
+ checkParquetFile(data)
+ }
- test(s"$prefix: map") {
- val data = (1 to 4).map(i => Tuple1(Map(i -> s"val_$i")))
- checkParquetFile(data)
+ test("array") {
+ val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1)))
+ checkParquetFile(data)
+ }
+
+ test("struct") {
+ val data = (1 to 4).map(i => Tuple1((i, s"val_$i")))
+ withParquetDataFrame(data) { df =>
+ // Structs are converted to `Row`s
+ checkAnswer(df, data.map { case Tuple1(struct) =>
+ Row(Row(struct.productIterator.toSeq: _*))
+ })
}
+ }
- test(s"$prefix: array") {
- val data = (1 to 4).map(i => Tuple1(Seq(i, i + 1)))
- checkParquetFile(data)
+ test("nested struct with array of array as field") {
+ val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i")))))
+ withParquetDataFrame(data) { df =>
+ // Structs are converted to `Row`s
+ checkAnswer(df, data.map { case Tuple1(struct) =>
+ Row(Row(struct.productIterator.toSeq: _*))
+ })
}
+ }
- test(s"$prefix: struct") {
- val data = (1 to 4).map(i => Tuple1((i, s"val_$i")))
- withParquetRDD(data) { rdd =>
- // Structs are converted to `Row`s
- checkAnswer(rdd, data.map { case Tuple1(struct) =>
- Row(Row(struct.productIterator.toSeq: _*))
- })
- }
+ test("nested map with struct as value type") {
+ val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i"))))
+ withParquetDataFrame(data) { df =>
+ checkAnswer(df, data.map { case Tuple1(m) =>
+ Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*)))
+ })
}
+ }
- test(s"$prefix: nested struct with array of array as field") {
- val data = (1 to 4).map(i => Tuple1((i, Seq(Seq(s"val_$i")))))
- withParquetRDD(data) { rdd =>
- // Structs are converted to `Row`s
- checkAnswer(rdd, data.map { case Tuple1(struct) =>
- Row(Row(struct.productIterator.toSeq: _*))
- })
- }
+ test("nulls") {
+ val allNulls = (
+ null.asInstanceOf[java.lang.Boolean],
+ null.asInstanceOf[Integer],
+ null.asInstanceOf[java.lang.Long],
+ null.asInstanceOf[java.lang.Float],
+ null.asInstanceOf[java.lang.Double])
+
+ withParquetDataFrame(allNulls :: Nil) { df =>
+ val rows = df.collect()
+ assert(rows.size === 1)
+ assert(rows.head === Row(Seq.fill(5)(null): _*))
}
+ }
- test(s"$prefix: nested map with struct as value type") {
- val data = (1 to 4).map(i => Tuple1(Map(i -> (i, s"val_$i"))))
- withParquetRDD(data) { rdd =>
- checkAnswer(rdd, data.map { case Tuple1(m) =>
- Row(m.mapValues(struct => Row(struct.productIterator.toSeq: _*)))
- })
- }
+ test("nones") {
+ val allNones = (
+ None.asInstanceOf[Option[Int]],
+ None.asInstanceOf[Option[Long]],
+ None.asInstanceOf[Option[String]])
+
+ withParquetDataFrame(allNones :: Nil) { df =>
+ val rows = df.collect()
+ assert(rows.size === 1)
+ assert(rows.head === Row(Seq.fill(3)(null): _*))
}
+ }
- test(s"$prefix: nulls") {
- val allNulls = (
- null.asInstanceOf[java.lang.Boolean],
- null.asInstanceOf[Integer],
- null.asInstanceOf[java.lang.Long],
- null.asInstanceOf[java.lang.Float],
- null.asInstanceOf[java.lang.Double])
-
- withParquetRDD(allNulls :: Nil) { rdd =>
- val rows = rdd.collect()
- assert(rows.size === 1)
- assert(rows.head === Row(Seq.fill(5)(null): _*))
- }
+ test("compression codec") {
+ def compressionCodecFor(path: String) = {
+ val codecs = ParquetTypesConverter
+ .readMetaData(new Path(path), Some(configuration))
+ .getBlocks
+ .flatMap(_.getColumns)
+ .map(_.getCodec.name())
+ .distinct
+
+ assert(codecs.size === 1)
+ codecs.head
}
- test(s"$prefix: nones") {
- val allNones = (
- None.asInstanceOf[Option[Int]],
- None.asInstanceOf[Option[Long]],
- None.asInstanceOf[Option[String]])
+ val data = (0 until 10).map(i => (i, i.toString))
- withParquetRDD(allNones :: Nil) { rdd =>
- val rows = rdd.collect()
- assert(rows.size === 1)
- assert(rows.head === Row(Seq.fill(3)(null): _*))
+ def checkCompressionCodec(codec: CompressionCodecName): Unit = {
+ withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) {
+ withParquetFile(data) { path =>
+ assertResult(conf.parquetCompressionCodec.toUpperCase) {
+ compressionCodecFor(path)
+ }
+ }
}
}
- test(s"$prefix: compression codec") {
- def compressionCodecFor(path: String) = {
- val codecs = ParquetTypesConverter
- .readMetaData(new Path(path), Some(configuration))
- .getBlocks
- .flatMap(_.getColumns)
- .map(_.getCodec.name())
- .distinct
-
- assert(codecs.size === 1)
- codecs.head
- }
+ // Checks default compression codec
+ checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec))
- val data = (0 until 10).map(i => (i, i.toString))
+ checkCompressionCodec(CompressionCodecName.UNCOMPRESSED)
+ checkCompressionCodec(CompressionCodecName.GZIP)
+ checkCompressionCodec(CompressionCodecName.SNAPPY)
+ }
- def checkCompressionCodec(codec: CompressionCodecName): Unit = {
- withSQLConf(SQLConf.PARQUET_COMPRESSION -> codec.name()) {
- withParquetFile(data) { path =>
- assertResult(conf.parquetCompressionCodec.toUpperCase) {
- compressionCodecFor(path)
- }
- }
- }
+ test("read raw Parquet file") {
+ def makeRawParquetFile(path: Path): Unit = {
+ val schema = MessageTypeParser.parseMessageType(
+ """
+ |message root {
+ | required boolean _1;
+ | required int32 _2;
+ | required int64 _3;
+ | required float _4;
+ | required double _5;
+ |}
+ """.stripMargin)
+
+ val writeSupport = new TestGroupWriteSupport(schema)
+ val writer = new ParquetWriter[Group](path, writeSupport)
+
+ (0 until 10).foreach { i =>
+ val record = new SimpleGroup(schema)
+ record.add(0, i % 2 == 0)
+ record.add(1, i)
+ record.add(2, i.toLong)
+ record.add(3, i.toFloat)
+ record.add(4, i.toDouble)
+ writer.write(record)
}
- // Checks default compression codec
- checkCompressionCodec(CompressionCodecName.fromConf(conf.parquetCompressionCodec))
+ writer.close()
+ }
- checkCompressionCodec(CompressionCodecName.UNCOMPRESSED)
- checkCompressionCodec(CompressionCodecName.GZIP)
- checkCompressionCodec(CompressionCodecName.SNAPPY)
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "part-r-0.parquet")
+ makeRawParquetFile(path)
+ checkAnswer(parquetFile(path.toString), (0 until 10).map { i =>
+ Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
+ })
}
+ }
- test(s"$prefix: read raw Parquet file") {
- def makeRawParquetFile(path: Path): Unit = {
- val schema = MessageTypeParser.parseMessageType(
- """
- |message root {
- | required boolean _1;
- | required int32 _2;
- | required int64 _3;
- | required float _4;
- | required double _5;
- |}
- """.stripMargin)
-
- val writeSupport = new TestGroupWriteSupport(schema)
- val writer = new ParquetWriter[Group](path, writeSupport)
-
- (0 until 10).foreach { i =>
- val record = new SimpleGroup(schema)
- record.add(0, i % 2 == 0)
- record.add(1, i)
- record.add(2, i.toLong)
- record.add(3, i.toFloat)
- record.add(4, i.toDouble)
- writer.write(record)
- }
+ test("write metadata") {
+ withTempPath { file =>
+ val path = new Path(file.toURI.toString)
+ val fs = FileSystem.getLocal(configuration)
+ val attributes = ScalaReflection.attributesFor[(Int, String)]
+ ParquetTypesConverter.writeMetaData(attributes, path, configuration)
- writer.close()
- }
+ assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE)))
+ assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE)))
- withTempDir { dir =>
- val path = new Path(dir.toURI.toString, "part-r-0.parquet")
- makeRawParquetFile(path)
- checkAnswer(parquetFile(path.toString), (0 until 10).map { i =>
- Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble)
- })
- }
+ val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration))
+ val actualSchema = metaData.getFileMetaData.getSchema
+ val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes)
+
+ actualSchema.checkContains(expectedSchema)
+ expectedSchema.checkContains(actualSchema)
}
+ }
- test(s"$prefix: write metadata") {
- withTempPath { file =>
- val path = new Path(file.toURI.toString)
- val fs = FileSystem.getLocal(configuration)
- val attributes = ScalaReflection.attributesFor[(Int, String)]
- ParquetTypesConverter.writeMetaData(attributes, path, configuration)
+ test("save - overwrite") {
+ withParquetFile((1 to 10).map(i => (i, i.toString))) { file =>
+ val newData = (11 to 20).map(i => (i, i.toString))
+ newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Overwrite, Map("path" -> file))
+ checkAnswer(parquetFile(file), newData.map(Row.fromTuple))
+ }
+ }
- assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_COMMON_METADATA_FILE)))
- assert(fs.exists(new Path(path, ParquetFileWriter.PARQUET_METADATA_FILE)))
+ test("save - ignore") {
+ val data = (1 to 10).map(i => (i, i.toString))
+ withParquetFile(data) { file =>
+ val newData = (11 to 20).map(i => (i, i.toString))
+ newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Ignore, Map("path" -> file))
+ checkAnswer(parquetFile(file), data.map(Row.fromTuple))
+ }
+ }
- val metaData = ParquetTypesConverter.readMetaData(path, Some(configuration))
- val actualSchema = metaData.getFileMetaData.getSchema
- val expectedSchema = ParquetTypesConverter.convertFromAttributes(attributes)
+ test("save - throw") {
+ val data = (1 to 10).map(i => (i, i.toString))
+ withParquetFile(data) { file =>
+ val newData = (11 to 20).map(i => (i, i.toString))
+ val errorMessage = intercept[Throwable] {
+ newData.toDF().save(
+ "org.apache.spark.sql.parquet", SaveMode.ErrorIfExists, Map("path" -> file))
+ }.getMessage
+ assert(errorMessage.contains("already exists"))
+ }
+ }
- actualSchema.checkContains(expectedSchema)
- expectedSchema.checkContains(actualSchema)
- }
+ test("save - append") {
+ val data = (1 to 10).map(i => (i, i.toString))
+ withParquetFile(data) { file =>
+ val newData = (11 to 20).map(i => (i, i.toString))
+ newData.toDF().save("org.apache.spark.sql.parquet", SaveMode.Append, Map("path" -> file))
+ checkAnswer(parquetFile(file), (data ++ newData).map(Row.fromTuple))
}
}
+}
+
+class ParquetDataSourceOnIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll {
+ val originalConf = sqlContext.conf.parquetUseDataSourceApi
+
+ override protected def beforeAll(): Unit = {
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ }
+
+ override protected def afterAll(): Unit = {
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ }
+}
+
+class ParquetDataSourceOffIOSuite extends ParquetIOSuiteBase with BeforeAndAfterAll {
+ val originalConf = sqlContext.conf.parquetUseDataSourceApi
- withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") {
- run("Parquet data source enabled")
+ override protected def beforeAll(): Unit = {
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
}
- withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") {
- run("Parquet data source disabled")
+ override protected def afterAll(): Unit = {
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
index ae606d11a8f68..adb3c9391f6c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetPartitionDiscoverySuite.scala
@@ -19,17 +19,25 @@ package org.apache.spark.sql.parquet
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
-import org.scalatest.FunSuite
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.parquet.ParquetRelation2._
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.{QueryTest, Row, SQLContext}
-class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest {
+// The data where the partitioning key exists only in the directory structure.
+case class ParquetData(intField: Int, stringField: String)
+
+// The data that also includes the partitioning key
+case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
+
+class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
override val sqlContext: SQLContext = TestSQLContext
+ import sqlContext._
+ import sqlContext.implicits._
+
val defaultPartitionName = "__NULL__"
test("column type inference") {
@@ -112,6 +120,17 @@ class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest {
Partition(Row(10, "20"), "hdfs://host:9000/path/a=10/b=20"),
Partition(Row(10.5, "hello"), "hdfs://host:9000/path/a=10.5/b=hello"))))
+ check(Seq(
+ s"hdfs://host:9000/path/a=10/b=20",
+ s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"),
+ PartitionSpec(
+ StructType(Seq(
+ StructField("a", IntegerType),
+ StructField("b", StringType))),
+ Seq(
+ Partition(Row(10, "20"), s"hdfs://host:9000/path/a=10/b=20"),
+ Partition(Row(null, "hello"), s"hdfs://host:9000/path/a=$defaultPartitionName/b=hello"))))
+
check(Seq(
s"hdfs://host:9000/path/a=10/b=$defaultPartitionName",
s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"),
@@ -123,4 +142,202 @@ class ParquetPartitionDiscoverySuite extends FunSuite with ParquetTest {
Partition(Row(10, null), s"hdfs://host:9000/path/a=10/b=$defaultPartitionName"),
Partition(Row(10.5, null), s"hdfs://host:9000/path/a=10.5/b=$defaultPartitionName"))))
}
+
+ test("read partitioned table - normal case") {
+ withTempDir { base =>
+ for {
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } {
+ makeParquetFile(
+ (1 to 10).map(i => ParquetData(i, i.toString)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ parquetFile(base.getCanonicalPath).registerTempTable("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } yield Row(i, i.toString, pi, ps))
+
+ checkAnswer(
+ sql("SELECT intField, pi FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ _ <- Seq("foo", "bar")
+ } yield Row(i, pi))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE pi = 1"),
+ for {
+ i <- 1 to 10
+ ps <- Seq("foo", "bar")
+ } yield Row(i, i.toString, 1, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps = 'foo'"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ } yield Row(i, i.toString, pi, "foo"))
+ }
+ }
+ }
+
+ test("read partitioned table - partition key included in Parquet file") {
+ withTempDir { base =>
+ for {
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } {
+ makeParquetFile(
+ (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ parquetFile(base.getCanonicalPath).registerTempTable("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", "bar")
+ } yield Row(i, pi, i.toString, ps))
+
+ checkAnswer(
+ sql("SELECT intField, pi FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ _ <- Seq("foo", "bar")
+ } yield Row(i, pi))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE pi = 1"),
+ for {
+ i <- 1 to 10
+ ps <- Seq("foo", "bar")
+ } yield Row(i, 1, i.toString, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps = 'foo'"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ } yield Row(i, pi, i.toString, "foo"))
+ }
+ }
+ }
+
+ test("read partitioned table - with nulls") {
+ withTempDir { base =>
+ for {
+ // Must be `Integer` rather than `Int` here. `null.asInstanceOf[Int]` results in a zero...
+ pi <- Seq(1, null.asInstanceOf[Integer])
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } {
+ makeParquetFile(
+ (1 to 10).map(i => ParquetData(i, i.toString)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ val parquetRelation = load(
+ "org.apache.spark.sql.parquet",
+ Map(
+ "path" -> base.getCanonicalPath,
+ ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName))
+
+ parquetRelation.registerTempTable("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, null.asInstanceOf[Integer])
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } yield Row(i, i.toString, pi, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE pi IS NULL"),
+ for {
+ i <- 1 to 10
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } yield Row(i, i.toString, null, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps IS NULL"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, null.asInstanceOf[Integer])
+ } yield Row(i, i.toString, pi, null))
+ }
+ }
+ }
+
+ test("read partitioned table - with nulls and partition keys are included in Parquet file") {
+ withTempDir { base =>
+ for {
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } {
+ makeParquetFile(
+ (1 to 10).map(i => ParquetDataWithKey(i, pi, i.toString, ps)),
+ makePartitionDir(base, defaultPartitionName, "pi" -> pi, "ps" -> ps))
+ }
+
+ val parquetRelation = load(
+ "org.apache.spark.sql.parquet",
+ Map(
+ "path" -> base.getCanonicalPath,
+ ParquetRelation2.DEFAULT_PARTITION_NAME -> defaultPartitionName))
+
+ parquetRelation.registerTempTable("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ ps <- Seq("foo", null.asInstanceOf[String])
+ } yield Row(i, pi, i.toString, ps))
+
+ checkAnswer(
+ sql("SELECT * FROM t WHERE ps IS NULL"),
+ for {
+ i <- 1 to 10
+ pi <- Seq(1, 2)
+ } yield Row(i, pi, i.toString, null))
+ }
+ }
+ }
+
+ test("read partitioned table - merging compatible schemas") {
+ withTempDir { base =>
+ makeParquetFile(
+ (1 to 10).map(i => Tuple1(i)).toDF("intField"),
+ makePartitionDir(base, defaultPartitionName, "pi" -> 1))
+
+ makeParquetFile(
+ (1 to 10).map(i => (i, i.toString)).toDF("intField", "stringField"),
+ makePartitionDir(base, defaultPartitionName, "pi" -> 2))
+
+ load(base.getCanonicalPath, "org.apache.spark.sql.parquet").registerTempTable("t")
+
+ withTempTable("t") {
+ checkAnswer(
+ sql("SELECT * FROM t"),
+ (1 to 10).map(i => Row(i, null, 1)) ++ (1 to 10).map(i => Row(i, i.toString, 2)))
+ }
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index cba06835f9a61..b98ba09ccfc2d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -17,103 +17,122 @@
package org.apache.spark.sql.parquet
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.{SQLConf, QueryTest}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.{QueryTest, SQLConf}
/**
* A test suite that tests various Parquet queries.
*/
-class ParquetQuerySuite extends QueryTest with ParquetTest {
+class ParquetQuerySuiteBase extends QueryTest with ParquetTest {
val sqlContext = TestSQLContext
- def run(prefix: String): Unit = {
- test(s"$prefix: simple projection") {
- withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
- checkAnswer(sql("SELECT _1 FROM t"), (0 until 10).map(Row.apply(_)))
- }
+ test("simple select queries") {
+ withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
+ checkAnswer(sql("SELECT _1 FROM t where t._1 > 5"), (6 until 10).map(Row.apply(_)))
+ checkAnswer(sql("SELECT _1 FROM t as tmp where tmp._1 < 5"), (0 until 5).map(Row.apply(_)))
}
+ }
- test(s"$prefix: appending") {
- val data = (0 until 10).map(i => (i, i.toString))
- withParquetTable(data, "t") {
- sql("INSERT INTO TABLE t SELECT * FROM t")
- checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
- }
+ test("appending") {
+ val data = (0 until 10).map(i => (i, i.toString))
+ createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ withParquetTable(data, "t") {
+ sql("INSERT INTO TABLE t SELECT * FROM tmp")
+ checkAnswer(table("t"), (data ++ data).map(Row.fromTuple))
}
+ catalog.unregisterTable(Seq("tmp"))
+ }
- // This test case will trigger the NPE mentioned in
- // https://issues.apache.org/jira/browse/PARQUET-151.
- ignore(s"$prefix: overwriting") {
- val data = (0 until 10).map(i => (i, i.toString))
- withParquetTable(data, "t") {
- sql("INSERT OVERWRITE TABLE t SELECT * FROM t")
- checkAnswer(table("t"), data.map(Row.fromTuple))
- }
+ test("overwriting") {
+ val data = (0 until 10).map(i => (i, i.toString))
+ createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ withParquetTable(data, "t") {
+ sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
+ checkAnswer(table("t"), data.map(Row.fromTuple))
}
+ catalog.unregisterTable(Seq("tmp"))
+ }
- test(s"$prefix: self-join") {
- // 4 rows, cells of column 1 of row 2 and row 4 are null
- val data = (1 to 4).map { i =>
- val maybeInt = if (i % 2 == 0) None else Some(i)
- (maybeInt, i.toString)
- }
-
- withParquetTable(data, "t") {
- val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1")
- val queryOutput = selfJoin.queryExecution.analyzed.output
+ test("self-join") {
+ // 4 rows, cells of column 1 of row 2 and row 4 are null
+ val data = (1 to 4).map { i =>
+ val maybeInt = if (i % 2 == 0) None else Some(i)
+ (maybeInt, i.toString)
+ }
- assertResult(4, s"Field count mismatches")(queryOutput.size)
- assertResult(2, s"Duplicated expression ID in query plan:\n $selfJoin") {
- queryOutput.filter(_.name == "_1").map(_.exprId).size
- }
+ withParquetTable(data, "t") {
+ val selfJoin = sql("SELECT * FROM t x JOIN t y WHERE x._1 = y._1")
+ val queryOutput = selfJoin.queryExecution.analyzed.output
- checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3")))
+ assertResult(4, "Field count mismatches")(queryOutput.size)
+ assertResult(2, "Duplicated expression ID in query plan:\n $selfJoin") {
+ queryOutput.filter(_.name == "_1").map(_.exprId).size
}
+
+ checkAnswer(selfJoin, List(Row(1, "1", 1, "1"), Row(3, "3", 3, "3")))
}
+ }
- test(s"$prefix: nested data - struct with array field") {
- val data = (1 to 10).map(i => Tuple1((i, Seq(s"val_$i"))))
- withParquetTable(data, "t") {
- checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map {
- case Tuple1((_, Seq(string))) => Row(string)
- })
- }
+ test("nested data - struct with array field") {
+ val data = (1 to 10).map(i => Tuple1((i, Seq("val_$i"))))
+ withParquetTable(data, "t") {
+ checkAnswer(sql("SELECT _1._2[0] FROM t"), data.map {
+ case Tuple1((_, Seq(string))) => Row(string)
+ })
}
+ }
- test(s"$prefix: nested data - array of struct") {
- val data = (1 to 10).map(i => Tuple1(Seq(i -> s"val_$i")))
- withParquetTable(data, "t") {
- checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map {
- case Tuple1(Seq((_, string))) => Row(string)
- })
- }
+ test("nested data - array of struct") {
+ val data = (1 to 10).map(i => Tuple1(Seq(i -> "val_$i")))
+ withParquetTable(data, "t") {
+ checkAnswer(sql("SELECT _1[0]._2 FROM t"), data.map {
+ case Tuple1(Seq((_, string))) => Row(string)
+ })
}
+ }
- test(s"$prefix: SPARK-1913 regression: columns only referenced by pushed down filters should remain") {
- withParquetTable((1 to 10).map(Tuple1.apply), "t") {
- checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_)))
- }
+ test("SPARK-1913 regression: columns only referenced by pushed down filters should remain") {
+ withParquetTable((1 to 10).map(Tuple1.apply), "t") {
+ checkAnswer(sql("SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_)))
}
+ }
- test(s"$prefix: SPARK-5309 strings stored using dictionary compression in parquet") {
- withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") {
+ test("SPARK-5309 strings stored using dictionary compression in parquet") {
+ withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") {
- checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"),
- (0 until 10).map(i => Row("same", "run_" + i, 100)))
+ checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"),
+ (0 until 10).map(i => Row("same", "run_" + i, 100)))
- checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"),
- List(Row("same", "run_5", 100)))
- }
+ checkAnswer(sql("SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"),
+ List(Row("same", "run_5", 100)))
}
}
+}
+
+class ParquetDataSourceOnQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll {
+ val originalConf = sqlContext.conf.parquetUseDataSourceApi
+
+ override protected def beforeAll(): Unit = {
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ }
+
+ override protected def afterAll(): Unit = {
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ }
+}
+
+class ParquetDataSourceOffQuerySuite extends ParquetQuerySuiteBase with BeforeAndAfterAll {
+ val originalConf = sqlContext.conf.parquetUseDataSourceApi
- withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "true") {
- run("Parquet data source enabled")
+ override protected def beforeAll(): Unit = {
+ sqlContext.conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
}
- withSQLConf(SQLConf.PARQUET_USE_DATA_SOURCE_API -> "false") {
- run("Parquet data source disabled")
+ override protected def afterAll(): Unit = {
+ sqlContext.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
index 2e6c2d5f9ab55..ad880e2bc3679 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetSchemaSuite.scala
@@ -36,8 +36,8 @@ class ParquetSchemaSuite extends FunSuite with ParquetTest {
private def testSchema[T <: Product: ClassTag: TypeTag](
testName: String, messageType: String, isThriftDerived: Boolean = false): Unit = {
test(testName) {
- val actual = ParquetTypesConverter.convertFromAttributes(ScalaReflection.attributesFor[T],
- isThriftDerived)
+ val actual = ParquetTypesConverter.convertFromAttributes(
+ ScalaReflection.attributesFor[T], isThriftDerived)
val expected = MessageTypeParser.parseMessageType(messageType)
actual.checkContains(expected)
expected.checkContains(actual)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index b02389978b625..60355414a40fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.sources
import java.io.File
+import org.apache.spark.sql.AnalysisException
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.catalyst.util
@@ -77,12 +78,10 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a, b FROM jsonTable"),
sql("SELECT a, b FROM jt").collect())
- dropTempTable("jsonTable")
-
- val message = intercept[RuntimeException]{
+ val message = intercept[DDLException]{
sql(
s"""
- |CREATE TEMPORARY TABLE jsonTable
+ |CREATE TEMPORARY TABLE IF NOT EXISTS jsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
| path '${path.toString}'
@@ -91,10 +90,25 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
""".stripMargin)
}.getMessage
assert(
- message.contains(s"path ${path.toString} already exists."),
+ message.contains(s"a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause."),
"CREATE TEMPORARY TABLE IF NOT EXISTS should not be allowed.")
- // Explicitly delete it.
+ // Overwrite the temporary table.
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE jsonTable
+ |USING org.apache.spark.sql.json.DefaultSource
+ |OPTIONS (
+ | path '${path.toString}'
+ |) AS
+ |SELECT a * 4 FROM jt
+ """.stripMargin)
+ checkAnswer(
+ sql("SELECT * FROM jsonTable"),
+ sql("SELECT a * 4 FROM jt").collect())
+
+ dropTempTable("jsonTable")
+ // Explicitly delete the data.
if (path.exists()) Utils.deleteRecursively(path)
sql(
@@ -104,12 +118,12 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
|OPTIONS (
| path '${path.toString}'
|) AS
- |SELECT a * 4 FROM jt
+ |SELECT b FROM jt
""".stripMargin)
checkAnswer(
sql("SELECT * FROM jsonTable"),
- sql("SELECT a * 4 FROM jt").collect())
+ sql("SELECT b FROM jt").collect())
dropTempTable("jsonTable")
}
@@ -144,4 +158,31 @@ class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
""".stripMargin)
}
}
+
+ test("it is not allowed to write to a table while querying it.") {
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE jsonTable
+ |USING org.apache.spark.sql.json.DefaultSource
+ |OPTIONS (
+ | path '${path.toString}'
+ |) AS
+ |SELECT a, b FROM jt
+ """.stripMargin)
+
+ val message = intercept[AnalysisException] {
+ sql(
+ s"""
+ |CREATE TEMPORARY TABLE jsonTable
+ |USING org.apache.spark.sql.json.DefaultSource
+ |OPTIONS (
+ | path '${path.toString}'
+ |) AS
+ |SELECT a, b FROM jsonTable
+ """.stripMargin)
+ }.getMessage
+ assert(
+ message.contains("Cannot overwrite table "),
+ "Writing to a table while querying it should not be allowed.")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 53f5f7426e9e6..91c6367371f15 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -29,7 +29,7 @@ abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
@transient
override protected[sql] lazy val analyzer: Analyzer =
new Analyzer(catalog, functionRegistry, caseSensitive = false) {
- override val extendedRules =
+ override val extendedResolutionRules =
PreInsertCastAndRename ::
Nil
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 390538d35a348..41cd35683c196 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -47,16 +47,22 @@ case class SimpleFilteredScan(from: Int, to: Int)(@transient val sqlContext: SQL
FiltersPushed.list = filters
- val filterFunctions = filters.collect {
+ def translateFilter(filter: Filter): Int => Boolean = filter match {
case EqualTo("a", v) => (a: Int) => a == v
case LessThan("a", v: Int) => (a: Int) => a < v
case LessThanOrEqual("a", v: Int) => (a: Int) => a <= v
case GreaterThan("a", v: Int) => (a: Int) => a > v
case GreaterThanOrEqual("a", v: Int) => (a: Int) => a >= v
case In("a", values) => (a: Int) => values.map(_.asInstanceOf[Int]).toSet.contains(a)
+ case IsNull("a") => (a: Int) => false // Int can't be null
+ case IsNotNull("a") => (a: Int) => true
+ case Not(pred) => (a: Int) => !translateFilter(pred)(a)
+ case And(left, right) => (a: Int) => translateFilter(left)(a) && translateFilter(right)(a)
+ case Or(left, right) => (a: Int) => translateFilter(left)(a) || translateFilter(right)(a)
+ case _ => (a: Int) => true
}
- def eval(a: Int) = !filterFunctions.map(_(a)).contains(false)
+ def eval(a: Int) = !filters.map(translateFilter(_)(a)).contains(false)
sqlContext.sparkContext.parallelize(from to to).filter(eval).map(i =>
Row.fromSeq(rowBuilders.map(_(i)).reduceOption(_ ++ _).getOrElse(Seq.empty)))
@@ -136,6 +142,26 @@ class FilteredScanSuite extends DataSourceTest {
"SELECT * FROM oneToTenFiltered WHERE b = 2",
Seq(1).map(i => Row(i, i * 2)).toSeq)
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE a IS NULL",
+ Seq.empty[Row])
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE a IS NOT NULL",
+ (1 to 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1",
+ (2 to 4).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8",
+ Seq(1, 2, 9, 10).map(i => Row(i, i * 2)).toSeq)
+
+ sqlTest(
+ "SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)",
+ (6 to 10).map(i => Row(i, i * 2)).toSeq)
+
testPushDown("SELECT * FROM oneToTenFiltered WHERE A = 1", 1)
testPushDown("SELECT a FROM oneToTenFiltered WHERE A = 1", 1)
testPushDown("SELECT b FROM oneToTenFiltered WHERE A = 1", 1)
@@ -162,6 +188,10 @@ class FilteredScanSuite extends DataSourceTest {
testPushDown("SELECT * FROM oneToTenFiltered WHERE a = 20", 0)
testPushDown("SELECT * FROM oneToTenFiltered WHERE b = 1", 10)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 5 AND a > 1", 3)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE a < 3 OR a > 8", 4)
+ testPushDown("SELECT * FROM oneToTenFiltered WHERE NOT (a < 6)", 5)
+
def testPushDown(sqlString: String, expectedCount: Int): Unit = {
test(s"PushDown Returns $expectedCount: $sqlString") {
val queryExecution = sql(sqlString).queryExecution
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
similarity index 78%
rename from sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
rename to sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index 36e504e759152..b5b16f9546691 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertIntoSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -21,11 +21,11 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{AnalysisException, Row}
import org.apache.spark.sql.catalyst.util
import org.apache.spark.util.Utils
-class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll {
+class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
import caseInsensisitiveContext._
@@ -129,6 +129,18 @@ class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll {
}
}
+ test("it is not allowed to write to a table while querying it.") {
+ val message = intercept[AnalysisException] {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE jsonTable SELECT a, b FROM jsonTable
+ """.stripMargin)
+ }.getMessage
+ assert(
+ message.contains("Cannot insert overwrite into table that is also being read from."),
+ "INSERT OVERWRITE to a table while querying it should not be allowed.")
+ }
+
test("Caching") {
// Cached Query Execution
cacheTable("jsonTable")
@@ -173,4 +185,34 @@ class InsertIntoSuite extends DataSourceTest with BeforeAndAfterAll {
uncacheTable("jsonTable")
assertCached(sql("SELECT * FROM jsonTable"), 0)
}
+
+ test("it's not allowed to insert into a relation that is not an InsertableRelation") {
+ sql(
+ """
+ |CREATE TEMPORARY TABLE oneToTen
+ |USING org.apache.spark.sql.sources.SimpleScanSource
+ |OPTIONS (
+ | From '1',
+ | To '10'
+ |)
+ """.stripMargin)
+
+ checkAnswer(
+ sql("SELECT * FROM oneToTen"),
+ (1 to 10).map(Row(_)).toSeq
+ )
+
+ val message = intercept[AnalysisException] {
+ sql(
+ s"""
+ |INSERT OVERWRITE TABLE oneToTen SELECT CAST(a AS INT) FROM jt
+ """.stripMargin)
+ }.getMessage
+ assert(
+ message.contains("does not allow insertion."),
+ "It is not allowed to insert into a table that is not an InsertableRelation."
+ )
+
+ dropTempTable("oneToTen")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
new file mode 100644
index 0000000000000..8331a14c9295c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala
@@ -0,0 +1,34 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.sources
+
+import org.scalatest.FunSuite
+
+class ResolvedDataSourceSuite extends FunSuite {
+
+ test("builtin sources") {
+ assert(ResolvedDataSource.lookupDataSource("jdbc") ===
+ classOf[org.apache.spark.sql.jdbc.DefaultSource])
+
+ assert(ResolvedDataSource.lookupDataSource("json") ===
+ classOf[org.apache.spark.sql.json.DefaultSource])
+
+ assert(ResolvedDataSource.lookupDataSource("parquet") ===
+ classOf[org.apache.spark.sql.parquet.DefaultSource])
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
index fe2f76cc397f5..607488ccfdd6a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
@@ -21,10 +21,10 @@ import java.io.File
import org.scalatest.BeforeAndAfterAll
-import org.apache.spark.sql.DataFrame
-import org.apache.spark.util.Utils
-
import org.apache.spark.sql.catalyst.util
+import org.apache.spark.sql.{SaveMode, SQLConf, DataFrame}
+import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
@@ -38,42 +38,60 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
override def beforeAll(): Unit = {
originalDefaultSource = conf.defaultDataSourceName
- conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json")
path = util.getTempFilePath("datasource").getCanonicalFile
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
df = jsonRDD(rdd)
+ df.registerTempTable("jsonTable")
}
override def afterAll(): Unit = {
- conf.setConf("spark.sql.default.datasource", originalDefaultSource)
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
}
after {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
if (path.exists()) Utils.deleteRecursively(path)
}
def checkLoad(): Unit = {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
checkAnswer(load(path.toString), df.collect())
- checkAnswer(load("org.apache.spark.sql.json", ("path", path.toString)), df.collect())
+
+ // Test if we can pick up the data source name passed in load.
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ checkAnswer(load(path.toString, "org.apache.spark.sql.json"), df.collect())
+ checkAnswer(load("org.apache.spark.sql.json", Map("path" -> path.toString)), df.collect())
+ val schema = StructType(StructField("b", StringType, true) :: Nil)
+ checkAnswer(
+ load("org.apache.spark.sql.json", schema, Map("path" -> path.toString)),
+ sql("SELECT b FROM jsonTable").collect())
}
- test("save with overwrite and load") {
+ test("save with path and load") {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
df.save(path.toString)
- checkLoad
+ checkLoad()
+ }
+
+ test("save with path and datasource, and load") {
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.save(path.toString, "org.apache.spark.sql.json")
+ checkLoad()
}
test("save with data source and options, and load") {
- df.save("org.apache.spark.sql.json", ("path", path.toString))
- checkLoad
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.save("org.apache.spark.sql.json", SaveMode.ErrorIfExists, Map("path" -> path.toString))
+ checkLoad()
}
test("save and save again") {
- df.save(path.toString)
+ df.save(path.toString, "org.apache.spark.sql.json")
- val message = intercept[RuntimeException] {
- df.save(path.toString)
+ var message = intercept[RuntimeException] {
+ df.save(path.toString, "org.apache.spark.sql.json")
}.getMessage
assert(
@@ -82,7 +100,18 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
if (path.exists()) Utils.deleteRecursively(path)
- df.save(path.toString)
- checkLoad
+ df.save(path.toString, "org.apache.spark.sql.json")
+ checkLoad()
+
+ df.save("org.apache.spark.sql.json", SaveMode.Overwrite, Map("path" -> path.toString))
+ checkLoad()
+
+ message = intercept[RuntimeException] {
+ df.save("org.apache.spark.sql.json", SaveMode.Append, Map("path" -> path.toString))
+ }.getMessage
+
+ assert(
+ message.contains("Append mode is not supported"),
+ "We should complain that 'Append mode is not supported' for JSON source.")
}
}
\ No newline at end of file
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
index 525777aa454c4..6e07df18b0e15 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.hive.thriftserver
import org.apache.commons.logging.LogFactory
-import org.apache.hadoop.hive.common.LogUtils
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService}
@@ -55,8 +54,6 @@ object HiveThriftServer2 extends Logging {
System.exit(-1)
}
- LogUtils.initHiveLog4j()
-
logInfo("Starting SparkContext")
SparkSQLEnv.init()
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
old mode 100755
new mode 100644
index bb19ac232fcbe..401e97b162dea
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -292,9 +292,13 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
}
}
+ var counter = 0
try {
while (!out.checkError() && driver.getResults(res)) {
- res.foreach(out.println)
+ res.foreach{ l =>
+ counter += 1
+ out.println(l)
+ }
res.clear()
}
} catch {
@@ -311,7 +315,11 @@ private[hive] class SparkSQLCLIDriver extends CliDriver with Logging {
ret = cret
}
- console.printInfo(s"Time taken: $timeTaken seconds", null)
+ var responseMsg = s"Time taken: $timeTaken seconds"
+ if (counter != 0) {
+ responseMsg += s", Fetched $counter row(s)"
+ }
+ console.printInfo(responseMsg , null)
// Destroy the driver to release all the locks.
driver.destroy()
} else {
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index 60953576d0e37..8bca4b33b3ad1 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -121,6 +121,6 @@ class CliSuite extends FunSuite with BeforeAndAfterAll with Logging {
}
test("Single command with -e") {
- runCliWithin(1.minute, Seq("-e", "SHOW TABLES;"))("" -> "OK")
+ runCliWithin(1.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK")
}
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
deleted file mode 100644
index b52a51d11e4ad..0000000000000
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suite.scala
+++ /dev/null
@@ -1,387 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.thriftserver
-
-import java.io.File
-import java.net.ServerSocket
-import java.sql.{Date, DriverManager, Statement}
-import java.util.concurrent.TimeoutException
-
-import scala.collection.JavaConversions._
-import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.duration._
-import scala.concurrent.{Await, Promise}
-import scala.sys.process.{Process, ProcessLogger}
-import scala.util.Try
-
-import org.apache.hadoop.hive.conf.HiveConf.ConfVars
-import org.apache.hive.jdbc.HiveDriver
-import org.apache.hive.service.auth.PlainSaslHelper
-import org.apache.hive.service.cli.GetInfoType
-import org.apache.hive.service.cli.thrift.TCLIService.Client
-import org.apache.hive.service.cli.thrift._
-import org.apache.thrift.protocol.TBinaryProtocol
-import org.apache.thrift.transport.TSocket
-import org.scalatest.FunSuite
-
-import org.apache.spark.Logging
-import org.apache.spark.sql.catalyst.util.getTempFilePath
-import org.apache.spark.sql.hive.HiveShim
-
-/**
- * Tests for the HiveThriftServer2 using JDBC.
- *
- * NOTE: SPARK_PREPEND_CLASSES is explicitly disabled in this test suite. Assembly jar must be
- * rebuilt after changing HiveThriftServer2 related code.
- */
-class HiveThriftServer2Suite extends FunSuite with Logging {
- Class.forName(classOf[HiveDriver].getCanonicalName)
-
- object TestData {
- def getTestDataFilePath(name: String) = {
- Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name")
- }
-
- val smallKv = getTestDataFilePath("small_kv.txt")
- val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt")
- }
-
- def randomListeningPort = {
- // Let the system to choose a random available port to avoid collision with other parallel
- // builds.
- val socket = new ServerSocket(0)
- val port = socket.getLocalPort
- socket.close()
- port
- }
-
- def withJdbcStatement(
- serverStartTimeout: FiniteDuration = 1.minute,
- httpMode: Boolean = false)(
- f: Statement => Unit) {
- val port = randomListeningPort
-
- startThriftServer(port, serverStartTimeout, httpMode) {
- val jdbcUri = if (httpMode) {
- s"jdbc:hive2://${"localhost"}:$port/" +
- "default?hive.server2.transport.mode=http;hive.server2.thrift.http.path=cliservice"
- } else {
- s"jdbc:hive2://${"localhost"}:$port/"
- }
-
- val user = System.getProperty("user.name")
- val connection = DriverManager.getConnection(jdbcUri, user, "")
- val statement = connection.createStatement()
-
- try {
- f(statement)
- } finally {
- statement.close()
- connection.close()
- }
- }
- }
-
- def withCLIServiceClient(
- serverStartTimeout: FiniteDuration = 1.minute)(
- f: ThriftCLIServiceClient => Unit) {
- val port = randomListeningPort
-
- startThriftServer(port) {
- // Transport creation logics below mimics HiveConnection.createBinaryTransport
- val rawTransport = new TSocket("localhost", port)
- val user = System.getProperty("user.name")
- val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
- val protocol = new TBinaryProtocol(transport)
- val client = new ThriftCLIServiceClient(new Client(protocol))
-
- transport.open()
-
- try {
- f(client)
- } finally {
- transport.close()
- }
- }
- }
-
- def startThriftServer(
- port: Int,
- serverStartTimeout: FiniteDuration = 1.minute,
- httpMode: Boolean = false)(
- f: => Unit) {
- val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
- val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
-
- val warehousePath = getTempFilePath("warehouse")
- val metastorePath = getTempFilePath("metastore")
- val metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
-
- val command =
- if (httpMode) {
- s"""$startScript
- | --master local
- | --hiveconf hive.root.logger=INFO,console
- | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
- | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
- | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=http
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT}=$port
- | --driver-class-path ${sys.props("java.class.path")}
- | --conf spark.ui.enabled=false
- """.stripMargin.split("\\s+").toSeq
- } else {
- s"""$startScript
- | --master local
- | --hiveconf hive.root.logger=INFO,console
- | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
- | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
- | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_PORT}=$port
- | --driver-class-path ${sys.props("java.class.path")}
- | --conf spark.ui.enabled=false
- """.stripMargin.split("\\s+").toSeq
- }
-
- val serverRunning = Promise[Unit]()
- val buffer = new ArrayBuffer[String]()
- val LOGGING_MARK =
- s"starting ${HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")}, logging to "
- var logTailingProcess: Process = null
- var logFilePath: String = null
-
- def captureLogOutput(line: String): Unit = {
- buffer += line
- if (line.contains("ThriftBinaryCLIService listening on") ||
- line.contains("Started ThriftHttpCLIService in http")) {
- serverRunning.success(())
- }
- }
-
- def captureThriftServerOutput(source: String)(line: String): Unit = {
- if (line.startsWith(LOGGING_MARK)) {
- logFilePath = line.drop(LOGGING_MARK.length).trim
- // Ensure that the log file is created so that the `tail' command won't fail
- Try(new File(logFilePath).createNewFile())
- logTailingProcess = Process(s"/usr/bin/env tail -f $logFilePath")
- .run(ProcessLogger(captureLogOutput, _ => ()))
- }
- }
-
- val env = Seq(
- // Resets SPARK_TESTING to avoid loading Log4J configurations in testing class paths
- "SPARK_TESTING" -> "0")
-
- Process(command, None, env: _*).run(ProcessLogger(
- captureThriftServerOutput("stdout"),
- captureThriftServerOutput("stderr")))
-
- try {
- Await.result(serverRunning.future, serverStartTimeout)
- f
- } catch {
- case cause: Exception =>
- cause match {
- case _: TimeoutException =>
- logError(s"Failed to start Hive Thrift server within $serverStartTimeout", cause)
- case _ =>
- }
- logError(
- s"""
- |=====================================
- |HiveThriftServer2Suite failure output
- |=====================================
- |HiveThriftServer2 command line: ${command.mkString(" ")}
- |Binding port: $port
- |System user: ${System.getProperty("user.name")}
- |
- |${buffer.mkString("\n")}
- |=========================================
- |End HiveThriftServer2Suite failure output
- |=========================================
- """.stripMargin, cause)
- throw cause
- } finally {
- warehousePath.delete()
- metastorePath.delete()
- Process(stopScript, None, env: _*).run().exitValue()
- // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
- Thread.sleep(3.seconds.toMillis)
- Option(logTailingProcess).map(_.destroy())
- Option(logFilePath).map(new File(_).delete())
- }
- }
-
- test("Test JDBC query execution") {
- withJdbcStatement() { statement =>
- val queries = Seq(
- "SET spark.sql.shuffle.partitions=3",
- "DROP TABLE IF EXISTS test",
- "CREATE TABLE test(key INT, val STRING)",
- s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
- "CACHE TABLE test")
-
- queries.foreach(statement.execute)
-
- assertResult(5, "Row count mismatch") {
- val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
- resultSet.next()
- resultSet.getInt(1)
- }
- }
- }
-
- test("Test JDBC query execution in Http Mode") {
- withJdbcStatement(httpMode = true) { statement =>
- val queries = Seq(
- "SET spark.sql.shuffle.partitions=3",
- "DROP TABLE IF EXISTS test",
- "CREATE TABLE test(key INT, val STRING)",
- s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
- "CACHE TABLE test")
-
- queries.foreach(statement.execute)
-
- assertResult(5, "Row count mismatch") {
- val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
- resultSet.next()
- resultSet.getInt(1)
- }
- }
- }
-
- test("SPARK-3004 regression: result set containing NULL") {
- withJdbcStatement() { statement =>
- val queries = Seq(
- "DROP TABLE IF EXISTS test_null",
- "CREATE TABLE test_null(key INT, val STRING)",
- s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null")
-
- queries.foreach(statement.execute)
-
- val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
-
- (0 until 5).foreach { _ =>
- resultSet.next()
- assert(resultSet.getInt(1) === 0)
- assert(resultSet.wasNull())
- }
-
- assert(!resultSet.next())
- }
- }
-
- test("GetInfo Thrift API") {
- withCLIServiceClient() { client =>
- val user = System.getProperty("user.name")
- val sessionHandle = client.openSession(user, "")
-
- assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
- client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
- }
-
- assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
- client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
- }
-
- assertResult(true, "Spark version shouldn't be \"Unknown\"") {
- val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
- logInfo(s"Spark version: $version")
- version != "Unknown"
- }
- }
- }
-
- test("Checks Hive version") {
- withJdbcStatement() { statement =>
- val resultSet = statement.executeQuery("SET spark.sql.hive.version")
- resultSet.next()
- assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
- }
- }
-
- test("Checks Hive version in Http Mode") {
- withJdbcStatement(httpMode = true) { statement =>
- val resultSet = statement.executeQuery("SET spark.sql.hive.version")
- resultSet.next()
- assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
- }
- }
-
- test("SPARK-4292 regression: result set iterator issue") {
- withJdbcStatement() { statement =>
- val queries = Seq(
- "DROP TABLE IF EXISTS test_4292",
- "CREATE TABLE test_4292(key INT, val STRING)",
- s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292")
-
- queries.foreach(statement.execute)
-
- val resultSet = statement.executeQuery("SELECT key FROM test_4292")
-
- Seq(238, 86, 311, 27, 165).foreach { key =>
- resultSet.next()
- assert(resultSet.getInt(1) === key)
- }
-
- statement.executeQuery("DROP TABLE IF EXISTS test_4292")
- }
- }
-
- test("SPARK-4309 regression: Date type support") {
- withJdbcStatement() { statement =>
- val queries = Seq(
- "DROP TABLE IF EXISTS test_date",
- "CREATE TABLE test_date(key INT, value STRING)",
- s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date")
-
- queries.foreach(statement.execute)
-
- assertResult(Date.valueOf("2011-01-01")) {
- val resultSet = statement.executeQuery(
- "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1")
- resultSet.next()
- resultSet.getDate(1)
- }
- }
- }
-
- test("SPARK-4407 regression: Complex type support") {
- withJdbcStatement() { statement =>
- val queries = Seq(
- "DROP TABLE IF EXISTS test_map",
- "CREATE TABLE test_map(key INT, value STRING)",
- s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map")
-
- queries.foreach(statement.execute)
-
- assertResult("""{238:"val_238"}""") {
- val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1")
- resultSet.next()
- resultSet.getString(1)
- }
-
- assertResult("""["238","val_238"]""") {
- val resultSet = statement.executeQuery(
- "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1")
- resultSet.next()
- resultSet.getString(1)
- }
- }
- }
-}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
new file mode 100644
index 0000000000000..d783d487b5c60
--- /dev/null
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala
@@ -0,0 +1,412 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive.thriftserver
+
+import java.io.File
+import java.sql.{Date, DriverManager, Statement}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.concurrent.duration._
+import scala.concurrent.{Await, Promise}
+import scala.sys.process.{Process, ProcessLogger}
+import scala.util.{Random, Try}
+
+import org.apache.hadoop.hive.conf.HiveConf.ConfVars
+import org.apache.hive.jdbc.HiveDriver
+import org.apache.hive.service.auth.PlainSaslHelper
+import org.apache.hive.service.cli.GetInfoType
+import org.apache.hive.service.cli.thrift.TCLIService.Client
+import org.apache.hive.service.cli.thrift.ThriftCLIServiceClient
+import org.apache.thrift.protocol.TBinaryProtocol
+import org.apache.thrift.transport.TSocket
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.util
+import org.apache.spark.sql.hive.HiveShim
+import org.apache.spark.util.Utils
+
+object TestData {
+ def getTestDataFilePath(name: String) = {
+ Thread.currentThread().getContextClassLoader.getResource(s"data/files/$name")
+ }
+
+ val smallKv = getTestDataFilePath("small_kv.txt")
+ val smallKvWithNull = getTestDataFilePath("small_kv_with_null.txt")
+}
+
+class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest {
+ override def mode = ServerMode.binary
+
+ private def withCLIServiceClient(f: ThriftCLIServiceClient => Unit): Unit = {
+ // Transport creation logics below mimics HiveConnection.createBinaryTransport
+ val rawTransport = new TSocket("localhost", serverPort)
+ val user = System.getProperty("user.name")
+ val transport = PlainSaslHelper.getPlainTransport(user, "anonymous", rawTransport)
+ val protocol = new TBinaryProtocol(transport)
+ val client = new ThriftCLIServiceClient(new Client(protocol))
+
+ transport.open()
+ try f(client) finally transport.close()
+ }
+
+ test("GetInfo Thrift API") {
+ withCLIServiceClient { client =>
+ val user = System.getProperty("user.name")
+ val sessionHandle = client.openSession(user, "")
+
+ assertResult("Spark SQL", "Wrong GetInfo(CLI_DBMS_NAME) result") {
+ client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_NAME).getStringValue
+ }
+
+ assertResult("Spark SQL", "Wrong GetInfo(CLI_SERVER_NAME) result") {
+ client.getInfo(sessionHandle, GetInfoType.CLI_SERVER_NAME).getStringValue
+ }
+
+ assertResult(true, "Spark version shouldn't be \"Unknown\"") {
+ val version = client.getInfo(sessionHandle, GetInfoType.CLI_DBMS_VER).getStringValue
+ logInfo(s"Spark version: $version")
+ version != "Unknown"
+ }
+ }
+ }
+
+ test("JDBC query execution") {
+ withJdbcStatement { statement =>
+ val queries = Seq(
+ "SET spark.sql.shuffle.partitions=3",
+ "DROP TABLE IF EXISTS test",
+ "CREATE TABLE test(key INT, val STRING)",
+ s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
+ "CACHE TABLE test")
+
+ queries.foreach(statement.execute)
+
+ assertResult(5, "Row count mismatch") {
+ val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
+ resultSet.next()
+ resultSet.getInt(1)
+ }
+ }
+ }
+
+ test("Checks Hive version") {
+ withJdbcStatement { statement =>
+ val resultSet = statement.executeQuery("SET spark.sql.hive.version")
+ resultSet.next()
+ assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
+ }
+ }
+
+ test("SPARK-3004 regression: result set containing NULL") {
+ withJdbcStatement { statement =>
+ val queries = Seq(
+ "DROP TABLE IF EXISTS test_null",
+ "CREATE TABLE test_null(key INT, val STRING)",
+ s"LOAD DATA LOCAL INPATH '${TestData.smallKvWithNull}' OVERWRITE INTO TABLE test_null")
+
+ queries.foreach(statement.execute)
+
+ val resultSet = statement.executeQuery("SELECT * FROM test_null WHERE key IS NULL")
+
+ (0 until 5).foreach { _ =>
+ resultSet.next()
+ assert(resultSet.getInt(1) === 0)
+ assert(resultSet.wasNull())
+ }
+
+ assert(!resultSet.next())
+ }
+ }
+
+ test("SPARK-4292 regression: result set iterator issue") {
+ withJdbcStatement { statement =>
+ val queries = Seq(
+ "DROP TABLE IF EXISTS test_4292",
+ "CREATE TABLE test_4292(key INT, val STRING)",
+ s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_4292")
+
+ queries.foreach(statement.execute)
+
+ val resultSet = statement.executeQuery("SELECT key FROM test_4292")
+
+ Seq(238, 86, 311, 27, 165).foreach { key =>
+ resultSet.next()
+ assert(resultSet.getInt(1) === key)
+ }
+
+ statement.executeQuery("DROP TABLE IF EXISTS test_4292")
+ }
+ }
+
+ test("SPARK-4309 regression: Date type support") {
+ withJdbcStatement { statement =>
+ val queries = Seq(
+ "DROP TABLE IF EXISTS test_date",
+ "CREATE TABLE test_date(key INT, value STRING)",
+ s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_date")
+
+ queries.foreach(statement.execute)
+
+ assertResult(Date.valueOf("2011-01-01")) {
+ val resultSet = statement.executeQuery(
+ "SELECT CAST('2011-01-01' as date) FROM test_date LIMIT 1")
+ resultSet.next()
+ resultSet.getDate(1)
+ }
+ }
+ }
+
+ test("SPARK-4407 regression: Complex type support") {
+ withJdbcStatement { statement =>
+ val queries = Seq(
+ "DROP TABLE IF EXISTS test_map",
+ "CREATE TABLE test_map(key INT, value STRING)",
+ s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map")
+
+ queries.foreach(statement.execute)
+
+ assertResult("""{238:"val_238"}""") {
+ val resultSet = statement.executeQuery("SELECT MAP(key, value) FROM test_map LIMIT 1")
+ resultSet.next()
+ resultSet.getString(1)
+ }
+
+ assertResult("""["238","val_238"]""") {
+ val resultSet = statement.executeQuery(
+ "SELECT ARRAY(CAST(key AS STRING), value) FROM test_map LIMIT 1")
+ resultSet.next()
+ resultSet.getString(1)
+ }
+ }
+ }
+}
+
+class HiveThriftHttpServerSuite extends HiveThriftJdbcTest {
+ override def mode = ServerMode.http
+
+ test("JDBC query execution") {
+ withJdbcStatement { statement =>
+ val queries = Seq(
+ "SET spark.sql.shuffle.partitions=3",
+ "DROP TABLE IF EXISTS test",
+ "CREATE TABLE test(key INT, val STRING)",
+ s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test",
+ "CACHE TABLE test")
+
+ queries.foreach(statement.execute)
+
+ assertResult(5, "Row count mismatch") {
+ val resultSet = statement.executeQuery("SELECT COUNT(*) FROM test")
+ resultSet.next()
+ resultSet.getInt(1)
+ }
+ }
+ }
+
+ test("Checks Hive version") {
+ withJdbcStatement { statement =>
+ val resultSet = statement.executeQuery("SET spark.sql.hive.version")
+ resultSet.next()
+ assert(resultSet.getString(1) === s"spark.sql.hive.version=${HiveShim.version}")
+ }
+ }
+}
+
+object ServerMode extends Enumeration {
+ val binary, http = Value
+}
+
+abstract class HiveThriftJdbcTest extends HiveThriftServer2Test {
+ Class.forName(classOf[HiveDriver].getCanonicalName)
+
+ private def jdbcUri = if (mode == ServerMode.http) {
+ s"""jdbc:hive2://localhost:$serverPort/
+ |default?
+ |hive.server2.transport.mode=http;
+ |hive.server2.thrift.http.path=cliservice
+ """.stripMargin.split("\n").mkString.trim
+ } else {
+ s"jdbc:hive2://localhost:$serverPort/"
+ }
+
+ protected def withJdbcStatement(f: Statement => Unit): Unit = {
+ val connection = DriverManager.getConnection(jdbcUri, user, "")
+ val statement = connection.createStatement()
+
+ try f(statement) finally {
+ statement.close()
+ connection.close()
+ }
+ }
+}
+
+abstract class HiveThriftServer2Test extends FunSuite with BeforeAndAfterAll with Logging {
+ def mode: ServerMode.Value
+
+ private val CLASS_NAME = HiveThriftServer2.getClass.getCanonicalName.stripSuffix("$")
+ private val LOG_FILE_MARK = s"starting $CLASS_NAME, logging to "
+
+ private val startScript = "../../sbin/start-thriftserver.sh".split("/").mkString(File.separator)
+ private val stopScript = "../../sbin/stop-thriftserver.sh".split("/").mkString(File.separator)
+
+ private var listeningPort: Int = _
+ protected def serverPort: Int = listeningPort
+
+ protected def user = System.getProperty("user.name")
+
+ private var warehousePath: File = _
+ private var metastorePath: File = _
+ private def metastoreJdbcUri = s"jdbc:derby:;databaseName=$metastorePath;create=true"
+
+ private val pidDir: File = Utils.createTempDir("thriftserver-pid")
+ private var logPath: File = _
+ private var logTailingProcess: Process = _
+ private var diagnosisBuffer: ArrayBuffer[String] = ArrayBuffer.empty[String]
+
+ private def serverStartCommand(port: Int) = {
+ val portConf = if (mode == ServerMode.binary) {
+ ConfVars.HIVE_SERVER2_THRIFT_PORT
+ } else {
+ ConfVars.HIVE_SERVER2_THRIFT_HTTP_PORT
+ }
+
+ s"""$startScript
+ | --master local
+ | --hiveconf hive.root.logger=INFO,console
+ | --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$metastoreJdbcUri
+ | --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
+ | --hiveconf ${ConfVars.HIVE_SERVER2_THRIFT_BIND_HOST}=localhost
+ | --hiveconf ${ConfVars.HIVE_SERVER2_TRANSPORT_MODE}=$mode
+ | --hiveconf $portConf=$port
+ | --driver-class-path ${sys.props("java.class.path")}
+ | --conf spark.ui.enabled=false
+ """.stripMargin.split("\\s+").toSeq
+ }
+
+ private def startThriftServer(port: Int, attempt: Int) = {
+ warehousePath = util.getTempFilePath("warehouse")
+ metastorePath = util.getTempFilePath("metastore")
+ logPath = null
+ logTailingProcess = null
+
+ val command = serverStartCommand(port)
+
+ diagnosisBuffer ++=
+ s"""
+ |### Attempt $attempt ###
+ |HiveThriftServer2 command line: $command
+ |Listening port: $port
+ |System user: $user
+ """.stripMargin.split("\n")
+
+ logInfo(s"Trying to start HiveThriftServer2: port=$port, mode=$mode, attempt=$attempt")
+
+ val env = Seq(
+ // Disables SPARK_TESTING to exclude log4j.properties in test directories.
+ "SPARK_TESTING" -> "0",
+ // Points SPARK_PID_DIR to SPARK_HOME, otherwise only 1 Thrift server instance can be started
+ // at a time, which is not Jenkins friendly.
+ "SPARK_PID_DIR" -> pidDir.getCanonicalPath)
+
+ logPath = Process(command, None, env: _*).lines.collectFirst {
+ case line if line.contains(LOG_FILE_MARK) => new File(line.drop(LOG_FILE_MARK.length))
+ }.getOrElse {
+ throw new RuntimeException("Failed to find HiveThriftServer2 log file.")
+ }
+
+ val serverStarted = Promise[Unit]()
+
+ // Ensures that the following "tail" command won't fail.
+ logPath.createNewFile()
+ logTailingProcess =
+ // Using "-n +0" to make sure all lines in the log file are checked.
+ Process(s"/usr/bin/env tail -n +0 -f ${logPath.getCanonicalPath}").run(ProcessLogger(
+ (line: String) => {
+ diagnosisBuffer += line
+
+ if (line.contains("ThriftBinaryCLIService listening on") ||
+ line.contains("Started ThriftHttpCLIService in http")) {
+ serverStarted.trySuccess(())
+ } else if (line.contains("HiveServer2 is stopped")) {
+ // This log line appears when the server fails to start and terminates gracefully (e.g.
+ // because of port contention).
+ serverStarted.tryFailure(new RuntimeException("Failed to start HiveThriftServer2"))
+ }
+ }))
+
+ Await.result(serverStarted.future, 2.minute)
+ }
+
+ private def stopThriftServer(): Unit = {
+ // The `spark-daemon.sh' script uses kill, which is not synchronous, have to wait for a while.
+ Process(stopScript, None, "SPARK_PID_DIR" -> pidDir.getCanonicalPath).run().exitValue()
+ Thread.sleep(3.seconds.toMillis)
+
+ warehousePath.delete()
+ warehousePath = null
+
+ metastorePath.delete()
+ metastorePath = null
+
+ Option(logPath).foreach(_.delete())
+ logPath = null
+
+ Option(logTailingProcess).foreach(_.destroy())
+ logTailingProcess = null
+ }
+
+ private def dumpLogs(): Unit = {
+ logError(
+ s"""
+ |=====================================
+ |HiveThriftServer2Suite failure output
+ |=====================================
+ |${diagnosisBuffer.mkString("\n")}
+ |=========================================
+ |End HiveThriftServer2Suite failure output
+ |=========================================
+ """.stripMargin)
+ }
+
+ override protected def beforeAll(): Unit = {
+ // Chooses a random port between 10000 and 19999
+ listeningPort = 10000 + Random.nextInt(10000)
+ diagnosisBuffer.clear()
+
+ // Retries up to 3 times with different port numbers if the server fails to start
+ (1 to 3).foldLeft(Try(startThriftServer(listeningPort, 0))) { case (started, attempt) =>
+ started.orElse {
+ listeningPort += 1
+ stopThriftServer()
+ Try(startThriftServer(listeningPort, attempt))
+ }
+ }.recover {
+ case cause: Throwable =>
+ dumpLogs()
+ throw cause
+ }.get
+
+ logInfo(s"HiveThriftServer2 started successfully")
+ }
+
+ override protected def afterAll(): Unit = {
+ stopThriftServer()
+ logInfo("HiveThriftServer2 stopped")
+ }
+}
diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
index ea9d61d8d0f5e..13116b40bb259 100644
--- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
+++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala
@@ -185,6 +185,10 @@ private[hive] class SparkExecuteStatementOperation(
def run(): Unit = {
logInfo(s"Running query '$statement'")
setState(OperationState.RUNNING)
+ hiveContext.sparkContext.setJobDescription(statement)
+ sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool =>
+ hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool)
+ }
try {
result = hiveContext.sql(statement)
logDebug(result.queryExecution.toString())
@@ -194,10 +198,6 @@ private[hive] class SparkExecuteStatementOperation(
logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
case _ =>
}
- hiveContext.sparkContext.setJobDescription(statement)
- sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool =>
- hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool)
- }
iter = {
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
index 71e3954b2c7ac..9b8faeff94eab 100644
--- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
+++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala
@@ -156,6 +156,10 @@ private[hive] class SparkExecuteStatementOperation(
def run(): Unit = {
logInfo(s"Running query '$statement'")
setState(OperationState.RUNNING)
+ hiveContext.sparkContext.setJobDescription(statement)
+ sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool =>
+ hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool)
+ }
try {
result = hiveContext.sql(statement)
logDebug(result.queryExecution.toString())
@@ -165,10 +169,6 @@ private[hive] class SparkExecuteStatementOperation(
logInfo(s"Setting spark.scheduler.pool=$value for future statements in this session.")
case _ =>
}
- hiveContext.sparkContext.setJobDescription(statement)
- sessionToActivePool.get(parentSession.getSessionHandle).foreach { pool =>
- hiveContext.sparkContext.setLocalProperty("spark.scheduler.pool", pool)
- }
iter = {
val useIncrementalCollect =
hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index a6266f611c219..c6ead4562d51e 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -225,6 +225,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// Needs constant object inspectors
"udf_round",
+ // the table src(key INT, value STRING) is not the same as HIVE unittest. In Hive
+ // is src(key STRING, value STRING), and in the reflect.q, it failed in
+ // Integer.valueOf, which expect the first argument passed as STRING type not INT.
+ "udf_reflect",
+
// Sort with Limit clause causes failure.
"ctas",
"ctas_hadoop20",
@@ -518,6 +523,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"inputddl2",
"inputddl3",
"inputddl4",
+ "inputddl5",
"inputddl6",
"inputddl7",
"inputddl8",
@@ -639,6 +645,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"nonblock_op_deduplicate",
"notable_alias1",
"notable_alias2",
+ "nullformatCTAS",
"nullgroup",
"nullgroup2",
"nullgroup3",
@@ -884,6 +891,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_power",
"udf_radians",
"udf_rand",
+ "udf_reflect2",
"udf_regexp",
"udf_regexp_extract",
"udf_regexp_replace",
diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml
index 58b0722464be8..72c474d66055c 100644
--- a/sql/hive/pom.xml
+++ b/sql/hive/pom.xml
@@ -84,6 +84,11 @@
scalacheck_${scala.binary.version}test
+
+ junit
+ junit
+ test
+
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 2c00659496972..c439dfe0a71f8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -22,26 +22,24 @@ import java.sql.Timestamp
import scala.collection.JavaConversions._
import scala.language.implicitConversions
-import scala.reflect.runtime.universe.TypeTag
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.hive.conf.HiveConf
import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.metadata.Table
-import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.parse.VariableSubstitution
+import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.session.SessionState
import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable}
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateAnalysisOperators, OverrideCatalog, OverrideFunctionRegistry}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, EliminateSubQueries, OverrideCatalog, OverrideFunctionRegistry}
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand, QueryExecutionException}
-import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DescribeHiveTableCommand}
-import org.apache.spark.sql.sources.{CreateTableUsing, DataSourceStrategy}
+import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, QueryExecutionException, SetCommand}
+import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand}
+import org.apache.spark.sql.sources.{DDLParser, DataSourceStrategy}
import org.apache.spark.sql.types._
/**
@@ -63,34 +61,40 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
protected[sql] def convertMetastoreParquet: Boolean =
getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true"
+ /**
+ * When true, a table created by a Hive CTAS statement (no USING clause) will be
+ * converted to a data source table, using the data source set by spark.sql.sources.default.
+ * The table in CTAS statement will be converted when it meets any of the following conditions:
+ * - The CTAS does not specify any of a SerDe (ROW FORMAT SERDE), a File Format (STORED AS), or
+ * a Storage Hanlder (STORED BY), and the value of hive.default.fileformat in hive-site.xml
+ * is either TextFile or SequenceFile.
+ * - The CTAS statement specifies TextFile (STORED AS TEXTFILE) as the file format and no SerDe
+ * is specified (no ROW FORMAT SERDE clause).
+ * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format
+ * and no SerDe is specified (no ROW FORMAT SERDE clause).
+ */
+ protected[sql] def convertCTAS: Boolean =
+ getConf("spark.sql.hive.convertCTAS", "false").toBoolean
+
override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution =
new this.QueryExecution(plan)
+ @transient
+ protected[sql] val ddlParserWithHiveQL = new DDLParser(HiveQl.parseSql(_))
+
override def sql(sqlText: String): DataFrame = {
val substituted = new VariableSubstitution().substitute(hiveconf, sqlText)
// TODO: Create a framework for registering parsers instead of just hardcoding if statements.
if (conf.dialect == "sql") {
super.sql(substituted)
} else if (conf.dialect == "hiveql") {
- DataFrame(this,
- ddlParser(sqlText, exceptionOnError = false).getOrElse(HiveQl.parseSql(substituted)))
+ val ddlPlan = ddlParserWithHiveQL(sqlText, exceptionOnError = false)
+ DataFrame(this, ddlPlan.getOrElse(HiveQl.parseSql(substituted)))
} else {
sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'")
}
}
- /**
- * Creates a table using the schema of the given class.
- *
- * @param tableName The name of the table to create.
- * @param allowExisting When false, an exception will be thrown if the table already exists.
- * @tparam A A case class that is used to describe the schema of the table to be created.
- */
- @Deprecated
- def createTable[A <: Product : TypeTag](tableName: String, allowExisting: Boolean = true) {
- catalog.createTable("default", tableName, ScalaReflection.attributesFor[A], allowExisting)
- }
-
/**
* Invalidate and refresh all the cached the metadata of the given table. For performance reasons,
* Spark SQL or the external data source library it uses might cache certain metadata about a
@@ -107,70 +111,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.invalidateTable("default", tableName)
}
- @Experimental
- def createTable(tableName: String, path: String, allowExisting: Boolean): Unit = {
- val dataSourceName = conf.defaultDataSourceName
- createTable(tableName, dataSourceName, allowExisting, ("path", path))
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- allowExisting: Boolean,
- option: (String, String),
- options: (String, String)*): Unit = {
- val cmd =
- CreateTableUsing(
- tableName,
- userSpecifiedSchema = None,
- dataSourceName,
- temporary = false,
- (option +: options).toMap,
- allowExisting)
- executePlan(cmd).toRdd
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- schema: StructType,
- allowExisting: Boolean,
- option: (String, String),
- options: (String, String)*): Unit = {
- val cmd =
- CreateTableUsing(
- tableName,
- userSpecifiedSchema = Some(schema),
- dataSourceName,
- temporary = false,
- (option +: options).toMap,
- allowExisting)
- executePlan(cmd).toRdd
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- allowExisting: Boolean,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- createTable(tableName, dataSourceName, allowExisting, opts.head, opts.tail:_*)
- }
-
- @Experimental
- def createTable(
- tableName: String,
- dataSourceName: String,
- schema: StructType,
- allowExisting: Boolean,
- options: java.util.Map[String, String]): Unit = {
- val opts = options.toSeq
- createTable(tableName, dataSourceName, schema, allowExisting, opts.head, opts.tail:_*)
- }
-
/**
* Analyzes the given table in the current database to generate statistics, which will be
* used in query optimizations.
@@ -180,7 +120,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
*/
@Experimental
def analyze(tableName: String) {
- val relation = EliminateAnalysisOperators(catalog.lookupRelation(Seq(tableName)))
+ val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName)))
relation match {
case relation: MetastoreRelation =>
@@ -282,22 +222,25 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
* SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be
* set in the SQLConf *as well as* in the HiveConf.
*/
- @transient protected[hive] lazy val (hiveconf, sessionState) =
- Option(SessionState.get())
- .orElse {
- val newState = new SessionState(new HiveConf(classOf[SessionState]))
- // Only starts newly created `SessionState` instance. Any existing `SessionState` instance
- // returned by `SessionState.get()` must be the most recently started one.
- SessionState.start(newState)
- Some(newState)
- }
- .map { state =>
- setConf(state.getConf.getAllProperties)
- if (state.out == null) state.out = new PrintStream(outputBuffer, true, "UTF-8")
- if (state.err == null) state.err = new PrintStream(outputBuffer, true, "UTF-8")
- (state.getConf, state)
- }
- .get
+ @transient protected[hive] lazy val sessionState: SessionState = {
+ var state = SessionState.get()
+ if (state == null) {
+ state = new SessionState(new HiveConf(classOf[SessionState]))
+ SessionState.start(state)
+ }
+ if (state.out == null) {
+ state.out = new PrintStream(outputBuffer, true, "UTF-8")
+ }
+ if (state.err == null) {
+ state.err = new PrintStream(outputBuffer, true, "UTF-8")
+ }
+ state
+ }
+
+ @transient protected[hive] lazy val hiveconf: HiveConf = {
+ setConf(sessionState.getConf.getAllProperties)
+ sessionState.getConf
+ }
override def setConf(key: String, value: String): Unit = {
super.setConf(key, value)
@@ -319,7 +262,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
@transient
override protected[sql] lazy val analyzer =
new Analyzer(catalog, functionRegistry, caseSensitive = false) {
- override val extendedRules =
+ override val extendedResolutionRules =
+ catalog.ParquetConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
ExtractPythonUdfs ::
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index c78369d12cf55..d3ad364328265 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -20,26 +20,26 @@ package org.apache.spark.sql.hive
import java.io.IOException
import java.util.{List => JList}
-import com.google.common.cache.{LoadingCache, CacheLoader, CacheBuilder}
-
-import org.apache.hadoop.util.ReflectionUtils
-import org.apache.hadoop.hive.metastore.{Warehouse, TableType}
-import org.apache.hadoop.hive.metastore.api.{Table => TTable, Partition => TPartition, FieldSchema}
+import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
+import org.apache.hadoop.hive.metastore.api.{FieldSchema, Partition => TPartition, Table => TTable}
+import org.apache.hadoop.hive.metastore.{TableType, Warehouse}
import org.apache.hadoop.hive.ql.metadata._
import org.apache.hadoop.hive.ql.plan.CreateTableDesc
import org.apache.hadoop.hive.serde.serdeConstants
-import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException}
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
+import org.apache.hadoop.hive.serde2.{Deserializer, SerDeException}
+import org.apache.hadoop.util.ReflectionUtils
import org.apache.spark.Logging
-import org.apache.spark.sql.SQLContext
-import org.apache.spark.sql.catalyst.analysis.{Catalog, OverrideCatalog}
+import org.apache.spark.sql.{SaveMode, AnalysisException, SQLContext}
+import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, Catalog, OverrideCatalog}
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.parquet.ParquetRelation2
-import org.apache.spark.sql.sources.{DDLParser, LogicalRelation, ResolvedDataSource}
+import org.apache.spark.sql.parquet.{ParquetRelation2, Partition => ParquetPartition, PartitionSpec}
+import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, DDLParser, LogicalRelation, ResolvedDataSource}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -52,6 +52,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
/** Connection to hive metastore. Usages should lock on `this`. */
protected[hive] val client = Hive.get(hive.hiveconf)
+ /** Usages should lock on `this`. */
protected[hive] lazy val hiveWarehouse = new Warehouse(hive.hiveconf)
// TODO: Use this everywhere instead of tuples or databaseName, tableName,.
@@ -65,14 +66,26 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
val cacheLoader = new CacheLoader[QualifiedTableName, LogicalPlan]() {
override def load(in: QualifiedTableName): LogicalPlan = {
logDebug(s"Creating new cached data source for $in")
- val table = client.getTable(in.database, in.name)
- val schemaString = table.getProperty("spark.sql.sources.schema")
+ val table = synchronized {
+ client.getTable(in.database, in.name)
+ }
val userSpecifiedSchema =
- if (schemaString == null) {
- None
- } else {
- Some(DataType.fromJson(schemaString).asInstanceOf[StructType])
+ Option(table.getProperty("spark.sql.sources.schema.numParts")).map { numParts =>
+ val parts = (0 until numParts.toInt).map { index =>
+ val part = table.getProperty(s"spark.sql.sources.schema.part.${index}")
+ if (part == null) {
+ throw new AnalysisException(
+ s"Could not read schema from the metastore because it is corrupted " +
+ s"(missing part ${index} of the schema).")
+ }
+
+ part
+ }
+ // Stick all parts back to a single schema string in the JSON representation
+ // and convert it back to a StructType.
+ DataType.fromJson(parts.mkString).asInstanceOf[StructType]
}
+
// It does not appear that the ql client for the metastore has a way to enumerate all the
// SerDe properties directly...
val options = table.getTTable.getSd.getSerdeInfo.getParameters.toMap
@@ -91,7 +104,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
CacheBuilder.newBuilder().maximumSize(1000).build(cacheLoader)
}
- def refreshTable(databaseName: String, tableName: String): Unit = {
+ override def refreshTable(databaseName: String, tableName: String): Unit = {
cachedDataSourceTables.refresh(QualifiedTableName(databaseName, tableName).toLowerCase)
}
@@ -101,16 +114,10 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
val caseSensitive: Boolean = false
- /** *
- * Creates a data source table (a table created with USING clause) in Hive's metastore.
- * Returns true when the table has been created. Otherwise, false.
- * @param tableName
- * @param userSpecifiedSchema
- * @param provider
- * @param options
- * @param isExternal
- * @return
- */
+ /**
+ * Creates a data source table (a table created with USING clause) in Hive's metastore.
+ * Returns true when the table has been created. Otherwise, false.
+ */
def createDataSourceTable(
tableName: String,
userSpecifiedSchema: Option[StructType],
@@ -122,7 +129,14 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
tbl.setProperty("spark.sql.sources.provider", provider)
if (userSpecifiedSchema.isDefined) {
- tbl.setProperty("spark.sql.sources.schema", userSpecifiedSchema.get.json)
+ val threshold = hive.conf.schemaStringLengthThreshold
+ val schemaJsonString = userSpecifiedSchema.get.json
+ // Split the JSON string.
+ val parts = schemaJsonString.grouped(threshold).toSeq
+ tbl.setProperty("spark.sql.sources.schema.numParts", parts.size.toString)
+ parts.zipWithIndex.foreach { case (part, index) =>
+ tbl.setProperty(s"spark.sql.sources.schema.part.${index}", part)
+ }
}
options.foreach { case (key, value) => tbl.setSerdeParam(key, value) }
@@ -140,15 +154,18 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
}
}
- def hiveDefaultTableFilePath(tableName: String): String = {
- val currentDatabase = client.getDatabase(hive.sessionState.getCurrentDatabase())
+ def hiveDefaultTableFilePath(tableName: String): String = synchronized {
+ val currentDatabase = client.getDatabase(hive.sessionState.getCurrentDatabase)
+
hiveWarehouse.getTablePath(currentDatabase, tableName).toString
}
- def tableExists(tableIdentifier: Seq[String]): Boolean = {
+ def tableExists(tableIdentifier: Seq[String]): Boolean = synchronized {
val tableIdent = processTableIdentifier(tableIdentifier)
- val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse(
- hive.sessionState.getCurrentDatabase)
+ val databaseName =
+ tableIdent
+ .lift(tableIdent.size - 2)
+ .getOrElse(hive.sessionState.getCurrentDatabase)
val tblName = tableIdent.last
client.getTable(databaseName, tblName, false) != null
}
@@ -160,10 +177,21 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
val databaseName = tableIdent.lift(tableIdent.size - 2).getOrElse(
hive.sessionState.getCurrentDatabase)
val tblName = tableIdent.last
- val table = client.getTable(databaseName, tblName)
+ val table = try client.getTable(databaseName, tblName) catch {
+ case te: org.apache.hadoop.hive.ql.metadata.InvalidTableException =>
+ throw new NoSuchTableException
+ }
if (table.getProperty("spark.sql.sources.provider") != null) {
- cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase)
+ val dataSourceTable =
+ cachedDataSourceTables(QualifiedTableName(databaseName, tblName).toLowerCase)
+ // Then, if alias is specified, wrap the table with a Subquery using the alias.
+ // Othersie, wrap the table with a Subquery using the table name.
+ val withAlias =
+ alias.map(a => Subquery(a, dataSourceTable)).getOrElse(
+ Subquery(tableIdent.last, dataSourceTable))
+
+ withAlias
} else if (table.isView) {
// if the unresolved relation is from hive view
// parse the text into logic node.
@@ -176,26 +204,53 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
Nil
}
- val relation = MetastoreRelation(
- databaseName, tblName, alias)(
- table.getTTable, partitions.map(part => part.getTPartition))(hive)
-
- if (hive.convertMetastoreParquet &&
- hive.conf.parquetUseDataSourceApi &&
- relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet")) {
- val metastoreSchema = StructType.fromAttributes(relation.output)
- val paths = if (relation.hiveQlTable.isPartitioned) {
- relation.hiveQlPartitions.map(p => p.getLocation)
- } else {
- Seq(relation.hiveQlTable.getDataLocation.toString)
- }
+ MetastoreRelation(databaseName, tblName, alias)(
+ table.getTTable, partitions.map(part => part.getTPartition))(hive)
+ }
+ }
- LogicalRelation(ParquetRelation2(
- paths, Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive))
- } else {
- relation
+ private def convertToParquetRelation(metastoreRelation: MetastoreRelation): LogicalRelation = {
+ val metastoreSchema = StructType.fromAttributes(metastoreRelation.output)
+
+ // NOTE: Instead of passing Metastore schema directly to `ParquetRelation2`, we have to
+ // serialize the Metastore schema to JSON and pass it as a data source option because of the
+ // evil case insensitivity issue, which is reconciled within `ParquetRelation2`.
+ if (metastoreRelation.hiveQlTable.isPartitioned) {
+ val partitionSchema = StructType.fromAttributes(metastoreRelation.partitionKeys)
+ val partitionColumnDataTypes = partitionSchema.map(_.dataType)
+ val partitions = metastoreRelation.hiveQlPartitions.map { p =>
+ val location = p.getLocation
+ val values = Row.fromSeq(p.getValues.zip(partitionColumnDataTypes).map {
+ case (rawValue, dataType) => Cast(Literal(rawValue), dataType).eval(null)
+ })
+ ParquetPartition(values, location)
}
+ val partitionSpec = PartitionSpec(partitionSchema, partitions)
+ val paths = partitions.map(_.path)
+ LogicalRelation(
+ ParquetRelation2(
+ paths,
+ Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json),
+ None,
+ Some(partitionSpec))(hive))
+ } else {
+ val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString)
+ LogicalRelation(
+ ParquetRelation2(
+ paths,
+ Map(ParquetRelation2.METASTORE_SCHEMA -> metastoreSchema.json))(hive))
+ }
+ }
+
+ override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = synchronized {
+ val dbName = if (!caseSensitive) {
+ if (databaseName.isDefined) Some(databaseName.get.toLowerCase) else None
+ } else {
+ databaseName
}
+ val db = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
+
+ client.getAllTables(db).map(tableName => (tableName, false))
}
/**
@@ -225,7 +280,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
val hiveSchema: JList[FieldSchema] = if (schema == null || schema.isEmpty) {
crtTbl.getCols
} else {
- schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), ""))
+ schema.map(attr => new FieldSchema(attr.name, toMetastoreType(attr.dataType), null))
}
tbl.setFields(hiveSchema)
@@ -256,9 +311,9 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
logInfo(s"Default to LazySimpleSerDe for table $dbName.$tblName")
tbl.setSerializationLib(classOf[LazySimpleSerDe].getName())
- import org.apache.hadoop.mapred.TextInputFormat
import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat
import org.apache.hadoop.io.Text
+ import org.apache.hadoop.mapred.TextInputFormat
tbl.setInputFormatClass(classOf[TextInputFormat])
tbl.setOutputFormatClass(classOf[HiveIgnoreKeyTextOutputFormat[Text, Text]])
@@ -299,6 +354,7 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
if (crtTbl != null && crtTbl.getLineDelim() != null) {
tbl.setSerdeParam(serdeConstants.LINE_DELIM, crtTbl.getLineDelim())
}
+ HiveShim.setTblNullFormat(crtTbl, tbl)
if (crtTbl != null && crtTbl.getSerdeProps() != null) {
val iter = crtTbl.getSerdeProps().entrySet().iterator()
@@ -380,13 +436,86 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
}
}
+ /**
+ * When scanning or writing to non-partitioned Metastore Parquet tables, convert them to Parquet
+ * data source relations for better performance.
+ *
+ * This rule can be considered as [[HiveStrategies.ParquetConversion]] done right.
+ */
+ object ParquetConversions extends Rule[LogicalPlan] {
+ override def apply(plan: LogicalPlan): LogicalPlan = {
+ // Collects all `MetastoreRelation`s which should be replaced
+ val toBeReplaced = plan.collect {
+ // Write path
+ case InsertIntoTable(relation: MetastoreRelation, _, _, _)
+ // Inserting into partitioned table is not supported in Parquet data source (yet).
+ if !relation.hiveQlTable.isPartitioned &&
+ hive.convertMetastoreParquet &&
+ hive.conf.parquetUseDataSourceApi &&
+ relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
+ val parquetRelation = convertToParquetRelation(relation)
+ val attributedRewrites = relation.output.zip(parquetRelation.output)
+ (relation, parquetRelation, attributedRewrites)
+
+ // Write path
+ case InsertIntoHiveTable(relation: MetastoreRelation, _, _, _)
+ // Inserting into partitioned table is not supported in Parquet data source (yet).
+ if !relation.hiveQlTable.isPartitioned &&
+ hive.convertMetastoreParquet &&
+ hive.conf.parquetUseDataSourceApi &&
+ relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
+ val parquetRelation = convertToParquetRelation(relation)
+ val attributedRewrites = relation.output.zip(parquetRelation.output)
+ (relation, parquetRelation, attributedRewrites)
+
+ // Read path
+ case p @ PhysicalOperation(_, _, relation: MetastoreRelation)
+ if hive.convertMetastoreParquet &&
+ hive.conf.parquetUseDataSourceApi &&
+ relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") =>
+ val parquetRelation = convertToParquetRelation(relation)
+ val attributedRewrites = relation.output.zip(parquetRelation.output)
+ (relation, parquetRelation, attributedRewrites)
+ }
+
+ val relationMap = toBeReplaced.map(r => (r._1, r._2)).toMap
+ val attributedRewrites = AttributeMap(toBeReplaced.map(_._3).fold(Nil)(_ ++: _))
+
+ // Replaces all `MetastoreRelation`s with corresponding `ParquetRelation2`s, and fixes
+ // attribute IDs referenced in other nodes.
+ plan.transformUp {
+ case r: MetastoreRelation if relationMap.contains(r) => {
+ val parquetRelation = relationMap(r)
+ val withAlias =
+ r.alias.map(a => Subquery(a, parquetRelation)).getOrElse(
+ Subquery(r.tableName, parquetRelation))
+
+ withAlias
+ }
+ case InsertIntoTable(r: MetastoreRelation, partition, child, overwrite)
+ if relationMap.contains(r) => {
+ val parquetRelation = relationMap(r)
+ InsertIntoTable(parquetRelation, partition, child, overwrite)
+ }
+ case InsertIntoHiveTable(r: MetastoreRelation, partition, child, overwrite)
+ if relationMap.contains(r) => {
+ val parquetRelation = relationMap(r)
+ InsertIntoTable(parquetRelation, partition, child, overwrite)
+ }
+ case other => other.transformExpressions {
+ case a: Attribute if a.resolved => attributedRewrites.getOrElse(a, a)
+ }
+ }
+ }
+ }
+
/**
* Creates any tables required for query execution.
* For example, because of a CREATE TABLE X AS statement.
*/
object CreateTables extends Rule[LogicalPlan] {
import org.apache.hadoop.hive.ql.Context
- import org.apache.hadoop.hive.ql.parse.{QB, ASTNode, SemanticAnalyzer}
+ import org.apache.hadoop.hive.ql.parse.{ASTNode, QB, SemanticAnalyzer}
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
// Wait until children are resolved.
@@ -417,24 +546,69 @@ private[hive] class HiveMetastoreCatalog(hive: HiveContext) extends Catalog with
Some(sa.getQB().getTableDesc)
}
- execution.CreateTableAsSelect(
- databaseName,
- tableName,
- child,
- allowExisting,
- desc)
+ // Check if the query specifies file format or storage handler.
+ val hasStorageSpec = desc match {
+ case Some(crtTbl) =>
+ crtTbl != null && (crtTbl.getSerName != null || crtTbl.getStorageHandler != null)
+ case None => false
+ }
+
+ if (hive.convertCTAS && !hasStorageSpec) {
+ // Do the conversion when spark.sql.hive.convertCTAS is true and the query
+ // does not specify any storage format (file format and storage handler).
+ if (dbName.isDefined) {
+ throw new AnalysisException(
+ "Cannot specify database name in a CTAS statement " +
+ "when spark.sql.hive.convertCTAS is set to true.")
+ }
+
+ val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists
+ CreateTableUsingAsSelect(
+ tblName,
+ hive.conf.defaultDataSourceName,
+ temporary = false,
+ mode,
+ options = Map.empty[String, String],
+ child
+ )
+ } else {
+ execution.CreateTableAsSelect(
+ databaseName,
+ tableName,
+ child,
+ allowExisting,
+ desc)
+ }
case p: LogicalPlan if p.resolved => p
case p @ CreateTableAsSelect(db, tableName, child, allowExisting, None) =>
val (dbName, tblName) = processDatabaseAndTableName(db, tableName)
- val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
- execution.CreateTableAsSelect(
- databaseName,
- tableName,
- child,
- allowExisting,
- None)
+ if (hive.convertCTAS) {
+ if (dbName.isDefined) {
+ throw new AnalysisException(
+ "Cannot specify database name in a CTAS statement " +
+ "when spark.sql.hive.convertCTAS is set to true.")
+ }
+
+ val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists
+ CreateTableUsingAsSelect(
+ tblName,
+ hive.conf.defaultDataSourceName,
+ temporary = false,
+ mode,
+ options = Map.empty[String, String],
+ child
+ )
+ } else {
+ val databaseName = dbName.getOrElse(hive.sessionState.getCurrentDatabase)
+ execution.CreateTableAsSelect(
+ databaseName,
+ tableName,
+ child,
+ allowExisting,
+ None)
+ }
}
}
@@ -601,7 +775,7 @@ private[hive] case class MetastoreRelation
}
object HiveMetastoreTypes {
- protected val ddlParser = new DDLParser
+ protected val ddlParser = new DDLParser(HiveQl.parseSql(_))
def toDataType(metastoreType: String): DataType = synchronized {
ddlParser.parseType(metastoreType)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index f51af62d3340b..98263f602e9ec 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -18,6 +18,8 @@
package org.apache.spark.sql.hive
import java.sql.Date
+
+
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.hive.conf.HiveConf
@@ -26,13 +28,14 @@ import org.apache.hadoop.hive.ql.lib.Node
import org.apache.hadoop.hive.ql.metadata.Table
import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan.PlanUtils
-import org.apache.spark.sql.SparkSQLParser
+import org.apache.spark.sql.{AnalysisException, SparkSQLParser}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.execution.ExplainCommand
import org.apache.spark.sql.sources.DescribeCommand
import org.apache.spark.sql.hive.execution.{HiveNativeCommand, DropTable, AnalyzeTable, HiveScriptIOSchema}
@@ -61,7 +64,6 @@ private[hive] object HiveQl {
"TOK_SHOWINDEXES",
"TOK_SHOWINDEXES",
"TOK_SHOWPARTITIONS",
- "TOK_SHOWTABLES",
"TOK_SHOW_TBLPROPERTIES",
"TOK_LOCKTABLE",
@@ -76,6 +78,7 @@ private[hive] object HiveQl {
"TOK_REVOKE",
"TOK_SHOW_GRANT",
"TOK_SHOW_ROLE_GRANT",
+ "TOK_SHOW_SET_ROLE",
"TOK_CREATEFUNCTION",
"TOK_DROPFUNCTION",
@@ -125,6 +128,7 @@ private[hive] object HiveQl {
// Commands that we do not need to explain.
protected val noExplainCommands = Seq(
"TOK_DESCTABLE",
+ "TOK_SHOWTABLES",
"TOK_TRUNCATETABLE" // truncate table" is a NativeCommand, does not need to explain.
) ++ nativeCommands
@@ -209,12 +213,6 @@ private[hive] object HiveQl {
}
}
- class ParseException(sql: String, cause: Throwable)
- extends Exception(s"Failed to parse: $sql", cause)
-
- class SemanticException(msg: String)
- extends Exception(s"Error in semantic analysis: $msg")
-
/**
* Returns the AST for the given SQL string.
*/
@@ -234,8 +232,10 @@ private[hive] object HiveQl {
/** Returns a LogicalPlan for a given HiveQL string. */
def parseSql(sql: String): LogicalPlan = hqlParser(sql)
+ val errorRegEx = "line (\\d+):(\\d+) (.*)".r
+
/** Creates LogicalPlan for a given HiveQL string. */
- def createPlan(sql: String) = {
+ def createPlan(sql: String): LogicalPlan = {
try {
val tree = getAst(sql)
if (nativeCommands contains tree.getText) {
@@ -247,14 +247,23 @@ private[hive] object HiveQl {
}
}
} catch {
- case e: Exception => throw new ParseException(sql, e)
- case e: NotImplementedError => sys.error(
- s"""
- |Unsupported language features in query: $sql
- |${dumpTree(getAst(sql))}
- |$e
- |${e.getStackTrace.head}
- """.stripMargin)
+ case pe: org.apache.hadoop.hive.ql.parse.ParseException =>
+ pe.getMessage match {
+ case errorRegEx(line, start, message) =>
+ throw new AnalysisException(message, Some(line.toInt), Some(start.toInt))
+ case otherMessage =>
+ throw new AnalysisException(otherMessage)
+ }
+ case e: Exception =>
+ throw new AnalysisException(e.getMessage)
+ case e: NotImplementedError =>
+ throw new AnalysisException(
+ s"""
+ |Unsupported language features in query: $sql
+ |${dumpTree(getAst(sql))}
+ |$e
+ |${e.getStackTrace.head}
+ """.stripMargin)
}
}
@@ -290,6 +299,7 @@ private[hive] object HiveQl {
/** @return matches of the form (tokenName, children). */
def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match {
case t: ASTNode =>
+ CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine)
Some((t.getText,
Option(t.getChildren).map(_.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]]))
case _ => None
@@ -464,23 +474,21 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
// Just fake explain for any of the native commands.
case Token("TOK_EXPLAIN", explainArgs)
if noExplainCommands.contains(explainArgs.head.getText) =>
- ExplainCommand(NoRelation, Seq(AttributeReference("plan", StringType, nullable = false)()))
+ ExplainCommand(NoRelation)
case Token("TOK_EXPLAIN", explainArgs)
if "TOK_CREATETABLE" == explainArgs.head.getText =>
val Some(crtTbl) :: _ :: extended :: Nil =
getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs)
ExplainCommand(
nodeToPlan(crtTbl),
- Seq(AttributeReference("plan", StringType,nullable = false)()),
- extended != None)
+ extended = extended.isDefined)
case Token("TOK_EXPLAIN", explainArgs) =>
// Ignore FORMATTED if present.
val Some(query) :: _ :: extended :: Nil =
getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs)
ExplainCommand(
nodeToPlan(query),
- Seq(AttributeReference("plan", StringType, nullable = false)()),
- extended != None)
+ extended = extended.isDefined)
case Token("TOK_DESCTABLE", describeArgs) =>
// Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
@@ -1099,7 +1107,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
Cast(nodeToExpr(arg), DateType)
/* Arithmetic */
- case Token("+", child :: Nil) => Add(Literal(0), nodeToExpr(child))
+ case Token("+", child :: Nil) => nodeToExpr(child)
case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child))
case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child))
case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right))
@@ -1237,6 +1245,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL =>
Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1)))
+ case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL =>
+ Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText))
+
case a: ASTNode =>
throw new NotImplementedError(
s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} :
@@ -1275,7 +1286,12 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0)
: StringBuilder = {
node match {
- case a: ASTNode => builder.append((" " * indent) + a.getText + "\n")
+ case a: ASTNode => builder.append(
+ (" " * indent) + a.getText + " " +
+ a.getLine + ", " +
+ a.getTokenStartIndex + "," +
+ a.getTokenStopIndex + ", " +
+ a.getCharPositionInLine + "\n")
case other => sys.error(s"Non ASTNode encountered: $other")
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
index 95abc363ae767..e63cea60457d9 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeComman
import org.apache.spark.sql.execution._
import org.apache.spark.sql.hive.execution._
import org.apache.spark.sql.parquet.ParquetRelation
-import org.apache.spark.sql.sources.{CreateTableUsingAsLogicalPlan, CreateTableUsingAsSelect, CreateTableUsing}
+import org.apache.spark.sql.sources.{CreateTableUsingAsSelect, CreateTableUsing}
import org.apache.spark.sql.types.StringType
@@ -139,15 +139,20 @@ private[hive] trait HiveStrategies {
val partitionLocations = partitions.map(_.getLocation)
- hiveContext
- .parquetFile(partitionLocations.head, partitionLocations.tail: _*)
- .addPartitioningAttributes(relation.partitionKeys)
- .lowerCase
- .where(unresolvedOtherPredicates)
- .select(unresolvedProjection: _*)
- .queryExecution
- .executedPlan
- .fakeOutput(projectList.map(_.toAttribute)) :: Nil
+ if (partitionLocations.isEmpty) {
+ PhysicalRDD(plan.output, sparkContext.emptyRDD[Row]) :: Nil
+ } else {
+ hiveContext
+ .parquetFile(partitionLocations: _*)
+ .addPartitioningAttributes(relation.partitionKeys)
+ .lowerCase
+ .where(unresolvedOtherPredicates)
+ .select(unresolvedProjection: _*)
+ .queryExecution
+ .executedPlan
+ .fakeOutput(projectList.map(_.toAttribute)) :: Nil
+ }
+
} else {
hiveContext
.parquetFile(relation.hiveQlTable.getDataLocation.toString)
@@ -216,20 +221,15 @@ private[hive] trait HiveStrategies {
object HiveDDLStrategy extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case CreateTableUsing(tableName, userSpecifiedSchema, provider, false, opts, allowExisting) =>
+ case CreateTableUsing(
+ tableName, userSpecifiedSchema, provider, false, opts, allowExisting, managedIfNoPath) =>
ExecutedCommand(
CreateMetastoreDataSource(
- tableName, userSpecifiedSchema, provider, opts, allowExisting)) :: Nil
-
- case CreateTableUsingAsSelect(tableName, provider, false, opts, allowExisting, query) =>
- val logicalPlan = hiveContext.parseSql(query)
- val cmd =
- CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, logicalPlan)
- ExecutedCommand(cmd) :: Nil
+ tableName, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath)) :: Nil
- case CreateTableUsingAsLogicalPlan(tableName, provider, false, opts, allowExisting, query) =>
+ case CreateTableUsingAsSelect(tableName, provider, false, mode, opts, query) =>
val cmd =
- CreateMetastoreDataSourceAsSelect(tableName, provider, opts, allowExisting, query)
+ CreateMetastoreDataSourceAsSelect(tableName, provider, mode, opts, query)
ExecutedCommand(cmd) :: Nil
case _ => Nil
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
index bfacc51ef57ab..07b5a84fb6602 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala
@@ -29,9 +29,9 @@ import org.apache.spark.sql.hive.HiveShim
import org.apache.spark.sql.SQLContext
/**
- * Implementation for "describe [extended] table".
- *
* :: DeveloperApi ::
+ *
+ * Implementation for "describe [extended] table".
*/
@DeveloperApi
case class DescribeHiveTableCommand(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
index 95dcaccefdc54..9934a5d3c30a2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala
@@ -18,8 +18,11 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.sources.ResolvedDataSource
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.sources._
+import org.apache.spark.sql.{SaveMode, DataFrame, SQLContext}
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.RunnableCommand
@@ -60,10 +63,10 @@ case class DropTable(
} catch {
// This table's metadata is not in
case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException =>
- // Other exceptions can be caused by users providing wrong parameters in OPTIONS
+ // Other Throwables can be caused by users providing wrong parameters in OPTIONS
// (e.g. invalid paths). We catch it and log a warning message.
- // Users should be able to drop such kinds of tables regardless if there is an exception.
- case e: Exception => log.warn(s"${e.getMessage}")
+ // Users should be able to drop such kinds of tables regardless if there is an error.
+ case e: Throwable => log.warn(s"${e.getMessage}")
}
hiveContext.invalidateTable(tableName)
hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName")
@@ -100,12 +103,17 @@ case class AddFile(path: String) extends RunnableCommand {
}
}
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
case class CreateMetastoreDataSource(
tableName: String,
userSpecifiedSchema: Option[StructType],
provider: String,
options: Map[String, String],
- allowExisting: Boolean) extends RunnableCommand {
+ allowExisting: Boolean,
+ managedIfNoPath: Boolean) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
@@ -114,13 +122,13 @@ case class CreateMetastoreDataSource(
if (allowExisting) {
return Seq.empty[Row]
} else {
- sys.error(s"Table $tableName already exists.")
+ throw new AnalysisException(s"Table $tableName already exists.")
}
}
var isExternal = true
val optionsWithPath =
- if (!options.contains("path")) {
+ if (!options.contains("path") && managedIfNoPath) {
isExternal = false
options + ("path" -> hiveContext.catalog.hiveDefaultTableFilePath(tableName))
} else {
@@ -138,25 +146,20 @@ case class CreateMetastoreDataSource(
}
}
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
case class CreateMetastoreDataSourceAsSelect(
tableName: String,
provider: String,
+ mode: SaveMode,
options: Map[String, String],
- allowExisting: Boolean,
query: LogicalPlan) extends RunnableCommand {
override def run(sqlContext: SQLContext): Seq[Row] = {
val hiveContext = sqlContext.asInstanceOf[HiveContext]
-
- if (hiveContext.catalog.tableExists(tableName :: Nil)) {
- if (allowExisting) {
- return Seq.empty[Row]
- } else {
- sys.error(s"Table $tableName already exists.")
- }
- }
-
- val df = DataFrame(hiveContext, query)
+ var createMetastoreTable = false
var isExternal = true
val optionsWithPath =
if (!options.contains("path")) {
@@ -166,15 +169,84 @@ case class CreateMetastoreDataSourceAsSelect(
options
}
- // Create the relation based on the data of df.
- ResolvedDataSource(sqlContext, provider, optionsWithPath, df)
+ if (sqlContext.catalog.tableExists(Seq(tableName))) {
+ // Check if we need to throw an exception or just return.
+ mode match {
+ case SaveMode.ErrorIfExists =>
+ throw new AnalysisException(s"Table $tableName already exists. " +
+ s"If you are using saveAsTable, you can set SaveMode to SaveMode.Append to " +
+ s"insert data into the table or set SaveMode to SaveMode.Overwrite to overwrite" +
+ s"the existing data. " +
+ s"Or, if you are using SQL CREATE TABLE, you need to drop $tableName first.")
+ case SaveMode.Ignore =>
+ // Since the table already exists and the save mode is Ignore, we will just return.
+ return Seq.empty[Row]
+ case SaveMode.Append =>
+ // Check if the specified data source match the data source of the existing table.
+ val resolved =
+ ResolvedDataSource(sqlContext, Some(query.schema), provider, optionsWithPath)
+ val createdRelation = LogicalRelation(resolved.relation)
+ EliminateSubQueries(sqlContext.table(tableName).logicalPlan) match {
+ case l @ LogicalRelation(i: InsertableRelation) =>
+ if (l.schema != createdRelation.schema) {
+ val errorDescription =
+ s"Cannot append to table $tableName because the schema of this " +
+ s"DataFrame does not match the schema of table $tableName."
+ val errorMessage =
+ s"""
+ |$errorDescription
+ |== Schemas ==
+ |${sideBySide(
+ s"== Expected Schema ==" +:
+ l.schema.treeString.split("\\\n"),
+ s"== Actual Schema ==" +:
+ createdRelation.schema.treeString.split("\\\n")).mkString("\n")}
+ """.stripMargin
+ throw new AnalysisException(errorMessage)
+ } else if (i != createdRelation.relation) {
+ val errorDescription =
+ s"Cannot append to table $tableName because the resolved relation does not " +
+ s"match the existing relation of $tableName. " +
+ s"You can use insertInto($tableName, false) to append this DataFrame to the " +
+ s"table $tableName and using its data source and options."
+ val errorMessage =
+ s"""
+ |$errorDescription
+ |== Relations ==
+ |${sideBySide(
+ s"== Expected Relation ==" ::
+ l.toString :: Nil,
+ s"== Actual Relation ==" ::
+ createdRelation.toString :: Nil).mkString("\n")}
+ """.stripMargin
+ throw new AnalysisException(errorMessage)
+ }
+ case o =>
+ throw new AnalysisException(s"Saving data in ${o.toString} is not supported.")
+ }
+ case SaveMode.Overwrite =>
+ hiveContext.sql(s"DROP TABLE IF EXISTS $tableName")
+ // Need to create the table again.
+ createMetastoreTable = true
+ }
+ } else {
+ // The table does not exist. We need to create it in metastore.
+ createMetastoreTable = true
+ }
- hiveContext.catalog.createDataSourceTable(
- tableName,
- None,
- provider,
- optionsWithPath,
- isExternal)
+ val df = DataFrame(hiveContext, query)
+
+ // Create the relation based on the data of df.
+ ResolvedDataSource(sqlContext, provider, mode, optionsWithPath, df)
+
+ if (createMetastoreTable) {
+ hiveContext.catalog.createDataSourceTable(
+ tableName,
+ Some(df.schema),
+ provider,
+ optionsWithPath,
+ isExternal)
+ }
Seq.empty[Row]
}
diff --git a/core/src/test/scala/org/apache/spark/util/FakeClock.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/package.scala
similarity index 73%
rename from core/src/test/scala/org/apache/spark/util/FakeClock.scala
rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/package.scala
index 0a45917b08dd2..4989c42e964ec 100644
--- a/core/src/test/scala/org/apache/spark/util/FakeClock.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/package.scala
@@ -15,12 +15,11 @@
* limitations under the License.
*/
-package org.apache.spark.util
+package org.apache.spark.sql.hive
-class FakeClock extends Clock {
- private var time = 0L
-
- def advance(millis: Long): Unit = time += millis
-
- def getTime(): Long = time
-}
+/**
+ * Physical execution operators used for running queries against data stored in Hive. These
+ * are not intended for use by users, but are documents so that it is easier to understand
+ * the output of EXPLAIN queries.
+ */
+package object execution
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index aae175e426ade..f136e43acc8f2 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -30,6 +30,7 @@ import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat}
import org.apache.hadoop.hive.ql.plan.{PlanUtils, TableDesc}
import org.apache.hadoop.io.Writable
import org.apache.hadoop.mapred._
+import org.apache.hadoop.hive.common.FileUtils
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql.Row
@@ -212,9 +213,14 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
.zip(row.toSeq.takeRight(dynamicPartColNames.length))
.map { case (col, rawVal) =>
val string = if (rawVal == null) null else String.valueOf(rawVal)
- s"/$col=${if (string == null || string.isEmpty) defaultPartName else string}"
- }
- .mkString
+ val colString =
+ if (string == null || string.isEmpty) {
+ defaultPartName
+ } else {
+ FileUtils.escapePathName(string)
+ }
+ s"/$col=$colString"
+ }.mkString
def newWriter = {
val newFileSinkDesc = new FileSinkDesc(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala
index a6c8ed4f7e866..db074361ef03c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package.scala
@@ -17,4 +17,14 @@
package org.apache.spark.sql
+/**
+ * Support for running Spark SQL queries using functionality from Apache Hive (does not require an
+ * existing Hive installation). Supported Hive features include:
+ * - Using HiveQL to express queries.
+ * - Reading metadata from the Hive Metastore using HiveSerDes.
+ * - Hive UDFs, UDAs, UDTs
+ *
+ * Users that would like access to this functionality should create a
+ * [[hive.HiveContext HiveContext]] instead of a [[SQLContext]].
+ */
package object hive
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala
deleted file mode 100644
index 2a16c9d1a27c9..0000000000000
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/parquet/FakeParquetSerDe.scala
+++ /dev/null
@@ -1,56 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.hive.parquet
-
-import java.util.Properties
-
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector.Category
-import org.apache.hadoop.hive.serde2.{SerDeStats, SerDe}
-import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector
-import org.apache.hadoop.io.Writable
-
-/**
- * A placeholder that allows Spark SQL users to create metastore tables that are stored as
- * parquet files. It is only intended to pass the checks that the serde is valid and exists
- * when a CREATE TABLE is run. The actual work of decoding will be done by ParquetTableScan
- * when "spark.sql.hive.convertMetastoreParquet" is set to true.
- */
-@deprecated("No code should depend on FakeParquetHiveSerDe as it is only intended as a " +
- "placeholder in the Hive MetaStore", "1.2.0")
-class FakeParquetSerDe extends SerDe {
- override def getObjectInspector: ObjectInspector = new ObjectInspector {
- override def getCategory: Category = Category.PRIMITIVE
-
- override def getTypeName: String = "string"
- }
-
- override def deserialize(p1: Writable): AnyRef = throwError
-
- override def initialize(p1: Configuration, p2: Properties): Unit = {}
-
- override def getSerializedClass: Class[_ <: Writable] = throwError
-
- override def getSerDeStats: SerDeStats = throwError
-
- override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = throwError
-
- private def throwError =
- sys.error(
- "spark.sql.hive.convertMetastoreParquet must be set to true to use FakeParquetSerDe")
-}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
similarity index 99%
rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 7c1d1133c3425..a2d99f1f4b28d 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -20,9 +20,6 @@ package org.apache.spark.sql.hive.test
import java.io.File
import java.util.{Set => JavaSet}
-import scala.collection.mutable
-import scala.language.implicitConversions
-
import org.apache.hadoop.hive.ql.exec.FunctionRegistry
import org.apache.hadoop.hive.ql.io.avro.{AvroContainerInputFormat, AvroContainerOutputFormat}
import org.apache.hadoop.hive.ql.metadata.Table
@@ -30,16 +27,18 @@ import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.serde2.RegexSerDe
import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.hadoop.hive.serde2.avro.AvroSerDe
-
-import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.execution.CacheTableCommand
import org.apache.spark.sql.hive._
-import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.hive.execution.HiveNativeCommand
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkContext}
+
+import scala.collection.mutable
+import scala.language.implicitConversions
/* Implicit conversions */
import scala.collection.JavaConversions._
@@ -197,6 +196,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
// The test tables that are defined in the Hive QTestUtil.
// /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java
+ // https://github.com/apache/hive/blob/branch-0.13/data/scripts/q_test_init.sql
val hiveQTestUtilTables = Seq(
TestTable("src",
"CREATE TABLE src (key INT, value STRING)".cmd,
@@ -224,11 +224,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
}
}),
TestTable("src_thrift", () => {
- import org.apache.thrift.protocol.TBinaryProtocol
- import org.apache.hadoop.hive.serde2.thrift.test.Complex
import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer
- import org.apache.hadoop.mapred.SequenceFileInputFormat
- import org.apache.hadoop.mapred.SequenceFileOutputFormat
+ import org.apache.hadoop.hive.serde2.thrift.test.Complex
+ import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat}
+ import org.apache.thrift.protocol.TBinaryProtocol
val srcThrift = new Table("default", "src_thrift")
srcThrift.setFields(Nil)
diff --git a/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
new file mode 100644
index 0000000000000..53ddecf57958b
--- /dev/null
+++ b/sql/hive/src/test/java/org/apache/spark/sql/hive/JavaMetastoreDataSourcesSuite.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.hive;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
+import org.apache.spark.sql.SaveMode;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.QueryTest$;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.hive.test.TestHive$;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.StructField;
+import org.apache.spark.sql.types.StructType;
+import org.apache.spark.util.Utils;
+
+public class JavaMetastoreDataSourcesSuite {
+ private transient JavaSparkContext sc;
+ private transient HiveContext sqlContext;
+
+ String originalDefaultSource;
+ File path;
+ Path hiveManagedPath;
+ FileSystem fs;
+ DataFrame df;
+
+ private void checkAnswer(DataFrame actual, List expected) {
+ String errorMessage = QueryTest$.MODULE$.checkAnswer(actual, expected);
+ if (errorMessage != null) {
+ Assert.fail(errorMessage);
+ }
+ }
+
+ @Before
+ public void setUp() throws IOException {
+ sqlContext = TestHive$.MODULE$;
+ sc = new JavaSparkContext(sqlContext.sparkContext());
+
+ originalDefaultSource = sqlContext.conf().defaultDataSourceName();
+ path =
+ Utils.createTempDir(System.getProperty("java.io.tmpdir"), "datasource").getCanonicalFile();
+ if (path.exists()) {
+ path.delete();
+ }
+ hiveManagedPath = new Path(sqlContext.catalog().hiveDefaultTableFilePath("javaSavedTable"));
+ fs = hiveManagedPath.getFileSystem(sc.hadoopConfiguration());
+ if (fs.exists(hiveManagedPath)){
+ fs.delete(hiveManagedPath, true);
+ }
+
+ List jsonObjects = new ArrayList(10);
+ for (int i = 0; i < 10; i++) {
+ jsonObjects.add("{\"a\":" + i + ", \"b\":\"str" + i + "\"}");
+ }
+ JavaRDD rdd = sc.parallelize(jsonObjects);
+ df = sqlContext.jsonRDD(rdd);
+ df.registerTempTable("jsonTable");
+ }
+
+ @After
+ public void tearDown() throws IOException {
+ // Clean up tables.
+ sqlContext.sql("DROP TABLE IF EXISTS javaSavedTable");
+ sqlContext.sql("DROP TABLE IF EXISTS externalTable");
+ }
+
+ @Test
+ public void saveExternalTableAndQueryIt() {
+ Map options = new HashMap();
+ options.put("path", path.toString());
+ df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options);
+
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM javaSavedTable"),
+ df.collectAsList());
+
+ DataFrame loadedDF =
+ sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", options);
+
+ checkAnswer(loadedDF, df.collectAsList());
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM externalTable"),
+ df.collectAsList());
+ }
+
+ @Test
+ public void saveExternalTableWithSchemaAndQueryIt() {
+ Map options = new HashMap();
+ options.put("path", path.toString());
+ df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options);
+
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM javaSavedTable"),
+ df.collectAsList());
+
+ List fields = new ArrayList();
+ fields.add(DataTypes.createStructField("b", DataTypes.StringType, true));
+ StructType schema = DataTypes.createStructType(fields);
+ DataFrame loadedDF =
+ sqlContext.createExternalTable("externalTable", "org.apache.spark.sql.json", schema, options);
+
+ checkAnswer(
+ loadedDF,
+ sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList());
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM externalTable"),
+ sqlContext.sql("SELECT b FROM javaSavedTable").collectAsList());
+ }
+
+ @Test
+ public void saveTableAndQueryIt() {
+ Map options = new HashMap();
+ df.saveAsTable("javaSavedTable", "org.apache.spark.sql.json", SaveMode.Append, options);
+
+ checkAnswer(
+ sqlContext.sql("SELECT * FROM javaSavedTable"),
+ df.collectAsList());
+ }
+}
diff --git a/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7 b/sql/hive/src/test/resources/golden/inputddl5-0-ebbf2aec5f76af7225c2efaf870b8ba7
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75 b/sql/hive/src/test/resources/golden/inputddl5-1-2691407ccdc5c848a4ba2aecb6dbad75
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b
new file mode 100644
index 0000000000000..518a70918b2c7
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/inputddl5-2-ca2faacf63dc4785f8bfd2ecc397e69b
@@ -0,0 +1 @@
+name string
diff --git a/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce
new file mode 100644
index 0000000000000..33398360345d7
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/inputddl5-3-4f28c7412a05cff89c0bd86b65aa7ce
@@ -0,0 +1 @@
+邵铮
diff --git a/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783 b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783
new file mode 100644
index 0000000000000..d00491fd7e5bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/inputddl5-4-bd7e25cff73f470d2e2336876342b783
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-0-36f9196395758cebfed837a1c391a1e b/sql/hive/src/test/resources/golden/nullformatCTAS-0-36f9196395758cebfed837a1c391a1e
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-1-b5a31d4cb34218b8de1ac3fed59fa75b b/sql/hive/src/test/resources/golden/nullformatCTAS-1-b5a31d4cb34218b8de1ac3fed59fa75b
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b b/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b
new file mode 100644
index 0000000000000..e74deff51c9ba
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/nullformatCTAS-10-7f4f04b87c7ef9653b4646949b24cf0b
@@ -0,0 +1,10 @@
+1.0 1
+1.0 1
+1.0 1
+1.0 1
+1.0 1
+NULL 1
+NULL NULL
+1.0 NULL
+1.0 1
+1.0 1
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-11-4a4c16b53c612d00012d338c97bf5281 b/sql/hive/src/test/resources/golden/nullformatCTAS-11-4a4c16b53c612d00012d338c97bf5281
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b b/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b
new file mode 100644
index 0000000000000..00ebb521970dd
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/nullformatCTAS-12-7f4f04b87c7ef9653b4646949b24cf0b
@@ -0,0 +1,10 @@
+1.0 1
+1.0 1
+1.0 1
+1.0 1
+1.0 1
+fooNull 1
+fooNull fooNull
+1.0 fooNull
+1.0 1
+1.0 1
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-13-2e59caa113585495d8684fee69d88bc0 b/sql/hive/src/test/resources/golden/nullformatCTAS-13-2e59caa113585495d8684fee69d88bc0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-14-ad9fe9d68c2cf492259af4f6167c1b12 b/sql/hive/src/test/resources/golden/nullformatCTAS-14-ad9fe9d68c2cf492259af4f6167c1b12
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-2-aa2bdbd93668dceae43d1a02f2ede68d b/sql/hive/src/test/resources/golden/nullformatCTAS-2-aa2bdbd93668dceae43d1a02f2ede68d
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273 b/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273
new file mode 100644
index 0000000000000..b00bcb3624532
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/nullformatCTAS-3-b0057150f237050f38c1efa1f2d6b273
@@ -0,0 +1,6 @@
+a string
+b string
+c string
+d string
+
+Detailed Table Information Table(tableName:base_tab, dbName:default, owner:animal, createTime:1423973915, lastAccessTime:0, retention:0, sd:StorageDescriptor(cols:[FieldSchema(name:a, type:string, comment:null), FieldSchema(name:b, type:string, comment:null), FieldSchema(name:c, type:string, comment:null), FieldSchema(name:d, type:string, comment:null)], location:file:/tmp/sparkHiveWarehouse2573474017665704744/base_tab, inputFormat:org.apache.hadoop.mapred.TextInputFormat, outputFormat:org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, compressed:false, numBuckets:-1, serdeInfo:SerDeInfo(name:null, serializationLib:org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, parameters:{serialization.format=1}), bucketCols:[], sortCols:[], parameters:{}, skewedInfo:SkewedInfo(skewedColNames:[], skewedColValues:[], skewedColValueLocationMaps:{}), storedAsSubDirectories:false), partitionKeys:[], parameters:{numFiles=1, transient_lastDdlTime=1423973915, COLUMN_STATS_ACCURATE=true, totalSize=130, numRows=0, rawDataSize=0}, viewOriginalText:null, viewExpandedText:null, tableType:MANAGED_TABLE)
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-4-16c7086f39d6458b6c5cf2479f0473bd b/sql/hive/src/test/resources/golden/nullformatCTAS-4-16c7086f39d6458b6c5cf2479f0473bd
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-5-183d77b734ce6a373de5b3ebe1cd04c9 b/sql/hive/src/test/resources/golden/nullformatCTAS-5-183d77b734ce6a373de5b3ebe1cd04c9
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-6-159fff36b548e00ee952d1df8ef19833 b/sql/hive/src/test/resources/golden/nullformatCTAS-6-159fff36b548e00ee952d1df8ef19833
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65 b/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65
new file mode 100644
index 0000000000000..264c973ff7af1
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/nullformatCTAS-7-46900b082b02ce3e58087d1f41128f65
@@ -0,0 +1,4 @@
+a string
+b string
+
+Detailed Table Information Table(tableName:null_tab3, dbName:default, owner:animal, createTime:1423973928, lastAccessTime:0, retention:0, sd:StorageDescriptor(cols:[FieldSchema(name:a, type:string, comment:null), FieldSchema(name:b, type:string, comment:null)], location:file:/tmp/sparkHiveWarehouse2573474017665704744/null_tab3, inputFormat:org.apache.hadoop.mapred.TextInputFormat, outputFormat:org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat, compressed:false, numBuckets:-1, serdeInfo:SerDeInfo(name:null, serializationLib:org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, parameters:{serialization.null.format=fooNull, serialization.format=1}), bucketCols:[], sortCols:[], parameters:{}, skewedInfo:SkewedInfo(skewedColNames:[], skewedColValues:[], skewedColValueLocationMaps:{}), storedAsSubDirectories:false), partitionKeys:[], parameters:{numFiles=1, transient_lastDdlTime=1423973928, COLUMN_STATS_ACCURATE=true, totalSize=80, numRows=10, rawDataSize=70}, viewOriginalText:null, viewExpandedText:null, tableType:MANAGED_TABLE)
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8 b/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8
new file mode 100644
index 0000000000000..881917bcf1c69
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/nullformatCTAS-8-7f26cbd6be5631a3acce26f667d1c5d8
@@ -0,0 +1,18 @@
+CREATE TABLE `null_tab3`(
+ `a` string,
+ `b` string)
+ROW FORMAT DELIMITED
+ NULL DEFINED AS 'fooNull'
+STORED AS INPUTFORMAT
+ 'org.apache.hadoop.mapred.TextInputFormat'
+OUTPUTFORMAT
+ 'org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat'
+LOCATION
+ 'file:/tmp/sparkHiveWarehouse2573474017665704744/null_tab3'
+TBLPROPERTIES (
+ 'numFiles'='1',
+ 'transient_lastDdlTime'='1423973928',
+ 'COLUMN_STATS_ACCURATE'='true',
+ 'totalSize'='80',
+ 'numRows'='10',
+ 'rawDataSize'='70')
diff --git a/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47 b/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47
new file mode 100644
index 0000000000000..3a2e3f4984a0e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/nullformatCTAS-9-22e1b3899de7087b39c24d9d8f643b47
@@ -0,0 +1 @@
+-1
diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-0-c6d02549aec166e16bfc44d5905fa33a b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-0-c6d02549aec166e16bfc44d5905fa33a
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-1-a8987ff8c7b9ca95bf8b32314694ed1f b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-1-a8987ff8c7b9ca95bf8b32314694ed1f
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-2-26f54240cf5b909086fc34a34d7fdb56 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-2-26f54240cf5b909086fc34a34d7fdb56
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-3-d08d5280027adea681001ad82a5a6974 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-3-d08d5280027adea681001ad82a5a6974
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749 b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749
new file mode 100644
index 0000000000000..d00491fd7e5bb
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/test ambiguousReferences resolved as hive-4-22eb25b5be6daf72a6649adfe5041749
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada b/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_reflect2-0-50131c0ba7b7a6b65c789a5a8497bada
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184 b/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184
new file mode 100644
index 0000000000000..cd35e5b290db5
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_reflect2-1-7bec330c7bc6f71cbaf9bf1883d1b184
@@ -0,0 +1 @@
+reflect2(arg0,method[,arg1[,arg2..]]) calls method of arg0 with reflection
diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19 b/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19
new file mode 100644
index 0000000000000..48ef97292ab62
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_reflect2-2-c5a05379f482215a5a484bed0299bf19
@@ -0,0 +1,3 @@
+reflect2(arg0,method[,arg1[,arg2..]]) calls method of arg0 with reflection
+Use this UDF to call Java methods by matching the argument signature
+
diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-3-effc057c78c00b0af26a4ac0f5f116ca b/sql/hive/src/test/resources/golden/udf_reflect2-3-effc057c78c00b0af26a4ac0f5f116ca
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4 b/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4
new file mode 100644
index 0000000000000..176ea0358d7ea
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/udf_reflect2-4-73d466e70e96e9e5f0cd373b37d4e1f4
@@ -0,0 +1,5 @@
+238 -18 238 238 238 238.0 238.0 238 val_238 val_238_concat false true false false false val_238 -1 -1 VALUE_238 al_238 al_2 VAL_238 val_238 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000
+86 86 86 86 86 86.0 86.0 86 val_86 val_86_concat true true true true true val_86 -1 -1 VALUE_86 al_86 al_8 VAL_86 val_86 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000
+311 55 311 311 311 311.0 311.0 311 val_311 val_311_concat false true false false false val_311 5 6 VALUE_311 al_311 al_3 VAL_311 val_311 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000
+27 27 27 27 27 27.0 27.0 27 val_27 val_27_concat false true false false false val_27 -1 -1 VALUE_27 al_27 al_2 VAL_27 val_27 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000
+165 -91 165 165 165 165.0 165.0 165 val_165 val_165_concat false true false false false val_165 4 4 VALUE_165 al_165 al_1 VAL_165 val_165 2013-02-15 19:41:20 113 1 5 19 41 20 1360986080000
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
index ba391293884bd..0270e63557963 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -17,10 +17,8 @@
package org.apache.spark.sql
-import org.scalatest.FunSuite
+import scala.collection.JavaConversions._
-import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.util._
@@ -55,9 +53,36 @@ class QueryTest extends PlanTest {
/**
* Runs the plan and makes sure the answer matches the expected result.
* @param rdd the [[DataFrame]] to be executed
- * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ].
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
*/
protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = {
+ QueryTest.checkAnswer(rdd, expectedAnswer) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
+ checkAnswer(rdd, Seq(expectedAnswer))
+ }
+
+ def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
+ test(sqlString) {
+ checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ }
+ }
+}
+
+object QueryTest {
+ /**
+ * Runs the plan and makes sure the answer matches the expected result.
+ * If there was exception during the execution or the contents of the DataFrame does not
+ * match the expected result, an error message will be returned. Otherwise, a [[None]] will
+ * be returned.
+ * @param rdd the [[DataFrame]] to be executed
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ */
+ def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Option[String] = {
val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
@@ -73,18 +98,20 @@ class QueryTest extends PlanTest {
}
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
- fail(
+ val errorMessage =
s"""
|Exception thrown while executing query:
|${rdd.queryExecution}
|== Exception ==
|$e
|${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin)
+ """.stripMargin
+ return Some(errorMessage)
}
if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
- fail(s"""
+ val errorMessage =
+ s"""
|Results do not match for query:
|${rdd.logicalPlan}
|== Analyzed Plan ==
@@ -93,22 +120,21 @@ class QueryTest extends PlanTest {
|${rdd.queryExecution.executedPlan}
|== Results ==
|${sideBySide(
- s"== Correct Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer).map(_.toString),
- s"== Spark Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
- """.stripMargin)
+ s"== Correct Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString),
+ s"== Spark Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString)).mkString("\n")}
+ """.stripMargin
+ return Some(errorMessage)
}
- }
- protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = {
- checkAnswer(rdd, Seq(expectedAnswer))
+ return None
}
- def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext): Unit = {
- test(sqlString) {
- checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
+ def checkAnswer(rdd: DataFrame, expectedAnswer: java.util.List[Row]): String = {
+ checkAnswer(rdd, expectedAnswer.toSeq) match {
+ case Some(errorMessage) => errorMessage
+ case None => null
}
}
-
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
index 081d94b6fc020..44ee5ab5975fb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala
@@ -35,11 +35,9 @@ class PlanTest extends FunSuite {
* we must normalize them to check if two different queries are identical.
*/
protected def normalizeExprIds(plan: LogicalPlan) = {
- val list = plan.flatMap(_.expressions.flatMap(_.references).map(_.exprId.id))
- val minId = if (list.isEmpty) 0 else list.min
plan transformAllExpressions {
case a: AttributeReference =>
- AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(a.exprId.id - minId))
+ AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0))
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
index 7c8b5205e239e..44d24273e722a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
-import org.apache.spark.sql.{DataFrame, QueryTest}
+import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest}
import org.apache.spark.storage.RDDBlockId
class CachedTableSuite extends QueryTest {
@@ -96,7 +96,7 @@ class CachedTableSuite extends QueryTest {
cacheTable("test")
sql("SELECT * FROM test").collect()
sql("DROP TABLE test")
- intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] {
+ intercept[AnalysisException] {
sql("SELECT * FROM test").collect()
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
new file mode 100644
index 0000000000000..f04437c595bf6
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import java.io.{OutputStream, PrintStream}
+
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.{AnalysisException, QueryTest}
+
+import scala.util.Try
+
+class ErrorPositionSuite extends QueryTest {
+
+ positionTest("unresolved attribute 1",
+ "SELECT x FROM src", "x")
+
+ positionTest("unresolved attribute 2",
+ "SELECT x FROM src", "x")
+
+ positionTest("unresolved attribute 3",
+ "SELECT key, x FROM src", "x")
+
+ positionTest("unresolved attribute 4",
+ """SELECT key,
+ |x FROM src
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute 5",
+ """SELECT key,
+ | x FROM src
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute 6",
+ """SELECT key,
+ |
+ | 1 + x FROM src
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute 7",
+ """SELECT key,
+ |
+ | 1 + x + 1 FROM src
+ """.stripMargin, "x")
+
+ positionTest("multi-char unresolved attribute",
+ """SELECT key,
+ |
+ | 1 + abcd + 1 FROM src
+ """.stripMargin, "abcd")
+
+ positionTest("unresolved attribute group by",
+ """SELECT key FROM src GROUP BY
+ |x
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute order by",
+ """SELECT key FROM src ORDER BY
+ |x
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute where",
+ """SELECT key FROM src
+ |WHERE x = true
+ """.stripMargin, "x")
+
+ positionTest("unresolved attribute backticks",
+ "SELECT `x` FROM src", "`x`")
+
+ positionTest("parse error",
+ "SELECT WHERE", "WHERE")
+
+ positionTest("bad relation",
+ "SELECT * FROM badTable", "badTable")
+
+ ignore("other expressions") {
+ positionTest("bad addition",
+ "SELECT 1 + array(1)", "1 + array")
+ }
+
+ /** Hive can be very noisy, messing up the output of our tests. */
+ private def quietly[A](f: => A): A = {
+ val origErr = System.err
+ val origOut = System.out
+ try {
+ System.setErr(new PrintStream(new OutputStream {
+ def write(b: Int) = {}
+ }))
+ System.setOut(new PrintStream(new OutputStream {
+ def write(b: Int) = {}
+ }))
+
+ f
+ } finally {
+ System.setErr(origErr)
+ System.setOut(origOut)
+ }
+ }
+
+ /**
+ * Creates a test that checks to see if the error thrown when analyzing a given query includes
+ * the location of the given token in the query string.
+ *
+ * @param name the name of the test
+ * @param query the query to analyze
+ * @param token a unique token in the string that should be indicated by the exception
+ */
+ def positionTest(name: String, query: String, token: String) = {
+ def parseTree =
+ Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("")
+
+ test(name) {
+ val error = intercept[AnalysisException] {
+ quietly(sql(query))
+ }
+ val (line, expectedLineNum) = query.split("\n").zipWithIndex.collect {
+ case (l, i) if l.contains(token) => (l, i + 1)
+ }.headOption.getOrElse(sys.error(s"Invalid test. Token $token not in $query"))
+ val actualLine = error.line.getOrElse {
+ fail(
+ s"line not returned for error '${error.getMessage}' on token $token\n$parseTree"
+ )
+ }
+ assert(actualLine === expectedLineNum, "wrong line")
+
+ val expectedStart = line.indexOf(token)
+ val actualStart = error.startPosition.getOrElse {
+ fail(
+ s"start not returned for error on token $token\n" +
+ HiveQl.dumpTree(HiveQl.getAst(query))
+ )
+ }
+ assert(expectedStart === actualStart,
+ s"""Incorrect start position.
+ |== QUERY ==
+ |$query
+ |
+ |== AST ==
+ |$parseTree
+ |
+ |Actual: $actualStart, Expected: $expectedStart
+ |$line
+ |${" " * actualStart}^
+ |0123456789 123456789 1234567890
+ | 2 3
+ """.stripMargin)
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
index 869d01eb398c5..d4b175fa443a4 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala
@@ -19,7 +19,11 @@ package org.apache.spark.sql.hive
import java.io.File
+import org.scalatest.BeforeAndAfter
+
import com.google.common.io.Files
+
+import org.apache.spark.sql.execution.QueryExecutionException
import org.apache.spark.sql.{QueryTest, _}
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.types._
@@ -29,15 +33,22 @@ import org.apache.spark.sql.hive.test.TestHive._
case class TestData(key: Int, value: String)
-class InsertIntoHiveTableSuite extends QueryTest {
+class InsertIntoHiveTableSuite extends QueryTest with BeforeAndAfter {
import org.apache.spark.sql.hive.test.TestHive.implicits._
val testData = TestHive.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString)))
- testData.registerTempTable("testData")
+ (1 to 100).map(i => TestData(i, i.toString))).toDF()
+
+ before {
+ // Since every we are doing tests for DDL statements,
+ // it is better to reset before every test.
+ TestHive.reset()
+ // Register the testData, which will be used in every test.
+ testData.registerTempTable("testData")
+ }
test("insertInto() HiveTable") {
- createTable[TestData]("createAndInsertTest")
+ sql("CREATE TABLE createAndInsertTest (key int, value string)")
// Add some data.
testData.insertInto("createAndInsertTest")
@@ -45,7 +56,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq.map(Row.fromTuple)
+ testData.collect().toSeq
)
// Add more data.
@@ -54,7 +65,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the table has been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.toDataFrame.collect().toSeq ++ testData.toDataFrame.collect().toSeq
+ testData.toDF().collect().toSeq ++ testData.toDF().collect().toSeq
)
// Now overwrite.
@@ -63,28 +74,30 @@ class InsertIntoHiveTableSuite extends QueryTest {
// Make sure the registered table has also been updated.
checkAnswer(
sql("SELECT * FROM createAndInsertTest"),
- testData.collect().toSeq.map(Row.fromTuple)
+ testData.collect().toSeq
)
}
test("Double create fails when allowExisting = false") {
- createTable[TestData]("doubleCreateAndInsertTest")
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
- intercept[org.apache.hadoop.hive.ql.metadata.HiveException] {
- createTable[TestData]("doubleCreateAndInsertTest", allowExisting = false)
- }
+ val message = intercept[QueryExecutionException] {
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
+ }.getMessage
+
+ println("message!!!!" + message)
}
test("Double create does not fail when allowExisting = true") {
- createTable[TestData]("createAndInsertTest")
- createTable[TestData]("createAndInsertTest")
+ sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)")
+ sql("CREATE TABLE IF NOT EXISTS doubleCreateAndInsertTest (key int, value string)")
}
test("SPARK-4052: scala.collection.Map as value type of MapType") {
val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil)
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m MAP )")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -98,7 +111,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
}
test("SPARK-4203:random partition directory order") {
- createTable[TestData]("tmp_table")
+ sql("CREATE TABLE tmp_table (key int, value string)")
val tmpDir = Files.createTempDir()
sql(s"CREATE TABLE table_with_partition(c1 string) PARTITIONED by (p1 string,p2 string,p3 string,p4 string,p5 string) location '${tmpDir.toURI.toString}' ")
sql("INSERT OVERWRITE TABLE table_with_partition partition (p1='a',p2='b',p3='c',p4='c',p5='1') SELECT 'blarr' FROM tmp_table")
@@ -129,7 +142,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
val schema = StructType(Seq(
StructField("a", ArrayType(StringType, containsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithArrayValue")
sql("CREATE TABLE hiveTableWithArrayValue(a Array )")
sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue")
@@ -146,7 +159,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
StructField("m", MapType(StringType, StringType, valueContainsNull = false))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Map(s"key$i" -> s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithMapValue")
sql("CREATE TABLE hiveTableWithMapValue(m Map )")
sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue")
@@ -163,7 +176,7 @@ class InsertIntoHiveTableSuite extends QueryTest {
StructField("s", StructType(Seq(StructField("f", StringType, nullable = false))))))
val rowRDD = TestHive.sparkContext.parallelize(
(1 to 100).map(i => Row(Row(s"value$i"))))
- val df = applySchema(rowRDD, schema)
+ val df = TestHive.createDataFrame(rowRDD, schema)
df.registerTempTable("tableWithStructValue")
sql("CREATE TABLE hiveTableWithStructValue(s Struct )")
sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
new file mode 100644
index 0000000000000..e12a6c21ccac4
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala
@@ -0,0 +1,81 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.hive
+
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql.hive.test.TestHive
+import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.Row
+
+class ListTablesSuite extends QueryTest with BeforeAndAfterAll {
+
+ import org.apache.spark.sql.hive.test.TestHive.implicits._
+
+ val df =
+ sparkContext.parallelize((1 to 10).map(i => (i,s"str$i"))).toDF("key", "value")
+
+ override def beforeAll(): Unit = {
+ // The catalog in HiveContext is a case insensitive one.
+ catalog.registerTable(Seq("ListTablesSuiteTable"), df.logicalPlan)
+ catalog.registerTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"), df.logicalPlan)
+ sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)")
+ sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB")
+ sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)")
+ }
+
+ override def afterAll(): Unit = {
+ catalog.unregisterTable(Seq("ListTablesSuiteTable"))
+ catalog.unregisterTable(Seq("ListTablesSuiteDB", "InDBListTablesSuiteTable"))
+ sql("DROP TABLE IF EXISTS HiveListTablesSuiteTable")
+ sql("DROP TABLE IF EXISTS ListTablesSuiteDB.HiveInDBListTablesSuiteTable")
+ sql("DROP DATABASE IF EXISTS ListTablesSuiteDB")
+ }
+
+ test("get all tables of current database") {
+ Seq(tables(), sql("SHOW TABLes")).foreach {
+ case allTables =>
+ // We are using default DB.
+ checkAnswer(
+ allTables.filter("tableName = 'listtablessuitetable'"),
+ Row("listtablessuitetable", true))
+ assert(allTables.filter("tableName = 'indblisttablessuitetable'").count() === 0)
+ checkAnswer(
+ allTables.filter("tableName = 'hivelisttablessuitetable'"),
+ Row("hivelisttablessuitetable", false))
+ assert(allTables.filter("tableName = 'hiveindblisttablessuitetable'").count() === 0)
+ }
+ }
+
+ test("getting all tables with a database name") {
+ Seq(tables("listtablessuiteDb"), sql("SHOW TABLes in listTablesSuitedb")).foreach {
+ case allTables =>
+ checkAnswer(
+ allTables.filter("tableName = 'listtablessuitetable'"),
+ Row("listtablessuitetable", true))
+ checkAnswer(
+ allTables.filter("tableName = 'indblisttablessuitetable'"),
+ Row("indblisttablessuitetable", true))
+ assert(allTables.filter("tableName = 'hivelisttablessuitetable'").count() === 0)
+ checkAnswer(
+ allTables.filter("tableName = 'hiveindblisttablessuitetable'"),
+ Row("hiveindblisttablessuitetable", false))
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 036efa84d7c85..00306f1cd7f86 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -23,29 +23,29 @@ import org.scalatest.BeforeAndAfterEach
import org.apache.commons.io.FileUtils
import org.apache.hadoop.fs.Path
+import org.apache.hadoop.mapred.InvalidInputException
import org.apache.spark.sql.catalyst.util
import org.apache.spark.sql._
import org.apache.spark.util.Utils
import org.apache.spark.sql.types._
-
-/* Implicits */
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.parquet.ParquetRelation2
+import org.apache.spark.sql.sources.LogicalRelation
/**
* Tests for persisting tables created though the data sources API into the metastore.
*/
class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
- import org.apache.spark.sql.hive.test.TestHive.implicits._
-
override def afterEach(): Unit = {
reset()
- if (ctasPath.exists()) Utils.deleteRecursively(ctasPath)
+ if (tempPath.exists()) Utils.deleteRecursively(tempPath)
}
val filePath = Utils.getSparkClassLoader.getResource("sample.json").getFile
- var ctasPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile
+ var tempPath: File = util.getTempFilePath("jsonCTAS").getCanonicalFile
test ("persistent JSON table") {
sql(
@@ -154,7 +154,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
test("check change without refresh") {
val tempDir = File.createTempFile("sparksql", "json")
tempDir.delete()
- sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a", "b") :: Nil).toDF()
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
sql(
s"""
@@ -170,7 +171,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
- sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a1", "b1", "c1") :: Nil).toDF()
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
// Schema is cached so the new column does not show. The updated values in existing columns
// will show.
@@ -178,7 +180,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
sql("SELECT * FROM jsonTable"),
Row("a1", "b1"))
- refreshTable("jsonTable")
+ sql("REFRESH TABLE jsonTable")
// Check that the refresh worked
checkAnswer(
@@ -190,7 +192,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
test("drop, change, recreate") {
val tempDir = File.createTempFile("sparksql", "json")
tempDir.delete()
- sparkContext.parallelize(("a", "b") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a", "b") :: Nil).toDF()
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
sql(
s"""
@@ -206,7 +209,8 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
Row("a", "b"))
FileUtils.deleteDirectory(tempDir)
- sparkContext.parallelize(("a", "b", "c") :: Nil).toJSON.saveAsTextFile(tempDir.getCanonicalPath)
+ sparkContext.parallelize(("a", "b", "c") :: Nil).toDF()
+ .toJSON.saveAsTextFile(tempDir.getCanonicalPath)
sql("DROP TABLE jsonTable")
@@ -270,7 +274,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT * FROM jsonTable
""".stripMargin)
@@ -297,19 +301,19 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT * FROM jsonTable
""".stripMargin)
- // Create the table again should trigger a AlreadyExistsException.
- val message = intercept[RuntimeException] {
+ // Create the table again should trigger a AnalysisException.
+ val message = intercept[AnalysisException] {
sql(
s"""
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT * FROM jsonTable
""".stripMargin)
@@ -325,7 +329,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
|CREATE TABLE IF NOT EXISTS ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
|OPTIONS (
- | path '${ctasPath}'
+ | path '${tempPath}'
|) AS
|SELECT a FROM jsonTable
""".stripMargin)
@@ -361,9 +365,7 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
s"""
|CREATE TABLE ctasJsonTable
|USING org.apache.spark.sql.json.DefaultSource
- |OPTIONS (
- |
- |) AS
+ |AS
|SELECT * FROM jsonTable
""".stripMargin)
@@ -402,38 +404,212 @@ class MetastoreDataSourcesSuite extends QueryTest with BeforeAndAfterEach {
sql("DROP TABLE jsonTable").collect().foreach(println)
}
- test("save and load table") {
+ test("SPARK-5839 HiveMetastoreCatalog does not recognize table aliases of data source tables.") {
+ val originalDefaultSource = conf.defaultDataSourceName
+
+ val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
+ val df = jsonRDD(rdd)
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ // Save the df as a managed table (by not specifiying the path).
+ df.saveAsTable("savedJsonTable")
+
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"),
+ (1 to 4).map(i => Row(i, s"str${i}")))
+
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"),
+ (6 to 10).map(i => Row(i, s"str${i}")))
+
+ invalidateTable("savedJsonTable")
+
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"),
+ (1 to 4).map(i => Row(i, s"str${i}")))
+
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"),
+ (6 to 10).map(i => Row(i, s"str${i}")))
+
+ // Drop table will also delete the data.
+ sql("DROP TABLE savedJsonTable")
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+ }
+
+ test("save table") {
val originalDefaultSource = conf.defaultDataSourceName
- conf.setConf("spark.sql.default.datasource", "org.apache.spark.sql.json")
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
val df = jsonRDD(rdd)
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ // Save the df as a managed table (by not specifiying the path).
df.saveAsTable("savedJsonTable")
checkAnswer(
sql("SELECT * FROM savedJsonTable"),
df.collect())
- createTable("createdJsonTable", catalog.hiveDefaultTableFilePath("savedJsonTable"), false)
+ // Right now, we cannot append to an existing JSON table.
+ intercept[RuntimeException] {
+ df.saveAsTable("savedJsonTable", SaveMode.Append)
+ }
+
+ // We can overwrite it.
+ df.saveAsTable("savedJsonTable", SaveMode.Overwrite)
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable"),
+ df.collect())
+
+ // When the save mode is Ignore, we will do nothing when the table already exists.
+ df.select("b").saveAsTable("savedJsonTable", SaveMode.Ignore)
+ assert(df.schema === table("savedJsonTable").schema)
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable"),
+ df.collect())
+
+ // Drop table will also delete the data.
+ sql("DROP TABLE savedJsonTable")
+ intercept[InvalidInputException] {
+ jsonFile(catalog.hiveDefaultTableFilePath("savedJsonTable"))
+ }
+
+ // Create an external table by specifying the path.
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.saveAsTable(
+ "savedJsonTable",
+ "org.apache.spark.sql.json",
+ SaveMode.Append,
+ Map("path" -> tempPath.toString))
+ checkAnswer(
+ sql("SELECT * FROM savedJsonTable"),
+ df.collect())
+
+ // Data should not be deleted after we drop the table.
+ sql("DROP TABLE savedJsonTable")
+ checkAnswer(
+ jsonFile(tempPath.toString),
+ df.collect())
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+ }
+
+ test("create external table") {
+ val originalDefaultSource = conf.defaultDataSourceName
+
+ val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
+ val df = jsonRDD(rdd)
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ df.saveAsTable(
+ "savedJsonTable",
+ "org.apache.spark.sql.json",
+ SaveMode.Append,
+ Map("path" -> tempPath.toString))
+
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "org.apache.spark.sql.json")
+ createExternalTable("createdJsonTable", tempPath.toString)
assert(table("createdJsonTable").schema === df.schema)
checkAnswer(
sql("SELECT * FROM createdJsonTable"),
df.collect())
- val message = intercept[RuntimeException] {
- createTable("createdJsonTable", filePath.toString, false)
+ var message = intercept[AnalysisException] {
+ createExternalTable("createdJsonTable", filePath.toString)
}.getMessage
assert(message.contains("Table createdJsonTable already exists."),
"We should complain that ctasJsonTable already exists")
- createTable("createdJsonTable", filePath.toString, true)
- // createdJsonTable should be not changed.
- assert(table("createdJsonTable").schema === df.schema)
+ // Data should not be deleted.
+ sql("DROP TABLE createdJsonTable")
checkAnswer(
- sql("SELECT * FROM createdJsonTable"),
+ jsonFile(tempPath.toString),
df.collect())
- conf.setConf("spark.sql.default.datasource", originalDefaultSource)
+ // Try to specify the schema.
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, "not a source name")
+ val schema = StructType(StructField("b", StringType, true) :: Nil)
+ createExternalTable(
+ "createdJsonTable",
+ "org.apache.spark.sql.json",
+ schema,
+ Map("path" -> tempPath.toString))
+ checkAnswer(
+ sql("SELECT * FROM createdJsonTable"),
+ sql("SELECT b FROM savedJsonTable").collect())
+
+ sql("DROP TABLE createdJsonTable")
+
+ message = intercept[RuntimeException] {
+ createExternalTable(
+ "createdJsonTable",
+ "org.apache.spark.sql.json",
+ schema,
+ Map.empty[String, String])
+ }.getMessage
+ assert(
+ message.contains("'path' must be specified for json data."),
+ "We should complain that path is not specified.")
+
+ sql("DROP TABLE savedJsonTable")
+ conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+ }
+
+ if (HiveShim.version == "0.13.1") {
+ test("scan a parquet table created through a CTAS statement") {
+ val originalConvertMetastore = getConf("spark.sql.hive.convertMetastoreParquet", "true")
+ val originalUseDataSource = getConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ setConf("spark.sql.hive.convertMetastoreParquet", "true")
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+
+ val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
+ jsonRDD(rdd).registerTempTable("jt")
+ sql(
+ """
+ |create table test_parquet_ctas STORED AS parquET
+ |AS select tmp.a from jt tmp where tmp.a < 5
+ """.stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT a FROM test_parquet_ctas WHERE a > 2 "),
+ Row(3) :: Row(4) :: Nil
+ )
+
+ table("test_parquet_ctas").queryExecution.analyzed match {
+ case LogicalRelation(p: ParquetRelation2) => // OK
+ case _ =>
+ fail(
+ s"test_parquet_ctas should be converted to ${classOf[ParquetRelation2].getCanonicalName}")
+ }
+
+ // Clenup and reset confs.
+ sql("DROP TABLE IF EXISTS jt")
+ sql("DROP TABLE IF EXISTS test_parquet_ctas")
+ setConf("spark.sql.hive.convertMetastoreParquet", originalConvertMetastore)
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalUseDataSource)
+ }
+ }
+
+ test("SPARK-6024 wide schema support") {
+ // We will need 80 splits for this schema if the threshold is 4000.
+ val schema = StructType((1 to 5000).map(i => StructField(s"c_${i}", StringType, true)))
+ assert(
+ schema.json.size > conf.schemaStringLengthThreshold,
+ "To correctly test the fix of SPARK-6024, the value of " +
+ s"spark.sql.sources.schemaStringLengthThreshold needs to be less than ${schema.json.size}")
+ // Manually create a metastore data source table.
+ catalog.createDataSourceTable(
+ tableName = "wide_schema",
+ userSpecifiedSchema = Some(schema),
+ provider = "json",
+ options = Map("path" -> "just a dummy path"),
+ isExternal = false)
+
+ invalidateTable("wide_schema")
+
+ val actualSchema = table("wide_schema").schema
+ assert(schema === actualSchema)
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala
new file mode 100644
index 0000000000000..d6ddd539d159d
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SerializationSuite.scala
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.sql.hive.test.TestHive
+
+class SerializationSuite extends FunSuite {
+
+ test("[SPARK-5840] HiveContext should be serializable") {
+ val hiveContext = new HiveContext(TestHive.sparkContext)
+ hiveContext.hiveconf
+ new JavaSerializer(new SparkConf()).newInstance().serialize(hiveContext)
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 6f07fd5a879c0..1e05a024b8807 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -127,11 +127,11 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
}
test("estimates the size of a test MetastoreRelation") {
- val rdd = sql("""SELECT * FROM src""")
- val sizes = rdd.queryExecution.analyzed.collect { case mr: MetastoreRelation =>
+ val df = sql("""SELECT * FROM src""")
+ val sizes = df.queryExecution.analyzed.collect { case mr: MetastoreRelation =>
mr.statistics.sizeInBytes
}
- assert(sizes.size === 1, s"Size wrong for:\n ${rdd.queryExecution}")
+ assert(sizes.size === 1, s"Size wrong for:\n ${df.queryExecution}")
assert(sizes(0).equals(BigInt(5812)),
s"expected exact size 5812 for test table 'src', got: ${sizes(0)}")
}
@@ -145,10 +145,10 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
ct: ClassTag[_]) = {
before()
- var rdd = sql(query)
+ var df = sql(query)
// Assert src has a size smaller than the threshold.
- val sizes = rdd.queryExecution.analyzed.collect {
+ val sizes = df.queryExecution.analyzed.collect {
case r if ct.runtimeClass.isAssignableFrom(r.getClass) => r.statistics.sizeInBytes
}
assert(sizes.size === 2 && sizes(0) <= conf.autoBroadcastJoinThreshold
@@ -157,21 +157,21 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
// Using `sparkPlan` because for relevant patterns in HashJoin to be
// matched, other strategies need to be applied.
- var bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j }
+ var bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j }
assert(bhj.size === 1,
- s"actual query plans do not contain broadcast join: ${rdd.queryExecution}")
+ s"actual query plans do not contain broadcast join: ${df.queryExecution}")
- checkAnswer(rdd, expectedAnswer) // check correctness of output
+ checkAnswer(df, expectedAnswer) // check correctness of output
TestHive.conf.settings.synchronized {
val tmp = conf.autoBroadcastJoinThreshold
sql(s"""SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1""")
- rdd = sql(query)
- bhj = rdd.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j }
+ df = sql(query)
+ bhj = df.queryExecution.sparkPlan.collect { case j: BroadcastHashJoin => j }
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
- val shj = rdd.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j }
+ val shj = df.queryExecution.sparkPlan.collect { case j: ShuffledHashJoin => j }
assert(shj.size === 1,
"ShuffledHashJoin should be planned when BroadcastHashJoin is turned off")
@@ -199,10 +199,10 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
|left semi JOIN src b ON a.key=86 and a.key = b.key""".stripMargin
val answer = Row(86, "val_86")
- var rdd = sql(leftSemiJoinQuery)
+ var df = sql(leftSemiJoinQuery)
// Assert src has a size smaller than the threshold.
- val sizes = rdd.queryExecution.analyzed.collect {
+ val sizes = df.queryExecution.analyzed.collect {
case r if implicitly[ClassTag[MetastoreRelation]].runtimeClass
.isAssignableFrom(r.getClass) =>
r.statistics.sizeInBytes
@@ -213,25 +213,25 @@ class StatisticsSuite extends QueryTest with BeforeAndAfterAll {
// Using `sparkPlan` because for relevant patterns in HashJoin to be
// matched, other strategies need to be applied.
- var bhj = rdd.queryExecution.sparkPlan.collect {
+ var bhj = df.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
}
assert(bhj.size === 1,
- s"actual query plans do not contain broadcast join: ${rdd.queryExecution}")
+ s"actual query plans do not contain broadcast join: ${df.queryExecution}")
- checkAnswer(rdd, answer) // check correctness of output
+ checkAnswer(df, answer) // check correctness of output
TestHive.conf.settings.synchronized {
val tmp = conf.autoBroadcastJoinThreshold
sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1")
- rdd = sql(leftSemiJoinQuery)
- bhj = rdd.queryExecution.sparkPlan.collect {
+ df = sql(leftSemiJoinQuery)
+ bhj = df.queryExecution.sparkPlan.collect {
case j: BroadcastLeftSemiJoinHash => j
}
assert(bhj.isEmpty, "BroadcastHashJoin still planned even though it is switched off")
- val shj = rdd.queryExecution.sparkPlan.collect {
+ val shj = df.queryExecution.sparkPlan.collect {
case j: LeftSemiJoinHash => j
}
assert(shj.size === 1,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 27047ce4b1b0b..bb0a67dc03e1d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -29,7 +29,7 @@ import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.spark.{SparkFiles, SparkException}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
@@ -62,7 +62,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
test("SPARK-4908: concurrent hive native commands") {
(1 to 100).par.map { _ =>
sql("USE default")
- sql("SHOW TABLES")
+ sql("SHOW DATABASES")
}
}
@@ -429,7 +429,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
|'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe' WITH SERDEPROPERTIES
|('serialization.last.column.takes.rest'='true') FROM src;
""".stripMargin.replaceAll("\n", " "))
-
+
createQueryTest("LIKE",
"SELECT * FROM src WHERE value LIKE '%1%'")
@@ -567,7 +567,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
TestData(2, "str2") :: Nil)
- testData.registerTempTable("REGisteredTABle")
+ testData.toDF().registerTempTable("REGisteredTABle")
assertResult(Array(Row(2, "str2"))) {
sql("SELECT tablealias.A, TABLEALIAS.b FROM reGisteredTABle TableAlias " +
@@ -583,8 +583,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
test("SPARK-1704: Explain commands as a DataFrame") {
sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
- val rdd = sql("explain select key, count(value) from src group by key")
- assert(isExplanation(rdd))
+ val df = sql("explain select key, count(value) from src group by key")
+ assert(isExplanation(df))
TestHive.reset()
}
@@ -592,7 +592,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
test("SPARK-2180: HAVING support in GROUP BY clauses (positive)") {
val fixture = List(("foo", 2), ("bar", 1), ("foo", 4), ("bar", 3))
.zipWithIndex.map {case Pair(Pair(value, attr), key) => HavingRow(key, value, attr)}
- TestHive.sparkContext.parallelize(fixture).registerTempTable("having_test")
+ TestHive.sparkContext.parallelize(fixture).toDF().registerTempTable("having_test")
val results =
sql("SELECT value, max(attr) AS attr FROM having_test GROUP BY value HAVING attr > 3")
.collect()
@@ -630,24 +630,24 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
test("Query Hive native command execution result") {
- val tableName = "test_native_commands"
+ val databaseName = "test_native_commands"
assertResult(0) {
- sql(s"DROP TABLE IF EXISTS $tableName").count()
+ sql(s"DROP DATABASE IF EXISTS $databaseName").count()
}
assertResult(0) {
- sql(s"CREATE TABLE $tableName(key INT, value STRING)").count()
+ sql(s"CREATE DATABASE $databaseName").count()
}
assert(
- sql("SHOW TABLES")
+ sql("SHOW DATABASES")
.select('result)
.collect()
.map(_.getString(0))
- .contains(tableName))
+ .contains(databaseName))
- assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM $tableName GROUP BY key")))
+ assert(isExplanation(sql(s"EXPLAIN SELECT key, COUNT(*) FROM src GROUP BY key")))
TestHive.reset()
}
@@ -740,7 +740,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
TestHive.sparkContext.parallelize(
TestData(1, "str1") ::
TestData(1, "str2") :: Nil)
- testData.registerTempTable("test_describe_commands2")
+ testData.toDF().registerTempTable("test_describe_commands2")
assertResult(
Array(
@@ -859,6 +859,22 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
}
+ test("SPARK-5592: get java.net.URISyntaxException when dynamic partitioning") {
+ sql("""
+ |create table sc as select *
+ |from (select '2011-01-11', '2011-01-11+14:18:26' from src tablesample (1 rows)
+ |union all
+ |select '2011-01-11', '2011-01-11+15:18:26' from src tablesample (1 rows)
+ |union all
+ |select '2011-01-11', '2011-01-11+16:18:26' from src tablesample (1 rows) ) s
+ """.stripMargin)
+ sql("create table sc_part (key string) partitioned by (ts string) stored as rcfile")
+ sql("set hive.exec.dynamic.partition=true")
+ sql("set hive.exec.dynamic.partition.mode=nonstrict")
+ sql("insert overwrite table sc_part partition(ts) select * from sc")
+ sql("drop table sc_part")
+ }
+
test("Partition spec validation") {
sql("DROP TABLE IF EXISTS dp_test")
sql("CREATE TABLE dp_test(key INT, value STRING) PARTITIONED BY (dp INT, sp INT)")
@@ -884,8 +900,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
test("SPARK-3414 regression: should store analyzed logical plan when registering a temp table") {
- sparkContext.makeRDD(Seq.empty[LogEntry]).registerTempTable("rawLogs")
- sparkContext.makeRDD(Seq.empty[LogFile]).registerTempTable("logFiles")
+ sparkContext.makeRDD(Seq.empty[LogEntry]).toDF().registerTempTable("rawLogs")
+ sparkContext.makeRDD(Seq.empty[LogFile]).toDF().registerTempTable("logFiles")
sql(
"""
@@ -963,8 +979,8 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
val testVal = "test.val.0"
val nonexistentKey = "nonexistent"
val KV = "([^=]+)=([^=]*)".r
- def collectResults(rdd: DataFrame): Set[(String, String)] =
- rdd.collect().map {
+ def collectResults(df: DataFrame): Set[(String, String)] =
+ df.collect().map {
case Row(key: String, value: String) => key -> value
case Row(KV(key, value)) => key -> value
}.toSet
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
index ff8130ae5f6bc..f4440e5b7846a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.hive.execution
+import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.hive.test.TestHive.{sparkContext, jsonRDD, sql}
import org.apache.spark.sql.hive.test.TestHive.implicits._
@@ -40,7 +41,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
"""{"a": [{"b": 1, "B": 2}]}""" :: Nil)).registerTempTable("nested")
// there are 2 filed matching field name "b", we should report Ambiguous reference error
- val exception = intercept[RuntimeException] {
+ val exception = intercept[AnalysisException] {
sql("SELECT a[0].b from nested").queryExecution.analyzed
}
assert(exception.getMessage.contains("Ambiguous reference to fields"))
@@ -76,7 +77,7 @@ class HiveResolutionSuite extends HiveComparisonTest {
test("case insensitivity with scala reflection") {
// Test resolution with Scala Reflection
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
- .registerTempTable("caseSensitivityTest")
+ .toDF().registerTempTable("caseSensitivityTest")
val query = sql("SELECT a, b, A, B, n.a, n.b, n.A, n.B FROM caseSensitivityTest")
assert(query.schema.fields.map(_.name) === Seq("a", "b", "A", "B", "a", "b", "A", "B"),
@@ -87,17 +88,26 @@ class HiveResolutionSuite extends HiveComparisonTest {
ignore("case insensitivity with scala reflection joins") {
// Test resolution with Scala Reflection
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
- .registerTempTable("caseSensitivityTest")
+ .toDF().registerTempTable("caseSensitivityTest")
sql("SELECT * FROM casesensitivitytest a JOIN casesensitivitytest b ON a.a = b.a").collect()
}
test("nested repeated resolution") {
sparkContext.parallelize(Data(1, 2, Nested(1,2), Seq(Nested(1,2))) :: Nil)
- .registerTempTable("nestedRepeatedTest")
+ .toDF().registerTempTable("nestedRepeatedTest")
assert(sql("SELECT nestedArray[0].a FROM nestedRepeatedTest").collect().head(0) === 1)
}
+ createQueryTest("test ambiguousReferences resolved as hive",
+ """
+ |CREATE TABLE t1(x INT);
+ |CREATE TABLE t2(a STRUCT, k INT);
+ |INSERT OVERWRITE TABLE t1 SELECT 1 FROM src LIMIT 1;
+ |INSERT OVERWRITE TABLE t2 SELECT named_struct("x",1),1 FROM src LIMIT 1;
+ |SELECT a.x FROM t1 a JOIN t2 b ON a.x = b.k;
+ """.stripMargin)
+
/**
* Negative examples. Currently only left here for documentation purposes.
* TODO(marmbrus): Test that catalyst fails on these queries.
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
index 8fb5e050a237a..ab53c6309e089 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala
@@ -18,9 +18,10 @@
package org.apache.spark.sql.hive.execution
import org.apache.spark.sql.Row
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.util.Utils
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index 1e99003d3e9b5..cb405f56bf53d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -62,7 +62,7 @@ class HiveUdfSuite extends QueryTest {
| getStruct(1).f5 FROM src LIMIT 1
""".stripMargin).head() === Row(1, 2, 3, 4, 5))
}
-
+
test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") {
checkAnswer(
sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"),
@@ -96,7 +96,7 @@ class HiveUdfSuite extends QueryTest {
test("SPARK-2693 udaf aggregates test") {
checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src").collect().toSeq)
-
+
checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"),
sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq)
}
@@ -104,14 +104,14 @@ class HiveUdfSuite extends QueryTest {
test("Generic UDAF aggregates") {
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)
-
+
checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"),
sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq)
}
-
+
test("UDFIntegerToString") {
val testData = TestHive.sparkContext.parallelize(
- IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil)
+ IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF()
testData.registerTempTable("integerTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '${classOf[UDFIntegerToString].getName}'")
@@ -127,7 +127,7 @@ class HiveUdfSuite extends QueryTest {
val testData = TestHive.sparkContext.parallelize(
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
- ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil)
+ ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF()
testData.registerTempTable("listListIntTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'")
@@ -142,7 +142,7 @@ class HiveUdfSuite extends QueryTest {
test("UDFListString") {
val testData = TestHive.sparkContext.parallelize(
ListStringCaseClass(Seq("a", "b", "c")) ::
- ListStringCaseClass(Seq("d", "e")) :: Nil)
+ ListStringCaseClass(Seq("d", "e")) :: Nil).toDF()
testData.registerTempTable("listStringTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'")
@@ -156,7 +156,7 @@ class HiveUdfSuite extends QueryTest {
test("UDFStringString") {
val testData = TestHive.sparkContext.parallelize(
- StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil)
+ StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF()
testData.registerTempTable("stringTable")
sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'")
@@ -173,7 +173,7 @@ class HiveUdfSuite extends QueryTest {
ListListIntCaseClass(Nil) ::
ListListIntCaseClass(Seq((1, 2, 3))) ::
ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) ::
- Nil)
+ Nil).toDF()
testData.registerTempTable("TwoListTable")
sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 49fe79d989259..f2bc73bf3bdf9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -17,10 +17,15 @@
package org.apache.spark.sql.hive.execution
-import org.apache.spark.sql.hive.HiveShim
+import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries
+import org.apache.spark.sql.hive.{MetastoreRelation, HiveShim}
+import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.parquet.ParquetRelation2
+import org.apache.spark.sql.sources.LogicalRelation
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, Row, SQLConf}
+import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SQLConf}
case class Nested1(f1: Nested2)
case class Nested2(f2: Nested3)
@@ -33,8 +38,6 @@ case class Nested3(f3: Int)
*/
class SQLQuerySuite extends QueryTest {
- import org.apache.spark.sql.hive.test.TestHive.implicits._
-
test("SPARK-4512 Fix attribute reference resolution error when using SORT BY") {
checkAnswer(
sql("SELECT * FROM (SELECT key + key AS a FROM src SORT BY value) t ORDER BY t.a"),
@@ -42,6 +45,73 @@ class SQLQuerySuite extends QueryTest {
)
}
+ test("CTAS without serde") {
+ def checkRelation(tableName: String, isDataSourceParquet: Boolean): Unit = {
+ val relation = EliminateSubQueries(catalog.lookupRelation(Seq(tableName)))
+ relation match {
+ case LogicalRelation(r: ParquetRelation2) =>
+ if (!isDataSourceParquet) {
+ fail(
+ s"${classOf[MetastoreRelation].getCanonicalName} is expected, but found " +
+ s"${ParquetRelation2.getClass.getCanonicalName}.")
+ }
+
+ case r: MetastoreRelation =>
+ if (isDataSourceParquet) {
+ fail(
+ s"${ParquetRelation2.getClass.getCanonicalName} is expected, but found " +
+ s"${classOf[MetastoreRelation].getCanonicalName}.")
+ }
+ }
+ }
+
+ val originalConf = getConf("spark.sql.hive.convertCTAS", "false")
+
+ setConf("spark.sql.hive.convertCTAS", "true")
+
+ sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value")
+ sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value")
+ var message = intercept[AnalysisException] {
+ sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value")
+ }.getMessage
+ assert(message.contains("Table ctas1 already exists"))
+ checkRelation("ctas1", true)
+ sql("DROP TABLE ctas1")
+
+ // Specifying database name for query can be converted to data source write path
+ // is not allowed right now.
+ message = intercept[AnalysisException] {
+ sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value")
+ }.getMessage
+ assert(
+ message.contains("Cannot specify database name in a CTAS statement"),
+ "When spark.sql.hive.convertCTAS is true, we should not allow " +
+ "database name specified.")
+
+ sql("CREATE TABLE ctas1 stored as textfile AS SELECT key k, value FROM src ORDER BY k, value")
+ checkRelation("ctas1", true)
+ sql("DROP TABLE ctas1")
+
+ sql(
+ "CREATE TABLE ctas1 stored as sequencefile AS SELECT key k, value FROM src ORDER BY k, value")
+ checkRelation("ctas1", true)
+ sql("DROP TABLE ctas1")
+
+ sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value")
+ checkRelation("ctas1", false)
+ sql("DROP TABLE ctas1")
+
+ sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value")
+ checkRelation("ctas1", false)
+ sql("DROP TABLE ctas1")
+
+ sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value")
+ checkRelation("ctas1", false)
+ sql("DROP TABLE ctas1")
+
+ setConf("spark.sql.hive.convertCTAS", originalConf)
+ }
+
test("CTAS with serde") {
sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value").collect()
sql(
@@ -174,7 +244,8 @@ class SQLQuerySuite extends QueryTest {
}
test("double nested data") {
- sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil).registerTempTable("nested")
+ sparkContext.parallelize(Nested1(Nested2(Nested3(1))) :: Nil)
+ .toDF().registerTempTable("nested")
checkAnswer(
sql("SELECT f1.f2.f3 FROM nested"),
Row(1))
@@ -184,7 +255,7 @@ class SQLQuerySuite extends QueryTest {
sql("SELECT * FROM test_ctas_1234"),
sql("SELECT * FROM nested").collect().toSeq)
- intercept[org.apache.hadoop.hive.ql.metadata.InvalidTableException] {
+ intercept[AnalysisException] {
sql("CREATE TABLE test_ctas_12345 AS SELECT * from notexists").collect()
}
}
@@ -197,7 +268,7 @@ class SQLQuerySuite extends QueryTest {
}
test("SPARK-4825 save join to table") {
- val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString))
+ val testData = sparkContext.parallelize(1 to 10).map(i => TestData(i, i.toString)).toDF()
sql("CREATE TABLE test1 (key INT, value STRING)")
testData.insertInto("test1")
sql("CREATE TABLE test2 (key INT, value STRING)")
@@ -277,7 +348,7 @@ class SQLQuerySuite extends QueryTest {
val rowRdd = sparkContext.parallelize(row :: Nil)
- applySchema(rowRdd, schema).registerTempTable("testTable")
+ TestHive.createDataFrame(rowRdd, schema).registerTempTable("testTable")
sql(
"""CREATE TABLE nullValuesInInnerComplexTypes
@@ -315,4 +386,34 @@ class SQLQuerySuite extends QueryTest {
dropTempTable("data")
}
+
+ test("logical.Project should not be resolved if it contains aggregates or generators") {
+ // This test is used to test the fix of SPARK-5875.
+ // The original issue was that Project's resolved will be true when it contains
+ // AggregateExpressions or Generators. However, in this case, the Project
+ // is not in a valid state (cannot be executed). Because of this bug, the analysis rule of
+ // PreInsertionCasts will actually start to work before ImplicitGenerate and then
+ // generates an invalid query plan.
+ val rdd = sparkContext.makeRDD((1 to 5).map(i => s"""{"a":[$i, ${i+1}]}"""))
+ jsonRDD(rdd).registerTempTable("data")
+ val originalConf = getConf("spark.sql.hive.convertCTAS", "false")
+ setConf("spark.sql.hive.convertCTAS", "false")
+
+ sql("CREATE TABLE explodeTest (key bigInt)")
+ table("explodeTest").queryExecution.analyzed match {
+ case metastoreRelation: MetastoreRelation => // OK
+ case _ =>
+ fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation")
+ }
+
+ sql(s"INSERT OVERWRITE TABLE explodeTest SELECT explode(a) AS val FROM data")
+ checkAnswer(
+ sql("SELECT key from explodeTest"),
+ (1 to 5).flatMap(i => Row(i) :: Row(i + 1) :: Nil)
+ )
+
+ sql("DROP TABLE explodeTest")
+ dropTempTable("data")
+ setConf("spark.sql.hive.convertCTAS", originalConf)
+ }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
index a7479a5b95864..c8da8eea4e646 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/parquet/parquetSuites.scala
@@ -20,25 +20,42 @@ package org.apache.spark.sql.parquet
import java.io.File
-import org.apache.spark.sql.catalyst.expressions.Row
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.{SQLConf, QueryTest}
-import org.apache.spark.sql.execution.PhysicalRDD
-import org.apache.spark.sql.hive.execution.HiveTableScan
+import org.apache.spark.sql.catalyst.expressions.Row
+import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD}
+import org.apache.spark.sql.hive.execution.{InsertIntoHiveTable, HiveTableScan}
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive.implicits._
+import org.apache.spark.sql.sources.{InsertIntoDataSource, LogicalRelation}
+import org.apache.spark.sql.SaveMode
// The data where the partitioning key exists only in the directory structure.
case class ParquetData(intField: Int, stringField: String)
// The data that also includes the partitioning key
case class ParquetDataWithKey(p: Int, intField: Int, stringField: String)
+case class StructContainer(intStructField :Int, stringStructField: String)
+
+case class ParquetDataWithComplexTypes(
+ intField: Int,
+ stringField: String,
+ structField: StructContainer,
+ arrayField: Seq[Int])
+
+case class ParquetDataWithKeyAndComplexTypes(
+ p: Int,
+ intField: Int,
+ stringField: String,
+ structField: StructContainer,
+ arrayField: Seq[Int])
/**
* A suite to test the automatic conversion of metastore tables with parquet data to use the
* built in parquet support.
*/
-class ParquetMetastoreSuite extends ParquetPartitioningTest {
+class ParquetMetastoreSuiteBase extends ParquetPartitioningTest {
override def beforeAll(): Unit = {
super.beforeAll()
@@ -83,6 +100,38 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
location '${new File(normalTableDir, "normal").getCanonicalPath}'
""")
+ sql(s"""
+ CREATE EXTERNAL TABLE partitioned_parquet_with_complextypes
+ (
+ intField INT,
+ stringField STRING,
+ structField STRUCT,
+ arrayField ARRAY
+ )
+ PARTITIONED BY (p int)
+ ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ STORED AS
+ INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ LOCATION '${partitionedTableDirWithComplexTypes.getCanonicalPath}'
+ """)
+
+ sql(s"""
+ CREATE EXTERNAL TABLE partitioned_parquet_with_key_and_complextypes
+ (
+ intField INT,
+ stringField STRING,
+ structField STRUCT,
+ arrayField ARRAY
+ )
+ PARTITIONED BY (p int)
+ ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ STORED AS
+ INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ LOCATION '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}'
+ """)
+
(1 to 10).foreach { p =>
sql(s"ALTER TABLE partitioned_parquet ADD PARTITION (p=$p)")
}
@@ -91,10 +140,30 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
sql(s"ALTER TABLE partitioned_parquet_with_key ADD PARTITION (p=$p)")
}
+ (1 to 10).foreach { p =>
+ sql(s"ALTER TABLE partitioned_parquet_with_key_and_complextypes ADD PARTITION (p=$p)")
+ }
+
+ (1 to 10).foreach { p =>
+ sql(s"ALTER TABLE partitioned_parquet_with_complextypes ADD PARTITION (p=$p)")
+ }
+
+ val rdd1 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
+ jsonRDD(rdd1).registerTempTable("jt")
+ val rdd2 = sparkContext.parallelize((1 to 10).map(i => s"""{"a":[$i, null]}"""))
+ jsonRDD(rdd2).registerTempTable("jt_array")
+
setConf("spark.sql.hive.convertMetastoreParquet", "true")
}
override def afterAll(): Unit = {
+ sql("DROP TABLE partitioned_parquet")
+ sql("DROP TABLE partitioned_parquet_with_key")
+ sql("DROP TABLE partitioned_parquet_with_complextypes")
+ sql("DROP TABLE partitioned_parquet_with_key_and_complextypes")
+ sql("DROP TABLE normal_parquet")
+ sql("DROP TABLE IF EXISTS jt")
+ sql("DROP TABLE IF EXISTS jt_array")
setConf("spark.sql.hive.convertMetastoreParquet", "false")
}
@@ -111,10 +180,265 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
}
}
+class ParquetDataSourceOnMetastoreSuite extends ParquetMetastoreSuiteBase {
+ val originalConf = conf.parquetUseDataSourceApi
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ sql(
+ """
+ |create table test_parquet
+ |(
+ | intField INT,
+ | stringField STRING
+ |)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ """.stripMargin)
+
+ conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ sql("DROP TABLE IF EXISTS test_parquet")
+
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ }
+
+ test("scan an empty parquet table") {
+ checkAnswer(sql("SELECT count(*) FROM test_parquet"), Row(0))
+ }
+
+ test("scan an empty parquet table with upper case") {
+ checkAnswer(sql("SELECT count(INTFIELD) FROM TEST_parquet"), Row(0))
+ }
+
+ test("insert into an empty parquet table") {
+ sql(
+ """
+ |create table test_insert_parquet
+ |(
+ | intField INT,
+ | stringField STRING
+ |)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ """.stripMargin)
+
+ // Insert into am empty table.
+ sql("insert into table test_insert_parquet select a, b from jt where jt.a > 5")
+ checkAnswer(
+ sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField < 8"),
+ Row(6, "str6") :: Row(7, "str7") :: Nil
+ )
+ // Insert overwrite.
+ sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5")
+ checkAnswer(
+ sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"),
+ Row(3, "str3") :: Row(4, "str4") :: Nil
+ )
+ sql("DROP TABLE IF EXISTS test_insert_parquet")
+
+ // Create it again.
+ sql(
+ """
+ |create table test_insert_parquet
+ |(
+ | intField INT,
+ | stringField STRING
+ |)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ """.stripMargin)
+ // Insert overwrite an empty table.
+ sql("insert overwrite table test_insert_parquet select a, b from jt where jt.a < 5")
+ checkAnswer(
+ sql(s"SELECT intField, stringField FROM test_insert_parquet WHERE intField > 2"),
+ Row(3, "str3") :: Row(4, "str4") :: Nil
+ )
+ // Insert into the table.
+ sql("insert into table test_insert_parquet select a, b from jt")
+ checkAnswer(
+ sql(s"SELECT intField, stringField FROM test_insert_parquet"),
+ (1 to 10).map(i => Row(i, s"str$i")) ++ (1 to 4).map(i => Row(i, s"str$i"))
+ )
+ sql("DROP TABLE IF EXISTS test_insert_parquet")
+ }
+
+ test("scan a parquet table created through a CTAS statement") {
+ sql(
+ """
+ |create table test_parquet_ctas ROW FORMAT
+ |SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ |AS select * from jt
+ """.stripMargin)
+
+ checkAnswer(
+ sql(s"SELECT a, b FROM test_parquet_ctas WHERE a = 1"),
+ Seq(Row(1, "str1"))
+ )
+
+ table("test_parquet_ctas").queryExecution.analyzed match {
+ case LogicalRelation(p: ParquetRelation2) => // OK
+ case _ =>
+ fail(
+ s"test_parquet_ctas should be converted to ${classOf[ParquetRelation2].getCanonicalName}")
+ }
+
+ sql("DROP TABLE IF EXISTS test_parquet_ctas")
+ }
+
+ test("MetastoreRelation in InsertIntoTable will be converted") {
+ sql(
+ """
+ |create table test_insert_parquet
+ |(
+ | intField INT
+ |)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ """.stripMargin)
+
+ val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt")
+ df.queryExecution.executedPlan match {
+ case ExecutedCommand(
+ InsertIntoDataSource(
+ LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK
+ case o => fail("test_insert_parquet should be converted to a " +
+ s"${classOf[ParquetRelation2].getCanonicalName} and " +
+ s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." +
+ s"However, found a ${o.toString} ")
+ }
+
+ checkAnswer(
+ sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"),
+ sql("SELECT a FROM jt WHERE jt.a > 5").collect()
+ )
+
+ sql("DROP TABLE IF EXISTS test_insert_parquet")
+ }
+
+ test("MetastoreRelation in InsertIntoHiveTable will be converted") {
+ sql(
+ """
+ |create table test_insert_parquet
+ |(
+ | int_array array
+ |)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ """.stripMargin)
+
+ val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array")
+ df.queryExecution.executedPlan match {
+ case ExecutedCommand(
+ InsertIntoDataSource(
+ LogicalRelation(r: ParquetRelation2), query, overwrite)) => // OK
+ case o => fail("test_insert_parquet should be converted to a " +
+ s"${classOf[ParquetRelation2].getCanonicalName} and " +
+ s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." +
+ s"However, found a ${o.toString} ")
+ }
+
+ checkAnswer(
+ sql("SELECT int_array FROM test_insert_parquet"),
+ sql("SELECT a FROM jt_array").collect()
+ )
+
+ sql("DROP TABLE IF EXISTS test_insert_parquet")
+ }
+}
+
+class ParquetDataSourceOffMetastoreSuite extends ParquetMetastoreSuiteBase {
+ val originalConf = conf.parquetUseDataSourceApi
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ }
+
+ test("MetastoreRelation in InsertIntoTable will not be converted") {
+ sql(
+ """
+ |create table test_insert_parquet
+ |(
+ | intField INT
+ |)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ """.stripMargin)
+
+ val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt")
+ df.queryExecution.executedPlan match {
+ case insert: InsertIntoHiveTable => // OK
+ case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " +
+ s"However, found ${o.toString}.")
+ }
+
+ checkAnswer(
+ sql("SELECT intField FROM test_insert_parquet WHERE test_insert_parquet.intField > 5"),
+ sql("SELECT a FROM jt WHERE jt.a > 5").collect()
+ )
+
+ sql("DROP TABLE IF EXISTS test_insert_parquet")
+ }
+
+ // TODO: enable it after the fix of SPARK-5950.
+ ignore("MetastoreRelation in InsertIntoHiveTable will not be converted") {
+ sql(
+ """
+ |create table test_insert_parquet
+ |(
+ | int_array array
+ |)
+ |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe'
+ |STORED AS
+ | INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat'
+ | OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat'
+ """.stripMargin)
+
+ val df = sql("INSERT INTO TABLE test_insert_parquet SELECT a FROM jt_array")
+ df.queryExecution.executedPlan match {
+ case insert: InsertIntoHiveTable => // OK
+ case o => fail(s"The SparkPlan should be ${classOf[InsertIntoHiveTable].getCanonicalName}. " +
+ s"However, found ${o.toString}.")
+ }
+
+ checkAnswer(
+ sql("SELECT int_array FROM test_insert_parquet"),
+ sql("SELECT a FROM jt_array").collect()
+ )
+
+ sql("DROP TABLE IF EXISTS test_insert_parquet")
+ }
+}
+
/**
* A suite of tests for the Parquet support through the data sources API.
*/
-class ParquetSourceSuite extends ParquetPartitioningTest {
+class ParquetSourceSuiteBase extends ParquetPartitioningTest {
override def beforeAll(): Unit = {
super.beforeAll()
@@ -141,6 +465,76 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
path '${new File(partitionedTableDir, "p=1").getCanonicalPath}'
)
""")
+
+ sql( s"""
+ CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes
+ USING org.apache.spark.sql.parquet
+ OPTIONS (
+ path '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}'
+ )
+ """)
+
+ sql( s"""
+ CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes
+ USING org.apache.spark.sql.parquet
+ OPTIONS (
+ path '${partitionedTableDirWithComplexTypes.getCanonicalPath}'
+ )
+ """)
+ }
+
+ test("SPARK-6016 make sure to use the latest footers") {
+ sql("drop table if exists spark_6016_fix")
+
+ // Create a DataFrame with two partitions. So, the created table will have two parquet files.
+ val df1 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i}"""), 2))
+ df1.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite)
+ checkAnswer(
+ sql("select * from spark_6016_fix"),
+ (1 to 10).map(i => Row(i))
+ )
+
+ // Create a DataFrame with four partitions. So, the created table will have four parquet files.
+ val df2 = jsonRDD(sparkContext.parallelize((1 to 10).map(i => s"""{"b":$i}"""), 4))
+ df2.saveAsTable("spark_6016_fix", "parquet", SaveMode.Overwrite)
+ // For the bug of SPARK-6016, we are caching two outdated footers for df1. Then,
+ // since the new table has four parquet files, we are trying to read new footers from two files
+ // and then merge metadata in footers of these four (two outdated ones and two latest one),
+ // which will cause an error.
+ checkAnswer(
+ sql("select * from spark_6016_fix"),
+ (1 to 10).map(i => Row(i))
+ )
+
+ sql("drop table spark_6016_fix")
+ }
+}
+
+class ParquetDataSourceOnSourceSuite extends ParquetSourceSuiteBase {
+ val originalConf = conf.parquetUseDataSourceApi
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
+ }
+}
+
+class ParquetDataSourceOffSourceSuite extends ParquetSourceSuiteBase {
+ val originalConf = conf.parquetUseDataSourceApi
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ conf.setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
+ }
+
+ override def afterAll(): Unit = {
+ super.afterAll()
+ setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, originalConf.toString)
}
}
@@ -151,8 +545,8 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
var partitionedTableDir: File = null
var normalTableDir: File = null
var partitionedTableDirWithKey: File = null
-
- import org.apache.spark.sql.hive.test.TestHive.implicits._
+ var partitionedTableDirWithComplexTypes: File = null
+ var partitionedTableDirWithKeyAndComplexTypes: File = null
override def beforeAll(): Unit = {
partitionedTableDir = File.createTempFile("parquettests", "sparksql")
@@ -167,12 +561,14 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
val partDir = new File(partitionedTableDir, s"p=$p")
sparkContext.makeRDD(1 to 10)
.map(i => ParquetData(i, s"part-$p"))
+ .toDF()
.saveAsParquetFile(partDir.getCanonicalPath)
}
sparkContext
.makeRDD(1 to 10)
.map(i => ParquetData(i, s"part-1"))
+ .toDF()
.saveAsParquetFile(new File(normalTableDir, "normal").getCanonicalPath)
partitionedTableDirWithKey = File.createTempFile("parquettests", "sparksql")
@@ -183,111 +579,159 @@ abstract class ParquetPartitioningTest extends QueryTest with BeforeAndAfterAll
val partDir = new File(partitionedTableDirWithKey, s"p=$p")
sparkContext.makeRDD(1 to 10)
.map(i => ParquetDataWithKey(p, i, s"part-$p"))
+ .toDF()
.saveAsParquetFile(partDir.getCanonicalPath)
}
+
+ partitionedTableDirWithKeyAndComplexTypes = File.createTempFile("parquettests", "sparksql")
+ partitionedTableDirWithKeyAndComplexTypes.delete()
+ partitionedTableDirWithKeyAndComplexTypes.mkdir()
+
+ (1 to 10).foreach { p =>
+ val partDir = new File(partitionedTableDirWithKeyAndComplexTypes, s"p=$p")
+ sparkContext.makeRDD(1 to 10).map { i =>
+ ParquetDataWithKeyAndComplexTypes(
+ p, i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i)
+ }.toDF().saveAsParquetFile(partDir.getCanonicalPath)
+ }
+
+ partitionedTableDirWithComplexTypes = File.createTempFile("parquettests", "sparksql")
+ partitionedTableDirWithComplexTypes.delete()
+ partitionedTableDirWithComplexTypes.mkdir()
+
+ (1 to 10).foreach { p =>
+ val partDir = new File(partitionedTableDirWithComplexTypes, s"p=$p")
+ sparkContext.makeRDD(1 to 10).map { i =>
+ ParquetDataWithComplexTypes(i, s"part-$p", StructContainer(i, f"${i}_string"), 1 to i)
+ }.toDF().saveAsParquetFile(partDir.getCanonicalPath)
+ }
}
- def run(prefix: String): Unit = {
- Seq("partitioned_parquet", "partitioned_parquet_with_key").foreach { table =>
- test(s"$prefix: ordering of the partitioning columns $table") {
- checkAnswer(
- sql(s"SELECT p, stringField FROM $table WHERE p = 1"),
- Seq.fill(10)(Row(1, "part-1"))
- )
-
- checkAnswer(
- sql(s"SELECT stringField, p FROM $table WHERE p = 1"),
- Seq.fill(10)(Row("part-1", 1))
- )
- }
-
- test(s"$prefix: project the partitioning column $table") {
- checkAnswer(
- sql(s"SELECT p, count(*) FROM $table group by p"),
- Row(1, 10) ::
- Row(2, 10) ::
- Row(3, 10) ::
- Row(4, 10) ::
- Row(5, 10) ::
- Row(6, 10) ::
- Row(7, 10) ::
- Row(8, 10) ::
- Row(9, 10) ::
- Row(10, 10) :: Nil
- )
- }
-
- test(s"$prefix: project partitioning and non-partitioning columns $table") {
- checkAnswer(
- sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"),
- Row("part-1", 1, 10) ::
- Row("part-2", 2, 10) ::
- Row("part-3", 3, 10) ::
- Row("part-4", 4, 10) ::
- Row("part-5", 5, 10) ::
- Row("part-6", 6, 10) ::
- Row("part-7", 7, 10) ::
- Row("part-8", 8, 10) ::
- Row("part-9", 9, 10) ::
- Row("part-10", 10, 10) :: Nil
- )
- }
-
- test(s"$prefix: simple count $table") {
- checkAnswer(
- sql(s"SELECT COUNT(*) FROM $table"),
- Row(100))
- }
-
- test(s"$prefix: pruned count $table") {
- checkAnswer(
- sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"),
- Row(10))
- }
-
- test(s"$prefix: non-existent partition $table") {
- checkAnswer(
- sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"),
- Row(0))
- }
-
- test(s"$prefix: multi-partition pruned count $table") {
- checkAnswer(
- sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"),
- Row(30))
- }
-
- test(s"$prefix: non-partition predicates $table") {
- checkAnswer(
- sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"),
- Row(30))
- }
-
- test(s"$prefix: sum $table") {
- checkAnswer(
- sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"),
- Row(1 + 2 + 3))
- }
-
- test(s"$prefix: hive udfs $table") {
- checkAnswer(
- sql(s"SELECT concat(stringField, stringField) FROM $table"),
- sql(s"SELECT stringField FROM $table").map {
- case Row(s: String) => Row(s + s)
- }.collect().toSeq)
- }
+ override protected def afterAll(): Unit = {
+ partitionedTableDir.delete()
+ normalTableDir.delete()
+ partitionedTableDirWithKey.delete()
+ partitionedTableDirWithComplexTypes.delete()
+ partitionedTableDirWithKeyAndComplexTypes.delete()
+ }
+
+ Seq(
+ "partitioned_parquet",
+ "partitioned_parquet_with_key",
+ "partitioned_parquet_with_complextypes",
+ "partitioned_parquet_with_key_and_complextypes").foreach { table =>
+
+ test(s"ordering of the partitioning columns $table") {
+ checkAnswer(
+ sql(s"SELECT p, stringField FROM $table WHERE p = 1"),
+ Seq.fill(10)(Row(1, "part-1"))
+ )
+
+ checkAnswer(
+ sql(s"SELECT stringField, p FROM $table WHERE p = 1"),
+ Seq.fill(10)(Row("part-1", 1))
+ )
+ }
+
+ test(s"project the partitioning column $table") {
+ checkAnswer(
+ sql(s"SELECT p, count(*) FROM $table group by p"),
+ Row(1, 10) ::
+ Row(2, 10) ::
+ Row(3, 10) ::
+ Row(4, 10) ::
+ Row(5, 10) ::
+ Row(6, 10) ::
+ Row(7, 10) ::
+ Row(8, 10) ::
+ Row(9, 10) ::
+ Row(10, 10) :: Nil
+ )
}
- test(s"$prefix: $prefix: non-part select(*)") {
+ test(s"project partitioning and non-partitioning columns $table") {
checkAnswer(
- sql("SELECT COUNT(*) FROM normal_parquet"),
+ sql(s"SELECT stringField, p, count(intField) FROM $table GROUP BY p, stringField"),
+ Row("part-1", 1, 10) ::
+ Row("part-2", 2, 10) ::
+ Row("part-3", 3, 10) ::
+ Row("part-4", 4, 10) ::
+ Row("part-5", 5, 10) ::
+ Row("part-6", 6, 10) ::
+ Row("part-7", 7, 10) ::
+ Row("part-8", 8, 10) ::
+ Row("part-9", 9, 10) ::
+ Row("part-10", 10, 10) :: Nil
+ )
+ }
+
+ test(s"simple count $table") {
+ checkAnswer(
+ sql(s"SELECT COUNT(*) FROM $table"),
+ Row(100))
+ }
+
+ test(s"pruned count $table") {
+ checkAnswer(
+ sql(s"SELECT COUNT(*) FROM $table WHERE p = 1"),
Row(10))
}
+
+ test(s"non-existent partition $table") {
+ checkAnswer(
+ sql(s"SELECT COUNT(*) FROM $table WHERE p = 1000"),
+ Row(0))
+ }
+
+ test(s"multi-partition pruned count $table") {
+ checkAnswer(
+ sql(s"SELECT COUNT(*) FROM $table WHERE p IN (1,2,3)"),
+ Row(30))
+ }
+
+ test(s"non-partition predicates $table") {
+ checkAnswer(
+ sql(s"SELECT COUNT(*) FROM $table WHERE intField IN (1,2,3)"),
+ Row(30))
+ }
+
+ test(s"sum $table") {
+ checkAnswer(
+ sql(s"SELECT SUM(intField) FROM $table WHERE intField IN (1,2,3) AND p = 1"),
+ Row(1 + 2 + 3))
+ }
+
+ test(s"hive udfs $table") {
+ checkAnswer(
+ sql(s"SELECT concat(stringField, stringField) FROM $table"),
+ sql(s"SELECT stringField FROM $table").map {
+ case Row(s: String) => Row(s + s)
+ }.collect().toSeq)
+ }
+ }
+
+ Seq(
+ "partitioned_parquet_with_key_and_complextypes",
+ "partitioned_parquet_with_complextypes").foreach { table =>
+
+ test(s"SPARK-5775 read struct from $table") {
+ checkAnswer(
+ sql(s"SELECT p, structField.intStructField, structField.stringStructField FROM $table WHERE p = 1"),
+ (1 to 10).map(i => Row(1, i, f"${i}_string")))
+ }
+
+ // Re-enable this after SPARK-5508 is fixed
+ ignore(s"SPARK-5775 read array from $table") {
+ checkAnswer(
+ sql(s"SELECT arrayField, p FROM $table WHERE p = 1"),
+ (1 to 10).map(i => Row(1 to i, 1)))
+ }
}
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "false")
- run("Parquet data source enabled")
- setConf(SQLConf.PARQUET_USE_DATA_SOURCE_API, "true")
- run("Parquet data source disabled")
+ test("non-part select(*)") {
+ checkAnswer(
+ sql("SELECT COUNT(*) FROM normal_parquet"),
+ Row(10))
+ }
}
diff --git a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
index b5a0754ff61f9..30646ddbc29d8 100644
--- a/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
+++ b/sql/hive/v0.12.0/src/main/scala/org/apache/spark/sql/hive/Shim12.scala
@@ -43,7 +43,9 @@ import org.apache.hadoop.mapred.InputFormat
import org.apache.spark.sql.types.{Decimal, DecimalType}
-case class HiveFunctionWrapper(functionClassName: String) extends java.io.Serializable {
+private[hive] case class HiveFunctionWrapper(functionClassName: String)
+ extends java.io.Serializable {
+
// for Serialization
def this() = this(null)
@@ -245,8 +247,13 @@ private[hive] object HiveShim {
def prepareWritable(w: Writable): Writable = {
w
}
+
+ def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = {}
}
-class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean)
+private[hive] class ShimFileSinkDesc(
+ var dir: String,
+ var tableInfo: TableDesc,
+ var compressed: Boolean)
extends FileSinkDesc(dir, tableInfo, compressed) {
}
diff --git a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
index e4c1809c8bb21..f9fcbdae15745 100644
--- a/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
+++ b/sql/hive/v0.13.1/src/main/scala/org/apache/spark/sql/hive/Shim13.scala
@@ -35,6 +35,7 @@ import org.apache.hadoop.hive.ql.Context
import org.apache.hadoop.hive.ql.metadata.{Table, Hive, Partition}
import org.apache.hadoop.hive.ql.plan.{CreateTableDesc, FileSinkDesc, TableDesc}
import org.apache.hadoop.hive.ql.processors.CommandProcessorFactory
+import org.apache.hadoop.hive.serde.serdeConstants
import org.apache.hadoop.hive.serde2.typeinfo.{TypeInfo, DecimalTypeInfo, TypeInfoFactory}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.{HiveDecimalObjectInspector, PrimitiveObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.{PrimitiveObjectInspector, ObjectInspector}
@@ -55,7 +56,9 @@ import org.apache.spark.sql.types.{Decimal, DecimalType}
*
* @param functionClassName UDF class name
*/
-case class HiveFunctionWrapper(var functionClassName: String) extends java.io.Externalizable {
+private[hive] case class HiveFunctionWrapper(var functionClassName: String)
+ extends java.io.Externalizable {
+
// for Serialization
def this() = this(null)
@@ -410,13 +413,22 @@ private[hive] object HiveShim {
}
w
}
+
+ def setTblNullFormat(crtTbl: CreateTableDesc, tbl: Table) = {
+ if (crtTbl != null && crtTbl.getNullFormat() != null) {
+ tbl.setSerdeParam(serdeConstants.SERIALIZATION_NULL_FORMAT, crtTbl.getNullFormat())
+ }
+ }
}
/*
* Bug introduced in hive-0.13. FileSinkDesc is serilizable, but its member path is not.
* Fix it through wrapper.
*/
-class ShimFileSinkDesc(var dir: String, var tableInfo: TableDesc, var compressed: Boolean)
+private[hive] class ShimFileSinkDesc(
+ var dir: String,
+ var tableInfo: TableDesc,
+ var compressed: Boolean)
extends Serializable with Logging {
var compressCodec: String = _
var compressType: String = _
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
index b780282bdac37..f88a8a0151550 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
@@ -152,7 +152,7 @@ class CheckpointWriter(
// Delete old checkpoint files
val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, fs)
- if (allCheckpointFiles.size > 4) {
+ if (allCheckpointFiles.size > 10) {
allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => {
logInfo("Deleting " + file)
fs.delete(file, true)
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index de124cf40eff1..bd01789b611a4 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -726,7 +726,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
* generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
*/
- def saveAsHadoopFiles[F <: OutputFormat[K, V]](prefix: String, suffix: String) {
+ def saveAsHadoopFiles(prefix: String, suffix: String) {
dstream.saveAsHadoopFiles(prefix, suffix)
}
@@ -734,12 +734,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
* generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
*/
- def saveAsHadoopFiles(
+ def saveAsHadoopFiles[F <: OutputFormat[_, _]](
prefix: String,
suffix: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[_ <: OutputFormat[_, _]]) {
+ outputFormatClass: Class[F]) {
dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass)
}
@@ -747,12 +747,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
* generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
*/
- def saveAsHadoopFiles(
+ def saveAsHadoopFiles[F <: OutputFormat[_, _]](
prefix: String,
suffix: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[_ <: OutputFormat[_, _]],
+ outputFormatClass: Class[F],
conf: JobConf) {
dstream.saveAsHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf)
}
@@ -761,7 +761,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
* generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
*/
- def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[K, V]](prefix: String, suffix: String) {
+ def saveAsNewAPIHadoopFiles(prefix: String, suffix: String) {
dstream.saveAsNewAPIHadoopFiles(prefix, suffix)
}
@@ -769,12 +769,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
* generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
*/
- def saveAsNewAPIHadoopFiles(
+ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[_, _]](
prefix: String,
suffix: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[_ <: NewOutputFormat[_, _]]) {
+ outputFormatClass: Class[F]) {
dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass)
}
@@ -782,12 +782,12 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
* Save each RDD in `this` DStream as a Hadoop file. The file name at each batch interval is
* generated based on `prefix` and `suffix`: "prefix-TIME_IN_MS.suffix".
*/
- def saveAsNewAPIHadoopFiles(
+ def saveAsNewAPIHadoopFiles[F <: NewOutputFormat[_, _]](
prefix: String,
suffix: String,
keyClass: Class[_],
valueClass: Class[_],
- outputFormatClass: Class[_ <: NewOutputFormat[_, _]],
+ outputFormatClass: Class[F],
conf: Configuration = new Configuration) {
dstream.saveAsNewAPIHadoopFiles(prefix, suffix, keyClass, valueClass, outputFormatClass, conf)
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
index 6379b88527ec8..22de8c02e63c8 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala
@@ -18,7 +18,6 @@
package org.apache.spark.streaming.dstream
import java.io.{IOException, ObjectInputStream}
-import java.util.concurrent.ConcurrentHashMap
import scala.collection.mutable
import scala.reflect.ClassTag
@@ -27,6 +26,7 @@ import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path, PathFilter}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
+import org.apache.spark.SerializableWritable
import org.apache.spark.rdd.{RDD, UnionRDD}
import org.apache.spark.streaming._
import org.apache.spark.util.{TimeStampedHashMap, Utils}
@@ -78,6 +78,8 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]](
(implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F])
extends InputDStream[(K, V)](ssc_) {
+ private val serializableConfOpt = conf.map(new SerializableWritable(_))
+
// This is a def so that it works during checkpoint recovery:
private def clock = ssc.scheduler.clock
@@ -86,7 +88,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]](
// Initial ignore threshold based on which old, existing files in the directory (at the time of
// starting the streaming application) will be ignored or considered
- private val initialModTimeIgnoreThreshold = if (newFilesOnly) clock.currentTime() else 0L
+ private val initialModTimeIgnoreThreshold = if (newFilesOnly) clock.getTimeMillis() else 0L
/*
* Make sure that the information of files selected in the last few batches are remembered.
@@ -159,7 +161,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]](
*/
private def findNewFiles(currentTime: Long): Array[String] = {
try {
- lastNewFileFindingTime = clock.currentTime()
+ lastNewFileFindingTime = clock.getTimeMillis()
// Calculate ignore threshold
val modTimeIgnoreThreshold = math.max(
@@ -172,7 +174,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]](
def accept(path: Path): Boolean = isNewFile(path, currentTime, modTimeIgnoreThreshold)
}
val newFiles = fs.listStatus(directoryPath, filter).map(_.getPath.toString)
- val timeTaken = clock.currentTime() - lastNewFileFindingTime
+ val timeTaken = clock.getTimeMillis() - lastNewFileFindingTime
logInfo("Finding new files took " + timeTaken + " ms")
logDebug("# cached file times = " + fileToModTime.size)
if (timeTaken > slideDuration.milliseconds) {
@@ -240,7 +242,7 @@ class FileInputDStream[K, V, F <: NewInputFormat[K,V]](
/** Generate one RDD from an array of files */
private def filesToRDD(files: Seq[String]): RDD[(K, V)] = {
val fileRDDs = files.map(file =>{
- val rdd = conf match {
+ val rdd = serializableConfOpt.map(_.value) match {
case Some(config) => context.sparkContext.newAPIHadoopFile(
file,
fm.runtimeClass.asInstanceOf[Class[F]],
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
index 79263a7183977..ee5e639b26d91 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala
@@ -23,7 +23,8 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.storage.StreamBlockId
-import org.apache.spark.streaming.util.{RecurringTimer, SystemClock}
+import org.apache.spark.streaming.util.RecurringTimer
+import org.apache.spark.util.SystemClock
/** Listener object for BlockGenerator events */
private[streaming] trait BlockGeneratorListener {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
index f7a8ebee8a544..dcdc27d29c270 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala
@@ -27,8 +27,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.storage._
-import org.apache.spark.streaming.util.{Clock, SystemClock, WriteAheadLogFileSegment, WriteAheadLogManager}
-import org.apache.spark.util.Utils
+import org.apache.spark.streaming.util.{WriteAheadLogFileSegment, WriteAheadLogManager}
+import org.apache.spark.util.{Clock, SystemClock, Utils}
/** Trait that represents the metadata related to storage of blocks */
private[streaming] trait ReceivedBlockStoreResult {
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
index 8632c94349bf9..ac92774a38273 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
@@ -23,7 +23,8 @@ import akka.actor.{ActorRef, Props, Actor}
import org.apache.spark.{SparkEnv, Logging}
import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time}
-import org.apache.spark.streaming.util.{Clock, ManualClock, RecurringTimer}
+import org.apache.spark.streaming.util.RecurringTimer
+import org.apache.spark.util.{Clock, ManualClock}
/** Event classes for JobGenerator */
private[scheduler] sealed trait JobGeneratorEvent
@@ -45,8 +46,14 @@ class JobGenerator(jobScheduler: JobScheduler) extends Logging {
val clock = {
val clockClass = ssc.sc.conf.get(
- "spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
- Class.forName(clockClass).newInstance().asInstanceOf[Clock]
+ "spark.streaming.clock", "org.apache.spark.util.SystemClock")
+ try {
+ Class.forName(clockClass).newInstance().asInstanceOf[Clock]
+ } catch {
+ case e: ClassNotFoundException if clockClass.startsWith("org.apache.spark.streaming") =>
+ val newClockClass = clockClass.replace("org.apache.spark.streaming", "org.apache.spark")
+ Class.forName(newClockClass).newInstance().asInstanceOf[Clock]
+ }
}
private val timer = new RecurringTimer(clock, ssc.graph.batchDuration.milliseconds,
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
index e19ac939f9ac5..200cf4ef4b0f1 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala
@@ -27,8 +27,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.{SparkException, Logging, SparkConf}
import org.apache.spark.streaming.Time
-import org.apache.spark.streaming.util.{Clock, WriteAheadLogManager}
-import org.apache.spark.util.Utils
+import org.apache.spark.streaming.util.WriteAheadLogManager
+import org.apache.spark.util.{Clock, Utils}
/** Trait representing any event in the ReceivedBlockTracker that updates its state. */
private[streaming] sealed trait ReceivedBlockTrackerLogEvent
@@ -150,7 +150,7 @@ private[streaming] class ReceivedBlockTracker(
* returns only after the files are cleaned up.
*/
def cleanupOldBatches(cleanupThreshTime: Time, waitForCompletion: Boolean): Unit = synchronized {
- assert(cleanupThreshTime.milliseconds < clock.currentTime())
+ assert(cleanupThreshTime.milliseconds < clock.getTimeMillis())
val timesToCleanup = timeToAllocatedBlocks.keys.filter { _ < cleanupThreshTime }.toSeq
logInfo("Deleting batches " + timesToCleanup)
writeToLog(BatchCleanupEvent(timesToCleanup))
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
deleted file mode 100644
index d6d96d7ba00fd..0000000000000
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/Clock.scala
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.streaming.util
-
-private[streaming]
-trait Clock {
- def currentTime(): Long
- def waitTillTime(targetTime: Long): Long
-}
-
-private[streaming]
-class SystemClock() extends Clock {
-
- val minPollTime = 25L
-
- def currentTime(): Long = {
- System.currentTimeMillis()
- }
-
- def waitTillTime(targetTime: Long): Long = {
- var currentTime = 0L
- currentTime = System.currentTimeMillis()
-
- var waitTime = targetTime - currentTime
- if (waitTime <= 0) {
- return currentTime
- }
-
- val pollTime = math.max(waitTime / 10.0, minPollTime).toLong
-
- while (true) {
- currentTime = System.currentTimeMillis()
- waitTime = targetTime - currentTime
- if (waitTime <= 0) {
- return currentTime
- }
- val sleepTime = math.min(waitTime, pollTime)
- Thread.sleep(sleepTime)
- }
- -1
- }
-}
-
-private[streaming]
-class ManualClock() extends Clock {
-
- private var time = 0L
-
- def currentTime() = this.synchronized {
- time
- }
-
- def setTime(timeToSet: Long) = {
- this.synchronized {
- time = timeToSet
- this.notifyAll()
- }
- }
-
- def addToTime(timeToAdd: Long) = {
- this.synchronized {
- time += timeToAdd
- this.notifyAll()
- }
- }
- def waitTillTime(targetTime: Long): Long = {
- this.synchronized {
- while (time < targetTime) {
- this.wait(100)
- }
- }
- currentTime()
- }
-}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
index 1a616a0434f2c..c8eef833eb431 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RecurringTimer.scala
@@ -18,6 +18,7 @@
package org.apache.spark.streaming.util
import org.apache.spark.Logging
+import org.apache.spark.util.{Clock, SystemClock}
private[streaming]
class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name: String)
@@ -38,7 +39,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name:
* current system time.
*/
def getStartTime(): Long = {
- (math.floor(clock.currentTime.toDouble / period) + 1).toLong * period
+ (math.floor(clock.getTimeMillis().toDouble / period) + 1).toLong * period
}
/**
@@ -48,7 +49,7 @@ class RecurringTimer(clock: Clock, period: Long, callback: (Long) => Unit, name:
* more than current time.
*/
def getRestartTime(originalStartTime: Long): Long = {
- val gap = clock.currentTime - originalStartTime
+ val gap = clock.getTimeMillis() - originalStartTime
(math.floor(gap.toDouble / period).toLong + 1) * period + originalStartTime
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala
index 166661b7496df..985ded9111f74 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogManager.scala
@@ -19,13 +19,12 @@ package org.apache.spark.streaming.util
import java.nio.ByteBuffer
import scala.collection.mutable.ArrayBuffer
-import scala.concurrent.duration.Duration
import scala.concurrent.{Await, ExecutionContext, Future}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
import org.apache.spark.Logging
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{Clock, SystemClock, Utils}
import WriteAheadLogManager._
/**
@@ -82,7 +81,7 @@ private[streaming] class WriteAheadLogManager(
var succeeded = false
while (!succeeded && failures < maxFailures) {
try {
- fileSegment = getLogWriter(clock.currentTime).write(byteBuffer)
+ fileSegment = getLogWriter(clock.getTimeMillis()).write(byteBuffer)
succeeded = true
} catch {
case ex: Exception =>
diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
index 2df8cf6a8a3df..57302ff407183 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java
@@ -1828,4 +1828,22 @@ private List> fileTestPrepare(File testDir) throws IOException {
return expected;
}
+
+ // SPARK-5795: no logic assertions, just testing that intended API invocations compile
+ private void compileSaveAsJavaAPI(JavaPairDStream pds) {
+ pds.saveAsNewAPIHadoopFiles(
+ "", "", LongWritable.class, Text.class,
+ org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class);
+ pds.saveAsHadoopFiles(
+ "", "", LongWritable.class, Text.class,
+ org.apache.hadoop.mapred.SequenceFileOutputFormat.class);
+ // Checks that a previous common workaround for this API still compiles
+ pds.saveAsNewAPIHadoopFiles(
+ "", "", LongWritable.class, Text.class,
+ (Class) org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat.class);
+ pds.saveAsHadoopFiles(
+ "", "", LongWritable.class, Text.class,
+ (Class) org.apache.hadoop.mapred.SequenceFileOutputFormat.class);
+ }
+
}
diff --git a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
index 1e24da7f5f60c..cfedb5a042a35 100644
--- a/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
+++ b/streaming/src/test/java/org/apache/spark/streaming/LocalJavaStreamingContext.java
@@ -31,7 +31,7 @@ public void setUp() {
SparkConf conf = new SparkConf()
.setMaster("local[2]")
.setAppName("test")
- .set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock");
+ .set("spark.streaming.clock", "org.apache.spark.util.ManualClock");
ssc = new JavaStreamingContext(conf, new Duration(1000));
ssc.checkpoint("checkpoint");
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
index e8f4a7779ec21..cf191715d29d6 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
@@ -22,13 +22,12 @@ import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer}
import scala.language.existentials
import scala.reflect.ClassTag
-import util.ManualClock
-
import org.apache.spark.{SparkConf, SparkException}
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.{BlockRDD, RDD}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.streaming.dstream.{DStream, WindowedDStream}
+import org.apache.spark.util.{Clock, ManualClock}
import org.apache.spark.HashPartitioner
class BasicOperationsSuite extends TestSuiteBase {
@@ -586,7 +585,7 @@ class BasicOperationsSuite extends TestSuiteBase {
for (i <- 0 until input.size) {
testServer.send(input(i).toString + "\n")
Thread.sleep(200)
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
collectRddInfo()
}
@@ -637,8 +636,8 @@ class BasicOperationsSuite extends TestSuiteBase {
ssc.graph.getOutputStreams().head.dependencies.head.asInstanceOf[DStream[T]]
if (rememberDuration != null) ssc.remember(rememberDuration)
val output = runStreams[(Int, Int)](ssc, cleanupTestInput.size, numExpectedOutput)
- val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- assert(clock.currentTime() === Seconds(10).milliseconds)
+ val clock = ssc.scheduler.clock.asInstanceOf[Clock]
+ assert(clock.getTimeMillis() === Seconds(10).milliseconds)
assert(output.size === numExpectedOutput)
operatedStream
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
index 8f8bc61437ba5..03c448f1df5f1 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
@@ -32,8 +32,7 @@ import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutput
import org.scalatest.concurrent.Eventually._
import org.apache.spark.streaming.dstream.{DStream, FileInputDStream}
-import org.apache.spark.streaming.util.ManualClock
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{Clock, ManualClock, Utils}
/**
* This test suites tests the checkpointing functionality of DStreams -
@@ -61,7 +60,7 @@ class CheckpointSuite extends TestSuiteBase {
assert(batchDuration === Milliseconds(500), "batchDuration for this test must be 1 second")
- conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+ conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock")
val stateStreamCheckpointInterval = Seconds(1)
val fs = FileSystem.getLocal(new Configuration())
@@ -324,13 +323,13 @@ class CheckpointSuite extends TestSuiteBase {
* Writes a file named `i` (which contains the number `i`) to the test directory and sets its
* modification time to `clock`'s current time.
*/
- def writeFile(i: Int, clock: ManualClock): Unit = {
+ def writeFile(i: Int, clock: Clock): Unit = {
val file = new File(testDir, i.toString)
Files.write(i + "\n", file, Charsets.UTF_8)
- assert(file.setLastModified(clock.currentTime()))
+ assert(file.setLastModified(clock.getTimeMillis()))
// Check that the file's modification date is actually the value we wrote, since rounding or
// truncation will break the test:
- assert(file.lastModified() === clock.currentTime())
+ assert(file.lastModified() === clock.getTimeMillis())
}
/**
@@ -372,13 +371,13 @@ class CheckpointSuite extends TestSuiteBase {
ssc.start()
// Advance half a batch so that the first file is created after the StreamingContext starts
- clock.addToTime(batchDuration.milliseconds / 2)
+ clock.advance(batchDuration.milliseconds / 2)
// Create files and advance manual clock to process them
for (i <- Seq(1, 2, 3)) {
writeFile(i, clock)
// Advance the clock after creating the file to avoid a race when
// setting its modification time
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
if (i != 3) {
// Since we want to shut down while the 3rd batch is processing
eventually(eventuallyTimeout) {
@@ -386,7 +385,7 @@ class CheckpointSuite extends TestSuiteBase {
}
}
}
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
eventually(eventuallyTimeout) {
// Wait until all files have been recorded and all batches have started
assert(recordedFiles(ssc) === Seq(1, 2, 3) && batchCounter.getNumStartedBatches === 3)
@@ -410,7 +409,7 @@ class CheckpointSuite extends TestSuiteBase {
writeFile(i, clock)
// Advance the clock after creating the file to avoid a race when
// setting its modification time
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
}
// Recover context from checkpoint file and verify whether the files that were
@@ -419,7 +418,7 @@ class CheckpointSuite extends TestSuiteBase {
withStreamingContext(new StreamingContext(checkpointDir)) { ssc =>
// So that the restarted StreamingContext's clock has gone forward in time since failure
ssc.conf.set("spark.streaming.manualClock.jump", (batchDuration * 3).milliseconds.toString)
- val oldClockTime = clock.currentTime()
+ val oldClockTime = clock.getTimeMillis()
clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
val batchCounter = new BatchCounter(ssc)
val outputStream = ssc.graph.getOutputStreams().head.asInstanceOf[TestOutputStream[Int]]
@@ -430,7 +429,7 @@ class CheckpointSuite extends TestSuiteBase {
ssc.start()
// Verify that the clock has traveled forward to the expected time
eventually(eventuallyTimeout) {
- clock.currentTime() === oldClockTime
+ clock.getTimeMillis() === oldClockTime
}
// Wait for pre-failure batch to be recomputed (3 while SSC was down plus last batch)
val numBatchesAfterRestart = 4
@@ -441,12 +440,12 @@ class CheckpointSuite extends TestSuiteBase {
writeFile(i, clock)
// Advance the clock after creating the file to avoid a race when
// setting its modification time
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
eventually(eventuallyTimeout) {
assert(batchCounter.getNumCompletedBatches === index + numBatchesAfterRestart + 1)
}
}
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
logInfo("Output after restart = " + outputStream.output.mkString("[", ", ", "]"))
assert(outputStream.output.size > 0, "No files processed after restart")
ssc.stop()
@@ -521,12 +520,12 @@ class CheckpointSuite extends TestSuiteBase {
*/
def advanceTimeWithRealDelay[V: ClassTag](ssc: StreamingContext, numBatches: Long): Seq[Seq[V]] = {
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- logInfo("Manual clock before advancing = " + clock.currentTime())
+ logInfo("Manual clock before advancing = " + clock.getTimeMillis())
for (i <- 1 to numBatches.toInt) {
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
Thread.sleep(batchDuration.milliseconds)
}
- logInfo("Manual clock after advancing = " + clock.currentTime())
+ logInfo("Manual clock after advancing = " + clock.getTimeMillis())
Thread.sleep(batchDuration.milliseconds)
val outputStream = ssc.graph.getOutputStreams.filter { dstream =>
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
index 01084a457db4f..7ed6320a3d0bc 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
@@ -17,12 +17,8 @@
package org.apache.spark.streaming
-import akka.actor.Actor
-import akka.actor.Props
-import akka.util.ByteString
-
import java.io.{File, BufferedWriter, OutputStreamWriter}
-import java.net.{InetSocketAddress, SocketException, ServerSocket}
+import java.net.{SocketException, ServerSocket}
import java.nio.charset.Charset
import java.util.concurrent.{Executors, TimeUnit, ArrayBlockingQueue}
import java.util.concurrent.atomic.AtomicInteger
@@ -36,9 +32,8 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.Logging
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.util.ManualClock
-import org.apache.spark.util.Utils
-import org.apache.spark.streaming.receiver.{ActorHelper, Receiver}
+import org.apache.spark.util.{ManualClock, Utils}
+import org.apache.spark.streaming.receiver.Receiver
import org.apache.spark.rdd.RDD
import org.apache.hadoop.io.{Text, LongWritable}
import org.apache.hadoop.mapreduce.lib.input.TextInputFormat
@@ -69,7 +64,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
for (i <- 0 until input.size) {
testServer.send(input(i).toString + "\n")
Thread.sleep(500)
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
}
Thread.sleep(1000)
logInfo("Stopping server")
@@ -120,19 +115,19 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
// Advance the clock so that the files are created after StreamingContext starts, but
// not enough to trigger a batch
- clock.addToTime(batchDuration.milliseconds / 2)
+ clock.advance(batchDuration.milliseconds / 2)
val input = Seq(1, 2, 3, 4, 5)
input.foreach { i =>
Thread.sleep(batchDuration.milliseconds)
val file = new File(testDir, i.toString)
Files.write(Array[Byte](i.toByte), file)
- assert(file.setLastModified(clock.currentTime()))
- assert(file.lastModified === clock.currentTime)
+ assert(file.setLastModified(clock.getTimeMillis()))
+ assert(file.lastModified === clock.getTimeMillis())
logInfo("Created file " + file)
// Advance the clock after creating the file to avoid a race when
// setting its modification time
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
eventually(eventuallyTimeout) {
assert(batchCounter.getNumCompletedBatches === i)
}
@@ -179,7 +174,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
while((!MultiThreadTestReceiver.haveAllThreadsFinished || output.sum < numTotalRecords) &&
System.currentTimeMillis() - startTime < 5000) {
Thread.sleep(100)
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
}
Thread.sleep(1000)
logInfo("Stopping context")
@@ -214,7 +209,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
for (i <- 0 until input.size) {
// Enqueue more than 1 item per tick but they should dequeue one at a time
inputIterator.take(2).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i)))
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
}
Thread.sleep(1000)
logInfo("Stopping context")
@@ -256,12 +251,12 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
// Enqueue the first 3 items (one by one), they should be merged in the next batch
val inputIterator = input.toIterator
inputIterator.take(3).foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i)))
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
Thread.sleep(1000)
// Enqueue the remaining items (again one by one), merged in the final batch
inputIterator.foreach(i => queue += ssc.sparkContext.makeRDD(Seq(i)))
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
Thread.sleep(1000)
logInfo("Stopping context")
ssc.stop()
@@ -308,19 +303,19 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter {
// Advance the clock so that the files are created after StreamingContext starts, but
// not enough to trigger a batch
- clock.addToTime(batchDuration.milliseconds / 2)
+ clock.advance(batchDuration.milliseconds / 2)
// Over time, create files in the directory
val input = Seq(1, 2, 3, 4, 5)
input.foreach { i =>
val file = new File(testDir, i.toString)
Files.write(i + "\n", file, Charset.forName("UTF-8"))
- assert(file.setLastModified(clock.currentTime()))
- assert(file.lastModified === clock.currentTime)
+ assert(file.setLastModified(clock.getTimeMillis()))
+ assert(file.lastModified === clock.getTimeMillis())
logInfo("Created file " + file)
// Advance the clock after creating the file to avoid a race when
// setting its modification time
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
eventually(eventuallyTimeout) {
assert(batchCounter.getNumCompletedBatches === i)
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 132ff2443fc0f..818f551dbe996 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -39,7 +39,7 @@ import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.storage._
import org.apache.spark.streaming.receiver._
import org.apache.spark.streaming.util._
-import org.apache.spark.util.AkkaUtils
+import org.apache.spark.util.{AkkaUtils, ManualClock}
import WriteAheadLogBasedBlockHandler._
import WriteAheadLogSuite._
@@ -165,7 +165,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
preCleanupLogFiles.size should be > 1
// this depends on the number of blocks inserted using generateAndStoreData()
- manualClock.currentTime() shouldEqual 5000L
+ manualClock.getTimeMillis() shouldEqual 5000L
val cleanupThreshTime = 3000L
handler.cleanupOldBlocks(cleanupThreshTime)
@@ -243,7 +243,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche
val blockIds = Seq.fill(blocks.size)(generateBlockId())
val storeResults = blocks.zip(blockIds).map {
case (block, id) =>
- manualClock.addToTime(500) // log rolling interval set to 1000 ms through SparkConf
+ manualClock.advance(500) // log rolling interval set to 1000 ms through SparkConf
logDebug("Inserting block " + id)
receivedBlockHandler.storeBlock(id, block)
}.toList
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
index fbb7b0bfebafc..a3a0fd5187403 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockTrackerSuite.scala
@@ -34,9 +34,9 @@ import org.apache.spark.{Logging, SparkConf, SparkException}
import org.apache.spark.storage.StreamBlockId
import org.apache.spark.streaming.receiver.BlockManagerBasedStoreResult
import org.apache.spark.streaming.scheduler._
-import org.apache.spark.streaming.util.{Clock, ManualClock, SystemClock, WriteAheadLogReader}
+import org.apache.spark.streaming.util.WriteAheadLogReader
import org.apache.spark.streaming.util.WriteAheadLogSuite._
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{Clock, ManualClock, SystemClock, Utils}
class ReceivedBlockTrackerSuite
extends FunSuite with BeforeAndAfter with Matchers with Logging {
@@ -100,7 +100,7 @@ class ReceivedBlockTrackerSuite
def incrementTime() {
val timeIncrementMillis = 2000L
- manualClock.addToTime(timeIncrementMillis)
+ manualClock.advance(timeIncrementMillis)
}
// Generate and add blocks to the given tracker
@@ -138,13 +138,13 @@ class ReceivedBlockTrackerSuite
tracker2.getUnallocatedBlocks(streamId).toList shouldEqual blockInfos1
// Allocate blocks to batch and verify whether the unallocated blocks got allocated
- val batchTime1 = manualClock.currentTime
+ val batchTime1 = manualClock.getTimeMillis()
tracker2.allocateBlocksToBatch(batchTime1)
tracker2.getBlocksOfBatchAndStream(batchTime1, streamId) shouldEqual blockInfos1
// Add more blocks and allocate to another batch
incrementTime()
- val batchTime2 = manualClock.currentTime
+ val batchTime2 = manualClock.getTimeMillis()
val blockInfos2 = addBlockInfos(tracker2)
tracker2.allocateBlocksToBatch(batchTime2)
tracker2.getBlocksOfBatchAndStream(batchTime2, streamId) shouldEqual blockInfos2
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
index 2aa5e0876b6e0..6a7cd97aa3222 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala
@@ -190,7 +190,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
logInfo("Count = " + count + ", Running count = " + runningCount)
}
ssc.start()
- ssc.awaitTermination(500)
+ ssc.awaitTerminationOrTimeout(500)
ssc.stop(stopSparkContext = false, stopGracefully = true)
logInfo("Running count = " + runningCount)
logInfo("TestReceiver.counter = " + TestReceiver.counter.get())
@@ -223,7 +223,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
logInfo("Count = " + count + ", Running count = " + runningCount)
}
ssc.start()
- ssc.awaitTermination(500)
+ ssc.awaitTerminationOrTimeout(500)
ssc.stop(stopSparkContext = false, stopGracefully = true)
logInfo("Running count = " + runningCount)
assert(runningCount > 0)
@@ -243,7 +243,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
// test whether awaitTermination() exits after give amount of time
failAfter(1000 millis) {
- ssc.awaitTermination(500)
+ ssc.awaitTerminationOrTimeout(500)
}
// test whether awaitTermination() does not exit if not time is given
@@ -288,7 +288,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
val exception = intercept[Exception] {
ssc.start()
- ssc.awaitTermination(5000)
+ ssc.awaitTerminationOrTimeout(5000)
}
assert(exception.getMessage.contains("map task"), "Expected exception not thrown")
}
@@ -299,7 +299,7 @@ class StreamingContextSuite extends FunSuite with BeforeAndAfter with Timeouts w
inputStream.transform { rdd => throw new TestException("error in transform"); rdd }.register()
val exception = intercept[TestException] {
ssc.start()
- ssc.awaitTermination(5000)
+ ssc.awaitTerminationOrTimeout(5000)
}
assert(exception.getMessage.contains("transform"), "Expected exception not thrown")
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
index 7d82c3e4aadcf..3565d621e8a6c 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
@@ -31,10 +31,9 @@ import org.scalatest.concurrent.PatienceConfiguration
import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream}
import org.apache.spark.streaming.scheduler.{StreamingListenerBatchStarted, StreamingListenerBatchCompleted, StreamingListener}
-import org.apache.spark.streaming.util.ManualClock
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.rdd.RDD
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ManualClock, Utils}
/**
* This is a input stream just for the testsuites. This is equivalent to a checkpointable,
@@ -189,10 +188,10 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
def beforeFunction() {
if (useManualClock) {
logInfo("Using manual clock")
- conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock")
+ conf.set("spark.streaming.clock", "org.apache.spark.util.ManualClock")
} else {
logInfo("Using real clock")
- conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.SystemClock")
+ conf.set("spark.streaming.clock", "org.apache.spark.util.SystemClock")
}
}
@@ -333,23 +332,23 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging {
// Advance manual clock
val clock = ssc.scheduler.clock.asInstanceOf[ManualClock]
- logInfo("Manual clock before advancing = " + clock.currentTime())
+ logInfo("Manual clock before advancing = " + clock.getTimeMillis())
if (actuallyWait) {
for (i <- 1 to numBatches) {
logInfo("Actually waiting for " + batchDuration)
- clock.addToTime(batchDuration.milliseconds)
+ clock.advance(batchDuration.milliseconds)
Thread.sleep(batchDuration.milliseconds)
}
} else {
- clock.addToTime(numBatches * batchDuration.milliseconds)
+ clock.advance(numBatches * batchDuration.milliseconds)
}
- logInfo("Manual clock after advancing = " + clock.currentTime())
+ logInfo("Manual clock after advancing = " + clock.getTimeMillis())
// Wait until expected number of output items have been generated
val startTime = System.currentTimeMillis()
while (output.size < numExpectedOutput && System.currentTimeMillis() - startTime < maxWaitTimeMillis) {
logInfo("output.size = " + output.size + ", numExpectedOutput = " + numExpectedOutput)
- ssc.awaitTermination(50)
+ ssc.awaitTerminationOrTimeout(50)
}
val timeTaken = System.currentTimeMillis() - startTime
logInfo("Output generated in " + timeTaken + " milliseconds")
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
index 7ce9499dc614d..8335659667f22 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala
@@ -26,7 +26,7 @@ import scala.language.{implicitConversions, postfixOps}
import WriteAheadLogSuite._
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{ManualClock, Utils}
import org.scalatest.{BeforeAndAfter, FunSuite}
import org.scalatest.concurrent.Eventually._
@@ -197,7 +197,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
val logFiles = getLogFilesInDirectory(testDir)
assert(logFiles.size > 1)
- manager.cleanupOldLogs(manualClock.currentTime() / 2, waitForCompletion)
+ manager.cleanupOldLogs(manualClock.getTimeMillis() / 2, waitForCompletion)
if (waitForCompletion) {
assert(getLogFilesInDirectory(testDir).size < logFiles.size)
@@ -219,7 +219,7 @@ class WriteAheadLogSuite extends FunSuite with BeforeAndAfter {
// Recover old files and generate a second set of log files
val dataToWrite2 = generateRandomData()
- manualClock.addToTime(100000)
+ manualClock.advance(100000)
writeDataUsingManager(testDir, dataToWrite2, manualClock)
val logFiles2 = getLogFilesInDirectory(testDir)
assert(logFiles2.size > logFiles1.size)
@@ -279,12 +279,12 @@ object WriteAheadLogSuite {
manualClock: ManualClock = new ManualClock,
stopManager: Boolean = true
): WriteAheadLogManager = {
- if (manualClock.currentTime < 100000) manualClock.setTime(10000)
+ if (manualClock.getTimeMillis() < 100000) manualClock.setTime(10000)
val manager = new WriteAheadLogManager(logDirectory, hadoopConf,
rollingIntervalSecs = 1, callerName = "WriteAheadLogSuite", clock = manualClock)
// Ensure that 500 does not get sorted after 2000, so put a high base value.
data.foreach { item =>
- manualClock.addToTime(500)
+ manualClock.advance(500)
manager.writeToLog(item)
}
if (stopManager) manager.stop()
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index 4cc320c5d59b5..e966bfba7bb7d 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -19,9 +19,9 @@ package org.apache.spark.deploy.yarn
import scala.util.control.NonFatal
-import java.io.IOException
+import java.io.{File, IOException}
import java.lang.reflect.InvocationTargetException
-import java.net.Socket
+import java.net.{Socket, URL}
import java.util.concurrent.atomic.AtomicReference
import akka.actor._
@@ -38,7 +38,8 @@ import org.apache.spark.deploy.{PythonRunner, SparkHadoopUtil}
import org.apache.spark.deploy.history.HistoryServer
import org.apache.spark.scheduler.cluster.YarnSchedulerBackend
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
-import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils}
+import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader,
+ SignalLogger, Utils}
/**
* Common application master functionality for Spark on Yarn.
@@ -67,8 +68,8 @@ private[spark] class ApplicationMaster(
@volatile private var finalMsg: String = ""
@volatile private var userClassThread: Thread = _
- private var reporterThread: Thread = _
- private var allocator: YarnAllocator = _
+ @volatile private var reporterThread: Thread = _
+ @volatile private var allocator: YarnAllocator = _
// Fields used in client mode.
private var actorSystem: ActorSystem = null
@@ -244,7 +245,6 @@ private[spark] class ApplicationMaster(
host: String,
port: String,
isClusterMode: Boolean): Unit = {
-
val driverUrl = AkkaUtils.address(
AkkaUtils.protocol(actorSystem),
SparkEnv.driverActorSystemName,
@@ -453,12 +453,24 @@ private[spark] class ApplicationMaster(
private def startUserApplication(): Thread = {
logInfo("Starting the user application in a separate Thread")
System.setProperty("spark.executor.instances", args.numExecutors.toString)
+
+ val classpath = Client.getUserClasspath(sparkConf)
+ val urls = classpath.map { entry =>
+ new URL("file:" + new File(entry.getPath()).getAbsolutePath())
+ }
+ val userClassLoader =
+ if (Client.isUserClassPathFirst(sparkConf, isDriver = true)) {
+ new ChildFirstURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
+ } else {
+ new MutableURLClassLoader(urls, Utils.getContextOrSparkClassLoader)
+ }
+
if (args.primaryPyFile != null && args.primaryPyFile.endsWith(".py")) {
System.setProperty("spark.submit.pyFiles",
PythonRunner.formatPaths(args.pyFiles).mkString(","))
}
- val mainMethod = Class.forName(args.userClass, false,
- Thread.currentThread.getContextClassLoader).getMethod("main", classOf[Array[String]])
+ val mainMethod = userClassLoader.loadClass(args.userClass)
+ .getMethod("main", classOf[Array[String]])
val userThread = new Thread {
override def run() {
@@ -473,16 +485,16 @@ private[spark] class ApplicationMaster(
e.getCause match {
case _: InterruptedException =>
// Reporter thread can interrupt to stop user class
- case e: Exception =>
+ case cause: Throwable =>
+ logError("User class threw exception: " + cause.getMessage, cause)
finish(FinalApplicationStatus.FAILED,
ApplicationMaster.EXIT_EXCEPTION_USER_CLASS,
- "User class threw exception: " + e.getMessage)
- // re-throw to get it logged
- throw e
+ "User class threw exception: " + cause.getMessage)
}
}
}
}
+ userThread.setContextClassLoader(userClassLoader)
userThread.setName("Driver")
userThread.start()
userThread
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 8afc1ccdad732..46d9df93488cb 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -183,8 +183,7 @@ private[spark] class Client(
private[yarn] def copyFileToRemote(
destDir: Path,
srcPath: Path,
- replication: Short,
- setPerms: Boolean = false): Path = {
+ replication: Short): Path = {
val destFs = destDir.getFileSystem(hadoopConf)
val srcFs = srcPath.getFileSystem(hadoopConf)
var destPath = srcPath
@@ -193,9 +192,7 @@ private[spark] class Client(
logInfo(s"Uploading resource $srcPath -> $destPath")
FileUtil.copy(srcFs, srcPath, destFs, destPath, false, hadoopConf)
destFs.setReplication(destPath, replication)
- if (setPerms) {
- destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION))
- }
+ destFs.setPermission(destPath, new FsPermission(APP_FILE_PERMISSION))
} else {
logInfo(s"Source and destination file systems are the same. Not copying $srcPath")
}
@@ -239,23 +236,22 @@ private[spark] class Client(
/**
* Copy the given main resource to the distributed cache if the scheme is not "local".
* Otherwise, set the corresponding key in our SparkConf to handle it downstream.
- * Each resource is represented by a 4-tuple of:
+ * Each resource is represented by a 3-tuple of:
* (1) destination resource name,
* (2) local path to the resource,
- * (3) Spark property key to set if the scheme is not local, and
- * (4) whether to set permissions for this resource
+ * (3) Spark property key to set if the scheme is not local
*/
List(
- (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR, false),
- (APP_JAR, args.userJar, CONF_SPARK_USER_JAR, true),
- ("log4j.properties", oldLog4jConf.orNull, null, false)
- ).foreach { case (destName, _localPath, confKey, setPermissions) =>
+ (SPARK_JAR, sparkJar(sparkConf), CONF_SPARK_JAR),
+ (APP_JAR, args.userJar, CONF_SPARK_USER_JAR),
+ ("log4j.properties", oldLog4jConf.orNull, null)
+ ).foreach { case (destName, _localPath, confKey) =>
val localPath: String = if (_localPath != null) _localPath.trim() else ""
if (!localPath.isEmpty()) {
val localURI = new URI(localPath)
if (localURI.getScheme != LOCAL_SCHEME) {
val src = getQualifiedLocalPath(localURI, hadoopConf)
- val destPath = copyFileToRemote(dst, src, replication, setPermissions)
+ val destPath = copyFileToRemote(dst, src, replication)
val destFs = FileSystem.get(destPath.toUri(), hadoopConf)
distCacheMgr.addResource(destFs, hadoopConf, destPath,
localResources, LocalResourceType.FILE, destName, statCache)
@@ -707,7 +703,7 @@ object Client extends Logging {
* Return the path to the given application's staging directory.
*/
private def getAppStagingDir(appId: ApplicationId): String = {
- SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR
+ buildPath(SPARK_STAGING, appId.toString())
}
/**
@@ -783,7 +779,13 @@ object Client extends Logging {
/**
* Populate the classpath entry in the given environment map.
- * This includes the user jar, Spark jar, and any extra application jars.
+ *
+ * User jars are generally not added to the JVM's system classpath; those are handled by the AM
+ * and executor backend. When the deprecated `spark.yarn.user.classpath.first` is used, user jars
+ * are included in the system classpath, though. The extra class path and other uploaded files are
+ * always made available through the system class path.
+ *
+ * @param args Client arguments (when starting the AM) or null (when starting executors).
*/
private[yarn] def populateClasspath(
args: ClientArguments,
@@ -795,48 +797,38 @@ object Client extends Logging {
addClasspathEntry(
YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env
)
-
- // Normally the users app.jar is last in case conflicts with spark jars
if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) {
- addUserClasspath(args, sparkConf, env)
- addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env)
- populateHadoopClasspath(conf, env)
- } else {
- addFileToClasspath(sparkJar(sparkConf), SPARK_JAR, env)
- populateHadoopClasspath(conf, env)
- addUserClasspath(args, sparkConf, env)
+ val userClassPath =
+ if (args != null) {
+ getUserClasspath(Option(args.userJar), Option(args.addJars))
+ } else {
+ getUserClasspath(sparkConf)
+ }
+ userClassPath.foreach { x =>
+ addFileToClasspath(x, null, env)
+ }
}
-
- // Append all jar files under the working directory to the classpath.
- addClasspathEntry(
- YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + "*", env
- )
+ addFileToClasspath(new URI(sparkJar(sparkConf)), SPARK_JAR, env)
+ populateHadoopClasspath(conf, env)
+ sys.env.get(ENV_DIST_CLASSPATH).foreach(addClasspathEntry(_, env))
}
/**
- * Adds the user jars which have local: URIs (or alternate names, such as APP_JAR) explicitly
- * to the classpath.
+ * Returns a list of URIs representing the user classpath.
+ *
+ * @param conf Spark configuration.
*/
- private def addUserClasspath(
- args: ClientArguments,
- conf: SparkConf,
- env: HashMap[String, String]): Unit = {
-
- // If `args` is not null, we are launching an AM container.
- // Otherwise, we are launching executor containers.
- val (mainJar, secondaryJars) =
- if (args != null) {
- (args.userJar, args.addJars)
- } else {
- (conf.get(CONF_SPARK_USER_JAR, null), conf.get(CONF_SPARK_YARN_SECONDARY_JARS, null))
- }
+ def getUserClasspath(conf: SparkConf): Array[URI] = {
+ getUserClasspath(conf.getOption(CONF_SPARK_USER_JAR),
+ conf.getOption(CONF_SPARK_YARN_SECONDARY_JARS))
+ }
- addFileToClasspath(mainJar, APP_JAR, env)
- if (secondaryJars != null) {
- secondaryJars.split(",").filter(_.nonEmpty).foreach { jar =>
- addFileToClasspath(jar, null, env)
- }
- }
+ private def getUserClasspath(
+ mainJar: Option[String],
+ secondaryJars: Option[String]): Array[URI] = {
+ val mainUri = mainJar.orElse(Some(APP_JAR)).map(new URI(_))
+ val secondaryUris = secondaryJars.map(_.split(",")).toSeq.flatten.map(new URI(_))
+ (mainUri ++ secondaryUris).toArray
}
/**
@@ -847,27 +839,19 @@ object Client extends Logging {
*
* If not a "local:" file and no alternate name, the environment is not modified.
*
- * @param path Path to add to classpath (optional).
+ * @param uri URI to add to classpath (optional).
* @param fileName Alternate name for the file (optional).
* @param env Map holding the environment variables.
*/
private def addFileToClasspath(
- path: String,
+ uri: URI,
fileName: String,
env: HashMap[String, String]): Unit = {
- if (path != null) {
- scala.util.control.Exception.ignoring(classOf[URISyntaxException]) {
- val uri = new URI(path)
- if (uri.getScheme == LOCAL_SCHEME) {
- addClasspathEntry(uri.getPath, env)
- return
- }
- }
- }
- if (fileName != null) {
- addClasspathEntry(
- YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + fileName, env
- )
+ if (uri != null && uri.getScheme == LOCAL_SCHEME) {
+ addClasspathEntry(uri.getPath, env)
+ } else if (fileName != null) {
+ addClasspathEntry(buildPath(
+ YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), fileName), env)
}
}
@@ -963,4 +947,23 @@ object Client extends Logging {
new Path(qualifiedURI)
}
+ /**
+ * Whether to consider jars provided by the user to have precedence over the Spark jars when
+ * loading user classes.
+ */
+ def isUserClassPathFirst(conf: SparkConf, isDriver: Boolean): Boolean = {
+ if (isDriver) {
+ conf.getBoolean("spark.driver.userClassPathFirst", false)
+ } else {
+ conf.getBoolean("spark.executor.userClassPathFirst", false)
+ }
+ }
+
+ /**
+ * Joins all the path components using Path.SEPARATOR.
+ */
+ def buildPath(components: String*): String = {
+ components.mkString(Path.SEPARATOR)
+ }
+
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
index 7cd8c5f0f9204..c1d3f7320f53c 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala
@@ -17,6 +17,7 @@
package org.apache.spark.deploy.yarn
+import java.io.File
import java.net.URI
import java.nio.ByteBuffer
@@ -37,7 +38,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
-import org.apache.spark.{SecurityManager, SparkConf, Logging}
+import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException}
import org.apache.spark.network.util.JavaUtils
class ExecutorRunnable(
@@ -57,7 +58,7 @@ class ExecutorRunnable(
var nmClient: NMClient = _
val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
lazy val env = prepareEnvironment(container)
-
+
def run = {
logInfo("Starting Executor Container")
nmClient = NMClient.createNMClient()
@@ -108,7 +109,13 @@ class ExecutorRunnable(
}
// Send the start request to the ContainerManager
- nmClient.startContainer(container, ctx)
+ try {
+ nmClient.startContainer(container, ctx)
+ } catch {
+ case ex: Exception =>
+ throw new SparkException(s"Exception while starting container ${container.getId}" +
+ s" on host $hostname", ex)
+ }
}
private def prepareCommand(
@@ -185,6 +192,16 @@ class ExecutorRunnable(
// For log4j configuration to reference
javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR)
+ val userClassPath = Client.getUserClasspath(sparkConf).flatMap { uri =>
+ val absPath =
+ if (new File(uri.getPath()).isAbsolute()) {
+ uri.getPath()
+ } else {
+ Client.buildPath(Environment.PWD.$(), uri.getPath())
+ }
+ Seq("--user-class-path", "file:" + absPath)
+ }.toSeq
+
val commands = prefixEnv ++ Seq(
YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java",
"-server",
@@ -196,11 +213,13 @@ class ExecutorRunnable(
"-XX:OnOutOfMemoryError='kill %p'") ++
javaOpts ++
Seq("org.apache.spark.executor.CoarseGrainedExecutorBackend",
- masterAddress.toString,
- slaveId.toString,
- hostname.toString,
- executorCores.toString,
- appId,
+ "--driver-url", masterAddress.toString,
+ "--executor-id", slaveId.toString,
+ "--hostname", hostname.toString,
+ "--cores", executorCores.toString,
+ "--app-id", appId) ++
+ userClassPath ++
+ Seq(
"1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout",
"2>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr")
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index f1b5aafac4066..8abdc26b43806 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -79,18 +79,12 @@ private[spark] class YarnClientSchedulerBackend(
)
// Warn against the following deprecated environment variables: env var -> suggestion
val deprecatedEnvVars = Map(
- "SPARK_MASTER_MEMORY" -> "SPARK_DRIVER_MEMORY or --driver-memory through spark-submit",
"SPARK_WORKER_INSTANCES" -> "SPARK_WORKER_INSTANCES or --num-executors through spark-submit",
"SPARK_WORKER_MEMORY" -> "SPARK_EXECUTOR_MEMORY or --executor-memory through spark-submit",
"SPARK_WORKER_CORES" -> "SPARK_EXECUTOR_CORES or --executor-cores through spark-submit")
- // Do the same for deprecated properties: property -> suggestion
- val deprecatedProps = Map("spark.master.memory" -> "--driver-memory through spark-submit")
optionTuples.foreach { case (optionName, envVar, sparkProp) =>
if (sc.getConf.contains(sparkProp)) {
extraArgs += (optionName, sc.getConf.get(sparkProp))
- if (deprecatedProps.contains(sparkProp)) {
- logWarning(s"NOTE: $sparkProp is deprecated. Use ${deprecatedProps(sparkProp)} instead.")
- }
} else if (System.getenv(envVar) != null) {
extraArgs += (optionName, System.getenv(envVar))
if (deprecatedEnvVars.contains(envVar)) {
diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties
index 287c8e3563503..aab41fa49430f 100644
--- a/yarn/src/test/resources/log4j.properties
+++ b/yarn/src/test/resources/log4j.properties
@@ -16,7 +16,7 @@
#
# Set everything to be logged to the file target/unit-tests.log
-log4j.rootCategory=INFO, file
+log4j.rootCategory=DEBUG, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=true
log4j.appender.file.file=target/unit-tests.log
@@ -25,4 +25,4 @@ log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{
# Ignore messages below warning level from Jetty, because it's a bit verbose
log4j.logger.org.eclipse.jetty=WARN
-org.eclipse.jetty.LEVEL=WARN
+log4j.logger.org.apache.hadoop=WARN
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
index 2bb3dcffd61d9..92f04b4b859b3 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala
@@ -28,8 +28,7 @@ import org.apache.hadoop.yarn.api.records._
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.mockito.Matchers._
import org.mockito.Mockito._
-import org.scalatest.FunSuite
-import org.scalatest.Matchers
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
import scala.collection.JavaConversions._
import scala.collection.mutable.{ HashMap => MutableHashMap }
@@ -39,7 +38,15 @@ import scala.util.Try
import org.apache.spark.{SparkException, SparkConf}
import org.apache.spark.util.Utils
-class ClientSuite extends FunSuite with Matchers {
+class ClientSuite extends FunSuite with Matchers with BeforeAndAfterAll {
+
+ override def beforeAll(): Unit = {
+ System.setProperty("SPARK_YARN_MODE", "true")
+ }
+
+ override def afterAll(): Unit = {
+ System.clearProperty("SPARK_YARN_MODE")
+ }
test("default Yarn application classpath") {
Client.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP))
@@ -82,6 +89,7 @@ class ClientSuite extends FunSuite with Matchers {
test("Local jar URIs") {
val conf = new Configuration()
val sparkConf = new SparkConf().set(Client.CONF_SPARK_JAR, SPARK)
+ .set("spark.yarn.user.classpath.first", "true")
val env = new MutableHashMap[String, String]()
val args = new ClientArguments(Array("--jar", USER, "--addJars", ADDED), sparkConf)
@@ -98,13 +106,10 @@ class ClientSuite extends FunSuite with Matchers {
})
if (classOf[Environment].getMethods().exists(_.getName == "$$")) {
cp should contain("{{PWD}}")
- cp should contain(s"{{PWD}}${Path.SEPARATOR}*")
} else if (Utils.isWindows) {
cp should contain("%PWD%")
- cp should contain(s"%PWD%${Path.SEPARATOR}*")
} else {
cp should contain(Environment.PWD.$())
- cp should contain(s"${Environment.PWD.$()}${File.separator}*")
}
cp should not contain (Client.SPARK_JAR)
cp should not contain (Client.APP_JAR)
@@ -117,7 +122,7 @@ class ClientSuite extends FunSuite with Matchers {
val client = spy(new Client(args, conf, sparkConf))
doReturn(new Path("/")).when(client).copyFileToRemote(any(classOf[Path]),
- any(classOf[Path]), anyShort(), anyBoolean())
+ any(classOf[Path]), anyShort())
val tempDir = Utils.createTempDir()
try {
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index e39de82740b1d..0e37276ba724b 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -17,27 +17,34 @@
package org.apache.spark.deploy.yarn
-import java.io.File
+import java.io.{File, FileOutputStream, OutputStreamWriter}
+import java.util.Properties
import java.util.concurrent.TimeUnit
import scala.collection.JavaConversions._
import scala.collection.mutable
-import com.google.common.base.Charsets
+import com.google.common.base.Charsets.UTF_8
+import com.google.common.io.ByteStreams
import com.google.common.io.Files
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.server.MiniYARNCluster
import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
-import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException}
+import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException, TestUtils}
import org.apache.spark.scheduler.cluster.ExecutorInfo
import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded}
import org.apache.spark.util.Utils
+/**
+ * Integration tests for YARN; these tests use a mini Yarn cluster to run Spark-on-YARN
+ * applications, and require the Spark assembly to be built before they can be successfully
+ * run.
+ */
class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers with Logging {
- // log4j configuration for the Yarn containers, so that their output is collected
- // by Yarn instead of trying to overwrite unit-tests.log.
+ // log4j configuration for the YARN containers, so that their output is collected
+ // by YARN instead of trying to overwrite unit-tests.log.
private val LOG4J_CONF = """
|log4j.rootCategory=DEBUG, console
|log4j.appender.console=org.apache.log4j.ConsoleAppender
@@ -52,13 +59,11 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
|
|from pyspark import SparkConf , SparkContext
|if __name__ == "__main__":
- | if len(sys.argv) != 3:
- | print >> sys.stderr, "Usage: test.py [master] [result file]"
+ | if len(sys.argv) != 2:
+ | print >> sys.stderr, "Usage: test.py [result file]"
| exit(-1)
- | conf = SparkConf()
- | conf.setMaster(sys.argv[1]).setAppName("python test in yarn cluster mode")
- | sc = SparkContext(conf=conf)
- | status = open(sys.argv[2],'w')
+ | sc = SparkContext(conf=SparkConf())
+ | status = open(sys.argv[1],'w')
| result = "failure"
| rdd = sc.parallelize(range(10))
| cnt = rdd.count()
@@ -72,23 +77,17 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
private var yarnCluster: MiniYARNCluster = _
private var tempDir: File = _
private var fakeSparkJar: File = _
- private var oldConf: Map[String, String] = _
+ private var logConfDir: File = _
override def beforeAll() {
super.beforeAll()
tempDir = Utils.createTempDir()
-
- val logConfDir = new File(tempDir, "log4j")
+ logConfDir = new File(tempDir, "log4j")
logConfDir.mkdir()
val logConfFile = new File(logConfDir, "log4j.properties")
- Files.write(LOG4J_CONF, logConfFile, Charsets.UTF_8)
-
- val childClasspath = logConfDir.getAbsolutePath() + File.pathSeparator +
- sys.props("java.class.path")
-
- oldConf = sys.props.filter { case (k, v) => k.startsWith("spark.") }.toMap
+ Files.write(LOG4J_CONF, logConfFile, UTF_8)
yarnCluster = new MiniYARNCluster(getClass().getName(), 1, 1, 1)
yarnCluster.init(new YarnConfiguration())
@@ -119,99 +118,165 @@ class YarnClusterSuite extends FunSuite with BeforeAndAfterAll with Matchers wit
}
logInfo(s"RM address in configuration is ${config.get(YarnConfiguration.RM_ADDRESS)}")
- config.foreach { e =>
- sys.props += ("spark.hadoop." + e.getKey() -> e.getValue())
- }
fakeSparkJar = File.createTempFile("sparkJar", null, tempDir)
- val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!"))
- sys.props += ("spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome)
- sys.props += ("spark.executorEnv.SPARK_HOME" -> sparkHome)
- sys.props += ("spark.yarn.jar" -> ("local:" + fakeSparkJar.getAbsolutePath()))
- sys.props += ("spark.executor.instances" -> "1")
- sys.props += ("spark.driver.extraClassPath" -> childClasspath)
- sys.props += ("spark.executor.extraClassPath" -> childClasspath)
- sys.props += ("spark.executor.extraJavaOptions" -> "-Dfoo=\"one two three\"")
- sys.props += ("spark.driver.extraJavaOptions" -> "-Dfoo=\"one two three\"")
}
override def afterAll() {
yarnCluster.stop()
- sys.props.retain { case (k, v) => !k.startsWith("spark.") }
- sys.props ++= oldConf
super.afterAll()
}
test("run Spark in yarn-client mode") {
- var result = File.createTempFile("result", null, tempDir)
- YarnClusterDriver.main(Array("yarn-client", result.getAbsolutePath()))
- checkResult(result)
-
- // verify log urls are present
- YarnClusterDriver.listener.addedExecutorInfos.values.foreach { info =>
- assert(info.logUrlMap.nonEmpty)
- }
+ testBasicYarnApp(true)
}
test("run Spark in yarn-cluster mode") {
- val main = YarnClusterDriver.getClass.getName().stripSuffix("$")
- var result = File.createTempFile("result", null, tempDir)
-
- val args = Array("--class", main,
- "--jar", "file:" + fakeSparkJar.getAbsolutePath(),
- "--arg", "yarn-cluster",
- "--arg", result.getAbsolutePath(),
- "--num-executors", "1")
- Client.main(args)
- checkResult(result)
-
- // verify log urls are present.
- YarnClusterDriver.listener.addedExecutorInfos.values.foreach { info =>
- assert(info.logUrlMap.nonEmpty)
- }
+ testBasicYarnApp(false)
}
test("run Spark in yarn-cluster mode unsuccessfully") {
- val main = YarnClusterDriver.getClass.getName().stripSuffix("$")
-
- // Use only one argument so the driver will fail
- val args = Array("--class", main,
- "--jar", "file:" + fakeSparkJar.getAbsolutePath(),
- "--arg", "yarn-cluster",
- "--num-executors", "1")
+ // Don't provide arguments so the driver will fail.
val exception = intercept[SparkException] {
- Client.main(args)
+ runSpark(false, mainClassName(YarnClusterDriver.getClass))
+ fail("Spark application should have failed.")
}
- assert(Utils.exceptionString(exception).contains("Application finished with failed status"))
}
test("run Python application in yarn-cluster mode") {
val primaryPyFile = new File(tempDir, "test.py")
- Files.write(TEST_PYFILE, primaryPyFile, Charsets.UTF_8)
+ Files.write(TEST_PYFILE, primaryPyFile, UTF_8)
val pyFile = new File(tempDir, "test2.py")
- Files.write(TEST_PYFILE, pyFile, Charsets.UTF_8)
+ Files.write(TEST_PYFILE, pyFile, UTF_8)
var result = File.createTempFile("result", null, tempDir)
- val args = Array("--class", "org.apache.spark.deploy.PythonRunner",
- "--primary-py-file", primaryPyFile.getAbsolutePath(),
- "--py-files", pyFile.getAbsolutePath(),
- "--arg", "yarn-cluster",
- "--arg", result.getAbsolutePath(),
- "--name", "python test in yarn-cluster mode",
- "--num-executors", "1")
- Client.main(args)
+ // The sbt assembly does not include pyspark / py4j python dependencies, so we need to
+ // propagate SPARK_HOME so that those are added to PYTHONPATH. See PythonUtils.scala.
+ val sparkHome = sys.props("spark.test.home")
+ val extraConf = Map(
+ "spark.executorEnv.SPARK_HOME" -> sparkHome,
+ "spark.yarn.appMasterEnv.SPARK_HOME" -> sparkHome)
+
+ runSpark(false, primaryPyFile.getAbsolutePath(),
+ sparkArgs = Seq("--py-files", pyFile.getAbsolutePath()),
+ appArgs = Seq(result.getAbsolutePath()),
+ extraConf = extraConf)
checkResult(result)
}
+ test("user class path first in client mode") {
+ testUseClassPathFirst(true)
+ }
+
+ test("user class path first in cluster mode") {
+ testUseClassPathFirst(false)
+ }
+
+ private def testBasicYarnApp(clientMode: Boolean): Unit = {
+ var result = File.createTempFile("result", null, tempDir)
+ runSpark(clientMode, mainClassName(YarnClusterDriver.getClass),
+ appArgs = Seq(result.getAbsolutePath()))
+ checkResult(result)
+ }
+
+ private def testUseClassPathFirst(clientMode: Boolean): Unit = {
+ // Create a jar file that contains a different version of "test.resource".
+ val originalJar = TestUtils.createJarWithFiles(Map("test.resource" -> "ORIGINAL"), tempDir)
+ val userJar = TestUtils.createJarWithFiles(Map("test.resource" -> "OVERRIDDEN"), tempDir)
+ val driverResult = File.createTempFile("driver", null, tempDir)
+ val executorResult = File.createTempFile("executor", null, tempDir)
+ runSpark(clientMode, mainClassName(YarnClasspathTest.getClass),
+ appArgs = Seq(driverResult.getAbsolutePath(), executorResult.getAbsolutePath()),
+ extraClassPath = Seq(originalJar.getPath()),
+ extraJars = Seq("local:" + userJar.getPath()),
+ extraConf = Map(
+ "spark.driver.userClassPathFirst" -> "true",
+ "spark.executor.userClassPathFirst" -> "true"))
+ checkResult(driverResult, "OVERRIDDEN")
+ checkResult(executorResult, "OVERRIDDEN")
+ }
+
+ private def runSpark(
+ clientMode: Boolean,
+ klass: String,
+ appArgs: Seq[String] = Nil,
+ sparkArgs: Seq[String] = Nil,
+ extraClassPath: Seq[String] = Nil,
+ extraJars: Seq[String] = Nil,
+ extraConf: Map[String, String] = Map()): Unit = {
+ val master = if (clientMode) "yarn-client" else "yarn-cluster"
+ val props = new Properties()
+
+ props.setProperty("spark.yarn.jar", "local:" + fakeSparkJar.getAbsolutePath())
+
+ val childClasspath = logConfDir.getAbsolutePath() +
+ File.pathSeparator +
+ sys.props("java.class.path") +
+ File.pathSeparator +
+ extraClassPath.mkString(File.pathSeparator)
+ props.setProperty("spark.driver.extraClassPath", childClasspath)
+ props.setProperty("spark.executor.extraClassPath", childClasspath)
+
+ // SPARK-4267: make sure java options are propagated correctly.
+ props.setProperty("spark.driver.extraJavaOptions", "-Dfoo=\"one two three\"")
+ props.setProperty("spark.executor.extraJavaOptions", "-Dfoo=\"one two three\"")
+
+ yarnCluster.getConfig().foreach { e =>
+ props.setProperty("spark.hadoop." + e.getKey(), e.getValue())
+ }
+
+ sys.props.foreach { case (k, v) =>
+ if (k.startsWith("spark.")) {
+ props.setProperty(k, v)
+ }
+ }
+
+ extraConf.foreach { case (k, v) => props.setProperty(k, v) }
+
+ val propsFile = File.createTempFile("spark", ".properties", tempDir)
+ val writer = new OutputStreamWriter(new FileOutputStream(propsFile), UTF_8)
+ props.store(writer, "Spark properties.")
+ writer.close()
+
+ val extraJarArgs = if (!extraJars.isEmpty()) Seq("--jars", extraJars.mkString(",")) else Nil
+ val mainArgs =
+ if (klass.endsWith(".py")) {
+ Seq(klass)
+ } else {
+ Seq("--class", klass, fakeSparkJar.getAbsolutePath())
+ }
+ val argv =
+ Seq(
+ new File(sys.props("spark.test.home"), "bin/spark-submit").getAbsolutePath(),
+ "--master", master,
+ "--num-executors", "1",
+ "--properties-file", propsFile.getAbsolutePath()) ++
+ extraJarArgs ++
+ sparkArgs ++
+ mainArgs ++
+ appArgs
+
+ Utils.executeAndGetOutput(argv,
+ extraEnvironment = Map("YARN_CONF_DIR" -> tempDir.getAbsolutePath()))
+ }
+
/**
* This is a workaround for an issue with yarn-cluster mode: the Client class will not provide
* any sort of error when the job process finishes successfully, but the job itself fails. So
* the tests enforce that something is written to a file after everything is ok to indicate
* that the job succeeded.
*/
- private def checkResult(result: File) = {
- var resultString = Files.toString(result, Charsets.UTF_8)
- resultString should be ("success")
+ private def checkResult(result: File): Unit = {
+ checkResult(result, "success")
+ }
+
+ private def checkResult(result: File, expected: String): Unit = {
+ var resultString = Files.toString(result, UTF_8)
+ resultString should be (expected)
+ }
+
+ private def mainClassName(klass: Class[_]): String = {
+ klass.getName().stripSuffix("$")
}
}
@@ -229,22 +294,22 @@ private object YarnClusterDriver extends Logging with Matchers {
val WAIT_TIMEOUT_MILLIS = 10000
var listener: SaveExecutorInfo = null
- def main(args: Array[String]) = {
- if (args.length != 2) {
+ def main(args: Array[String]): Unit = {
+ if (args.length != 1) {
System.err.println(
s"""
|Invalid command line: ${args.mkString(" ")}
|
- |Usage: YarnClusterDriver [master] [result file]
+ |Usage: YarnClusterDriver [result file]
""".stripMargin)
System.exit(1)
}
listener = new SaveExecutorInfo
- val sc = new SparkContext(new SparkConf().setMaster(args(0))
+ val sc = new SparkContext(new SparkConf()
.setAppName("yarn \"test app\" 'with quotes' and \\back\\slashes and $dollarSigns"))
sc.addSparkListener(listener)
- val status = new File(args(1))
+ val status = new File(args(0))
var result = "failure"
try {
val data = sc.parallelize(1 to 4, 4).collect().toSet
@@ -253,7 +318,48 @@ private object YarnClusterDriver extends Logging with Matchers {
result = "success"
} finally {
sc.stop()
- Files.write(result, status, Charsets.UTF_8)
+ Files.write(result, status, UTF_8)
+ }
+
+ // verify log urls are present
+ listener.addedExecutorInfos.values.foreach { info =>
+ assert(info.logUrlMap.nonEmpty)
+ }
+ }
+
+}
+
+private object YarnClasspathTest {
+
+ def main(args: Array[String]): Unit = {
+ if (args.length != 2) {
+ System.err.println(
+ s"""
+ |Invalid command line: ${args.mkString(" ")}
+ |
+ |Usage: YarnClasspathTest [driver result file] [executor result file]
+ """.stripMargin)
+ System.exit(1)
+ }
+
+ readResource(args(0))
+ val sc = new SparkContext(new SparkConf())
+ try {
+ sc.parallelize(Seq(1)).foreach { x => readResource(args(1)) }
+ } finally {
+ sc.stop()
+ }
+ }
+
+ private def readResource(resultPath: String): Unit = {
+ var result = "failure"
+ try {
+ val ccl = Thread.currentThread().getContextClassLoader()
+ val resource = ccl.getResourceAsStream("test.resource")
+ val bytes = ByteStreams.toByteArray(resource)
+ result = new String(bytes, 0, bytes.length, UTF_8)
+ } finally {
+ Files.write(result, new File(resultPath), UTF_8)
}
}