diff --git a/assembly/pom.xml b/assembly/pom.xml index 594fa0c779e1b..1bb5a671f5390 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -43,12 +43,6 @@ - - - com.google.guava - guava - compile - org.apache.spark spark-core_${scala.binary.version} @@ -133,22 +127,6 @@ shade - - - com.google - org.spark-project.guava - - com.google.common.** - - - com/google/common/base/Absent* - com/google/common/base/Function - com/google/common/base/Optional* - com/google/common/base/Present* - com/google/common/base/Supplier - - - diff --git a/build/mvn b/build/mvn index f91e2b4bdcc02..a87c5a26230c8 100755 --- a/build/mvn +++ b/build/mvn @@ -48,11 +48,11 @@ install_app() { # check if we already have the tarball # check if we have curl installed # download application - [ ! -f "${local_tarball}" ] && [ -n "`which curl 2>/dev/null`" ] && \ + [ ! -f "${local_tarball}" ] && [ $(command -v curl) ] && \ echo "exec: curl ${curl_opts} ${remote_tarball}" && \ curl ${curl_opts} "${remote_tarball}" > "${local_tarball}" # if the file still doesn't exist, lets try `wget` and cross our fingers - [ ! -f "${local_tarball}" ] && [ -n "`which wget 2>/dev/null`" ] && \ + [ ! -f "${local_tarball}" ] && [ $(command -v wget) ] && \ echo "exec: wget ${wget_opts} ${remote_tarball}" && \ wget ${wget_opts} -O "${local_tarball}" "${remote_tarball}" # if both were unsuccessful, exit diff --git a/build/sbt-launch-lib.bash b/build/sbt-launch-lib.bash index f5df439effb01..5e0c640fa5919 100755 --- a/build/sbt-launch-lib.bash +++ b/build/sbt-launch-lib.bash @@ -50,9 +50,9 @@ acquire_sbt_jar () { # Download printf "Attempting to fetch sbt\n" JAR_DL="${JAR}.part" - if hash curl 2>/dev/null; then + if [ $(command -v curl) ]; then (curl --silent ${URL1} > "${JAR_DL}" || curl --silent ${URL2} > "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" - elif hash wget 2>/dev/null; then + elif [ $(command -v wget) ]; then (wget --quiet ${URL1} -O "${JAR_DL}" || wget --quiet ${URL2} -O "${JAR_DL}") && mv "${JAR_DL}" "${JAR}" else printf "You do not have curl or wget installed, please install sbt manually from http://www.scala-sbt.org/\n" diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index 96b6844f0aabb..464c14457e53f 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -87,6 +87,7 @@ # period 10 Poll period # unit seconds Units of poll period # prefix EMPTY STRING Prefix to prepend to metric name +# protocol tcp Protocol ("tcp" or "udp") to use ## Examples # Enable JmxSink for all instances by class name diff --git a/core/pom.xml b/core/pom.xml index 1984682b9c099..6fce10a0aea4c 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -34,6 +34,10 @@ Spark Project Core http://spark.apache.org/ + + com.google.guava + guava + com.twitter chill_${scala.binary.version} @@ -106,16 +110,6 @@ org.eclipse.jetty jetty-server - - - com.google.guava - guava - compile - org.apache.commons commons-lang3 @@ -204,19 +198,19 @@ stream - com.codahale.metrics + io.dropwizard.metrics metrics-core - com.codahale.metrics + io.dropwizard.metrics metrics-jvm - com.codahale.metrics + io.dropwizard.metrics metrics-json - com.codahale.metrics + io.dropwizard.metrics metrics-graphite @@ -350,44 +344,6 @@ true - - org.apache.maven.plugins - maven-shade-plugin - - - package - - shade - - - false - - - com.google.guava:guava - - - - - - com.google.guava:guava - - com/google/common/base/Absent* - com/google/common/base/Function - com/google/common/base/Optional* - com/google/common/base/Present* - com/google/common/base/Supplier - - - - - - - - org.apache.maven.plugins maven-dependency-plugin diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index c99a61f63ea2b..89eec7d4b7f61 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -10,4 +10,3 @@ log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.eclipse.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO -log4j.logger.org.apache.hadoop.yarn.util.RackResolver=WARN diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4c4ee04cc515e..3c61c10820ba9 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1986,7 +1986,7 @@ object SparkContext extends Logging { case "yarn-client" => val scheduler = try { val clazz = - Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + Class.forName("org.apache.spark.scheduler.cluster.YarnScheduler") val cons = clazz.getConstructor(classOf[SparkContext]) cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 62bf18d82d9b0..0f91c942ecd50 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -348,6 +348,19 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#reduce]] + */ + def treeReduce(f: JFunction2[T, T, T], depth: Int): T = rdd.treeReduce(f, depth) + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeReduce]] with suggested depth 2. + */ + def treeReduce(f: JFunction2[T, T, T]): T = treeReduce(f, 2) + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -369,6 +382,30 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { combOp: JFunction2[U, U, U]): U = rdd.aggregate(zeroValue)(seqOp, combOp)(fakeClassTag[U]) + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree + * @see [[org.apache.spark.api.java.JavaRDDLike#aggregate]] + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U], + depth: Int): U = { + rdd.treeAggregate(zeroValue)(seqOp, combOp, depth)(fakeClassTag[U]) + } + + /** + * [[org.apache.spark.api.java.JavaRDDLike#treeAggregate]] with suggested depth 2. + */ + def treeAggregate[U]( + zeroValue: U, + seqOp: JFunction2[U, T, U], + combOp: JFunction2[U, U, U]): U = { + treeAggregate(zeroValue, seqOp, combOp, 2) + } + /** * Return the number of elements in the RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 5ba66178e2b78..c9181a29d4756 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -138,6 +138,11 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable + case array: Array[Any] => { + val arrayWriteable = new ArrayWritable(classOf[Writable]) + arrayWriteable.set(array.map(convertToWritable(_))) + arrayWriteable + } case other => throw new SparkException( s"Data of type ${other.getClass.getName} cannot be used") } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4ac666c54fbcd..119e0459c5d1b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -67,17 +67,16 @@ private[spark] class PythonRDD( envVars += ("SPARK_REUSE_WORKER" -> "1") } val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap) + // Whether is the worker released into idle pool + @volatile var released = false // Start a thread to feed the process input from our parent's iterator val writerThread = new WriterThread(env, worker, split, context) - var complete_cleanly = false context.addTaskCompletionListener { context => writerThread.shutdownOnTaskCompletion() writerThread.join() - if (reuse_worker && complete_cleanly) { - env.releasePythonWorker(pythonExec, envVars.toMap, worker) - } else { + if (!reuse_worker || !released) { try { worker.close() } catch { @@ -145,8 +144,12 @@ private[spark] class PythonRDD( stream.readFully(update) accumulator += Collections.singletonList(update) } + // Check whether the worker is ready to be re-used. if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - complete_cleanly = true + if (reuse_worker) { + env.releasePythonWorker(pythonExec, envVars.toMap, worker) + released = true + } } null } diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index a4153aaa926f8..fb52a960e0765 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -153,7 +153,10 @@ private[spark] object SerDeUtil extends Logging { iter.flatMap { row => val obj = unpickle.loads(row) if (batched) { - obj.asInstanceOf[JArrayList[_]].asScala + obj match { + case array: Array[Any] => array.toSeq + case _ => obj.asInstanceOf[JArrayList[_]].asScala + } } else { Seq(obj) } @@ -199,7 +202,10 @@ private[spark] object SerDeUtil extends Logging { * representation is serialized */ def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = { - val (keyFailed, valueFailed) = checkPickle(rdd.first()) + val (keyFailed, valueFailed) = rdd.take(1) match { + case Array() => (false, false) + case Array(first) => checkPickle(first) + } rdd.mapPartitions { iter => val cleaned = iter.map { case (k, v) => @@ -226,10 +232,12 @@ private[spark] object SerDeUtil extends Logging { } val rdd = pythonToJava(pyRDD, batched).rdd - rdd.first match { - case obj if isPair(obj) => + rdd.take(1) match { + case Array(obj) if isPair(obj) => // we only accept (K, V) - case other => throw new SparkException( + case Array() => + // we also accept empty collections + case Array(other) => throw new SparkException( s"RDD element of type ${other.getClass.getName} cannot be used") } rdd.map { obj => diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 57f9faf5ddd1d..211e3ede53d9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -133,10 +133,9 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesReadOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead") val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesRead = f() @@ -156,10 +155,9 @@ class SparkHadoopUtil extends Logging { * statistics are only available as of Hadoop 2.5 (see HADOOP-10688). * Returns None if the required method can't be found. */ - private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration) - : Option[() => Long] = { + private[spark] def getFSBytesWrittenOnThreadCallback(): Option[() => Long] = { try { - val threadStats = getFileSystemThreadStatistics(path, conf) + val threadStats = getFileSystemThreadStatistics() val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten") val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum val baselineBytesWritten = f() @@ -172,10 +170,8 @@ class SparkHadoopUtil extends Logging { } } - private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = { - val qualifiedPath = path.getFileSystem(conf).makeQualified(path) - val scheme = qualifiedPath.toUri().getScheme() - val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme)) + private def getFileSystemThreadStatistics(): Seq[AnyRef] = { + val stats = FileSystem.getAllStatistics() stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics")) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 050ba91eb2bc3..c240bcd705d93 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -144,6 +144,8 @@ object SparkSubmit { printErrorAndExit("Cluster deploy mode is not applicable to Spark shells.") case (_, CLUSTER) if isSqlShell(args.mainClass) => printErrorAndExit("Cluster deploy mode is not applicable to Spark SQL shell.") + case (_, CLUSTER) if isThriftServer(args.mainClass) => + printErrorAndExit("Cluster deploy mode is not applicable to Spark Thrift server.") case _ => } @@ -408,6 +410,13 @@ object SparkSubmit { mainClass == "org.apache.spark.sql.hive.thriftserver.SparkSQLCLIDriver" } + /** + * Return whether the given main class represents a thrift server. + */ + private[spark] def isThriftServer(mainClass: String): Boolean = { + mainClass == "org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" + } + /** * Return whether the given primary resource requires running python. */ diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index d8c2e41a7c715..312bb3a1daaa3 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -76,7 +76,6 @@ private[spark] class Executor( } val executorSource = new ExecutorSource(this, executorId) - conf.set("spark.executor.id", executorId) if (!isLocal) { env.metricsSystem.registerSource(executorSource) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index ddb5903bf6875..97912c68c5982 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -19,7 +19,6 @@ package org.apache.spark.executor import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.executor.DataReadMethod import org.apache.spark.executor.DataReadMethod.DataReadMethod import scala.collection.mutable.ArrayBuffer diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index d7b5f5c40efae..2d25ebd66159f 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -22,7 +22,7 @@ import java.util.Properties import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.graphite.{Graphite, GraphiteReporter} +import com.codahale.metrics.graphite.{GraphiteUDP, Graphite, GraphiteReporter} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem @@ -38,6 +38,7 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric val GRAPHITE_KEY_PERIOD = "period" val GRAPHITE_KEY_UNIT = "unit" val GRAPHITE_KEY_PREFIX = "prefix" + val GRAPHITE_KEY_PROTOCOL = "protocol" def propertyToOption(prop: String): Option[String] = Option(property.getProperty(prop)) @@ -66,7 +67,11 @@ private[spark] class GraphiteSink(val property: Properties, val registry: Metric MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) - val graphite: Graphite = new Graphite(new InetSocketAddress(host, port)) + val graphite = propertyToOption(GRAPHITE_KEY_PROTOCOL).map(_.toLowerCase) match { + case Some("udp") => new GraphiteUDP(new InetSocketAddress(host, port)) + case Some("tcp") | None => new Graphite(new InetSocketAddress(host, port)) + case Some(p) => throw new Exception(s"Invalid Graphite protocol: $p") + } val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry) .convertDurationsTo(TimeUnit.MILLISECONDS) diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala index 03c4137ca0a81..ee22c6656e69e 100644 --- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala @@ -184,14 +184,16 @@ private[nio] class ConnectionManager( // to be able to track asynchronous messages private val idCount: AtomicInteger = new AtomicInteger(1) + private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() + private val selectorThread = new Thread("connection-manager-thread") { override def run() = ConnectionManager.this.run() } selectorThread.setDaemon(true) + // start this thread last, since it invokes run(), which accesses members above selectorThread.start() - private val writeRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() - private def triggerWrite(key: SelectionKey) { val conn = connectionsByKey.getOrElse(key, null) if (conn == null) return @@ -232,7 +234,6 @@ private[nio] class ConnectionManager( } ) } - private val readRunnableStarted: HashSet[SelectionKey] = new HashSet[SelectionKey]() private def triggerRead(key: SelectionKey) { val conn = connectionsByKey.getOrElse(key, null) diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 056aef0bc210a..c3e3931042de2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,6 +35,7 @@ import org.apache.hadoop.mapred.Reporter import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID +import org.apache.hadoop.mapred.lib.CombineFileSplit import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -218,13 +219,13 @@ class HadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.inputSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, jobConf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) var reader: RecordReader[K, V] = null @@ -254,7 +255,8 @@ class HadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.inputSplit.value.isInstanceOf[FileSplit]) { + } else if (split.inputSplit.value.isInstanceOf[FileSplit] || + split.inputSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 7b0e3c87ccff4..d86f95ac3e485 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.FileSplit +import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.input.WholeTextFileInputFormat @@ -34,7 +34,7 @@ import org.apache.spark.Logging import org.apache.spark.Partition import org.apache.spark.SerializableWritable import org.apache.spark.{SparkContext, TaskContext} -import org.apache.spark.executor.{DataReadMethod, InputMetrics} +import org.apache.spark.executor.DataReadMethod import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.util.Utils @@ -114,13 +114,13 @@ class NewHadoopRDD[K, V]( // Find a function that will return the FileSystem bytes read by this thread. Do this before // creating RecordReader, because RecordReader's constructor might read some bytes - val bytesReadCallback = inputMetrics.bytesReadCallback.orElse( + val bytesReadCallback = inputMetrics.bytesReadCallback.orElse { split.serializableHadoopSplit.value match { - case split: FileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(split.getPath, conf) + case _: FileSplit | _: CombineFileSplit => + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() case _ => None } - ) + } inputMetrics.setBytesReadCallback(bytesReadCallback) val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) @@ -163,7 +163,8 @@ class NewHadoopRDD[K, V]( reader.close() if (bytesReadCallback.isDefined) { inputMetrics.updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) { + } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || + split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { // If we can't get the bytes read from the FS stats, fall back to the split size, // which may be inaccurate. try { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 0f37d830ef34f..49b88a90ab5af 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -990,7 +990,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val committer = format.getOutputCommitter(hadoopContext) committer.setupTask(hadoopContext) - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]] try { @@ -1061,7 +1061,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) // around by taking a mod. We expect that no task will be attempted 2 billion times. val taskAttemptId = (context.taskAttemptId % Int.MaxValue).toInt - val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config) + val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context) writer.setup(context.stageId, context.partitionId, taskAttemptId) writer.open() @@ -1086,11 +1086,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) writer.commitJob() } - private def initHadoopOutputMetrics(context: TaskContext, config: Configuration) - : (OutputMetrics, Option[() => Long]) = { - val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir")) - .map(new Path(_)) - .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config)) + private def initHadoopOutputMetrics(context: TaskContext): (OutputMetrics, Option[() => Long]) = { + val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop) if (bytesWrittenCallback.isDefined) { context.taskMetrics.outputMetrics = Some(outputMetrics) diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index ab7410a1f7f99..97aee58bddbf1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -604,8 +604,8 @@ abstract class RDD[T: ClassTag]( * print line function (like out.println()) as the 2nd parameter. * An example of pipe the RDD data of groupBy() in a streaming way, * instead of constructing a huge String to concat all the elements: - * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = - * for (e <- record._2){f(e)} + * def printRDDElement(record:(String, Seq[String]), f:String=>Unit) = + * for (e <- record._2){f(e)} * @param separateWorkingDir Use separate working directories for each task. * @return the result RDD */ @@ -841,7 +841,7 @@ abstract class RDD[T: ClassTag]( * Return an RDD with the elements from `this` that are not in `other`. * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting - * RDD will be <= us. + * RDD will be <= us. */ def subtract(other: RDD[T]): RDD[T] = subtract(other, partitioner.getOrElse(new HashPartitioner(partitions.size))) @@ -900,6 +900,38 @@ abstract class RDD[T: ClassTag]( jobResult.getOrElse(throw new UnsupportedOperationException("empty collection")) } + /** + * Reduces the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#reduce]] + */ + def treeReduce(f: (T, T) => T, depth: Int = 2): T = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + val cleanF = context.clean(f) + val reducePartition: Iterator[T] => Option[T] = iter => { + if (iter.hasNext) { + Some(iter.reduceLeft(cleanF)) + } else { + None + } + } + val partiallyReduced = mapPartitions(it => Iterator(reducePartition(it))) + val op: (Option[T], Option[T]) => Option[T] = (c, x) => { + if (c.isDefined && x.isDefined) { + Some(cleanF(c.get, x.get)) + } else if (c.isDefined) { + c + } else if (x.isDefined) { + x + } else { + None + } + } + partiallyReduced.treeAggregate(Option.empty[T])(op, op, depth) + .getOrElse(throw new UnsupportedOperationException("empty collection")) + } + /** * Aggregate the elements of each partition, and then the results for all the partitions, using a * given associative function and a neutral "zero value". The function op(t1, t2) is allowed to @@ -935,6 +967,37 @@ abstract class RDD[T: ClassTag]( jobResult } + /** + * Aggregates the elements of this RDD in a multi-level tree pattern. + * + * @param depth suggested depth of the tree (default: 2) + * @see [[org.apache.spark.rdd.RDD#aggregate]] + */ + def treeAggregate[U: ClassTag](zeroValue: U)( + seqOp: (U, T) => U, + combOp: (U, U) => U, + depth: Int = 2): U = { + require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") + if (partitions.size == 0) { + return Utils.clone(zeroValue, context.env.closureSerializer.newInstance()) + } + val cleanSeqOp = context.clean(seqOp) + val cleanCombOp = context.clean(combOp) + val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) + var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var numPartitions = partiallyAggregated.partitions.size + val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) + // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. + while (numPartitions > scale + numPartitions / scale) { + numPartitions /= scale + val curNumPartitions = numPartitions + partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => + iter.map((i % curNumPartitions, _)) + }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + } + partiallyAggregated.reduce(cleanCombOp) + } + /** * Return the number of elements in the RDD. */ @@ -964,7 +1027,7 @@ abstract class RDD[T: ClassTag]( * * Note that this method should only be used if the resulting map is expected to be small, as * the whole thing is loaded into the driver's memory. - * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which + * To handle very large results, consider using rdd.map(x => (x, 1L)).reduceByKey(_ + _), which * returns an RDD[T, Long] instead of a map. */ def countByValue()(implicit ord: Ordering[T] = null): Map[T, Long] = { @@ -1002,7 +1065,7 @@ abstract class RDD[T: ClassTag]( * Algorithmic Engineering of a State of The Art Cardinality Estimation Algorithm", available * here. * - * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` + * The relative accuracy is approximately `1.054 / sqrt(2^p)`. Setting a nonzero `sp > p` * would trigger sparse representation of registers, which may reduce the memory consumption * and increase accuracy when the cardinality is small. * @@ -1320,7 +1383,7 @@ abstract class RDD[T: ClassTag]( /** * Private API for changing an RDD's ClassTag. - * Used for internal Java <-> Scala API compatibility. + * Used for internal Java-Scala API compatibility. */ private[spark] def retag(cls: Class[T]): RDD[T] = { val classTag: ClassTag[T] = ClassTag.apply(cls) @@ -1329,7 +1392,7 @@ abstract class RDD[T: ClassTag]( /** * Private API for changing an RDD's ClassTag. - * Used for internal Java <-> Scala API compatibility. + * Used for internal Java-Scala API compatibility. */ private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = { this.mapPartitions(identity, preservesPartitioning = true)(classTag) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 36a6e6338faa6..be23056e7d423 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -17,10 +17,9 @@ package org.apache.spark.scheduler -import java.util.concurrent.{LinkedBlockingQueue, Semaphore} +import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.AsynchronousListenerBus /** * Asynchronously passes SparkListenerEvents to registered SparkListeners. @@ -29,113 +28,19 @@ import org.apache.spark.util.Utils * has started will events be actually propagated to all attached listeners. This listener bus * is stopped when it receives a SparkListenerShutdown event, which is posted using stop(). */ -private[spark] class LiveListenerBus extends SparkListenerBus with Logging { - - /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[SparkListenerEvent](EVENT_QUEUE_CAPACITY) - private var queueFullErrorMessageLogged = false - private var started = false - - // A counter that represents the number of events produced and consumed in the queue - private val eventLock = new Semaphore(0) - - private val listenerThread = new Thread("SparkListenerBus") { - setDaemon(true) - override def run(): Unit = Utils.logUncaughtExceptions { - while (true) { - eventLock.acquire() - // Atomically remove and process this event - LiveListenerBus.this.synchronized { - val event = eventQueue.poll - if (event == SparkListenerShutdown) { - // Get out of the while loop and shutdown the daemon thread - return - } - Option(event).foreach(postToAll) - } - } - } - } - - /** - * Start sending events to attached listeners. - * - * This first sends out all buffered events posted before this listener bus has started, then - * listens for any additional events asynchronously while the listener bus is still running. - * This should only be called once. - */ - def start() { - if (started) { - throw new IllegalStateException("Listener bus already started!") +private[spark] class LiveListenerBus + extends AsynchronousListenerBus[SparkListener, SparkListenerEvent]("SparkListenerBus") + with SparkListenerBus { + + private val logDroppedEvent = new AtomicBoolean(false) + + override def onDropEvent(event: SparkListenerEvent): Unit = { + if (logDroppedEvent.compareAndSet(false, true)) { + // Only log the following message once to avoid duplicated annoying logs. + logError("Dropping SparkListenerEvent because no remaining room in event queue. " + + "This likely means one of the SparkListeners is too slow and cannot keep up with " + + "the rate at which tasks are being started by the scheduler.") } - listenerThread.start() - started = true } - def post(event: SparkListenerEvent) { - val eventAdded = eventQueue.offer(event) - if (eventAdded) { - eventLock.release() - } else { - logQueueFullErrorMessage() - } - } - - /** - * For testing only. Wait until there are no more events in the queue, or until the specified - * time has elapsed. Return true if the queue has emptied and false is the specified time - * elapsed before the queue emptied. - */ - def waitUntilEmpty(timeoutMillis: Int): Boolean = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!queueIsEmpty) { - if (System.currentTimeMillis > finishTime) { - return false - } - /* Sleep rather than using wait/notify, because this is used only for testing and - * wait/notify add overhead in the general case. */ - Thread.sleep(10) - } - true - } - - /** - * For testing only. Return whether the listener daemon thread is still alive. - */ - def listenerThreadIsAlive: Boolean = synchronized { listenerThread.isAlive } - - /** - * Return whether the event queue is empty. - * - * The use of synchronized here guarantees that all events that once belonged to this queue - * have already been processed by all attached listeners, if this returns true. - */ - def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty } - - /** - * Log an error message to indicate that the event queue is full. Do this only once. - */ - private def logQueueFullErrorMessage(): Unit = { - if (!queueFullErrorMessageLogged) { - if (listenerThread.isAlive) { - logError("Dropping SparkListenerEvent because no remaining room in event queue. " + - "This likely means one of the SparkListeners is too slow and cannot keep up with" + - "the rate at which tasks are being started by the scheduler.") - } else { - logError("SparkListenerBus thread is dead! This means SparkListenerEvents have not" + - "been (and will no longer be) propagated to listeners for some time.") - } - queueFullErrorMessageLogged = true - } - } - - def stop() { - if (!started) { - throw new IllegalStateException("Attempted to stop a listener bus that has not yet started!") - } - post(SparkListenerShutdown) - listenerThread.join() - } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index e5d1eb767e109..dd28ddb31de1f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -91,11 +91,11 @@ case class SparkListenerBlockManagerRemoved(time: Long, blockManagerId: BlockMan case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent @DeveloperApi -case class SparkListenerExecutorAdded(executorId: String, executorInfo: ExecutorInfo) +case class SparkListenerExecutorAdded(time: Long, executorId: String, executorInfo: ExecutorInfo) extends SparkListenerEvent @DeveloperApi -case class SparkListenerExecutorRemoved(executorId: String) +case class SparkListenerExecutorRemoved(time: Long, executorId: String, reason: String) extends SparkListenerEvent /** @@ -116,9 +116,6 @@ case class SparkListenerApplicationStart(appName: String, appId: Option[String], @DeveloperApi case class SparkListenerApplicationEnd(time: Long) extends SparkListenerEvent -/** An event used in the listener to shutdown the listener daemon thread. */ -private[spark] case object SparkListenerShutdown extends SparkListenerEvent - /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index e700c6af542f4..fe8a19a2c0cb9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -17,78 +17,47 @@ package org.apache.spark.scheduler -import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.ListenerBus /** - * A SparkListenerEvent bus that relays events to its listeners + * A [[SparkListenerEvent]] bus that relays [[SparkListenerEvent]]s to its listeners */ -private[spark] trait SparkListenerBus extends Logging { - - // SparkListeners attached to this event bus - protected val sparkListeners = new ArrayBuffer[SparkListener] - with mutable.SynchronizedBuffer[SparkListener] - - def addListener(listener: SparkListener) { - sparkListeners += listener - } +private[spark] trait SparkListenerBus extends ListenerBus[SparkListener, SparkListenerEvent] { - /** - * Post an event to all attached listeners. - * This does nothing if the event is SparkListenerShutdown. - */ - def postToAll(event: SparkListenerEvent) { + override def onPostEvent(listener: SparkListener, event: SparkListenerEvent): Unit = { event match { case stageSubmitted: SparkListenerStageSubmitted => - foreachListener(_.onStageSubmitted(stageSubmitted)) + listener.onStageSubmitted(stageSubmitted) case stageCompleted: SparkListenerStageCompleted => - foreachListener(_.onStageCompleted(stageCompleted)) + listener.onStageCompleted(stageCompleted) case jobStart: SparkListenerJobStart => - foreachListener(_.onJobStart(jobStart)) + listener.onJobStart(jobStart) case jobEnd: SparkListenerJobEnd => - foreachListener(_.onJobEnd(jobEnd)) + listener.onJobEnd(jobEnd) case taskStart: SparkListenerTaskStart => - foreachListener(_.onTaskStart(taskStart)) + listener.onTaskStart(taskStart) case taskGettingResult: SparkListenerTaskGettingResult => - foreachListener(_.onTaskGettingResult(taskGettingResult)) + listener.onTaskGettingResult(taskGettingResult) case taskEnd: SparkListenerTaskEnd => - foreachListener(_.onTaskEnd(taskEnd)) + listener.onTaskEnd(taskEnd) case environmentUpdate: SparkListenerEnvironmentUpdate => - foreachListener(_.onEnvironmentUpdate(environmentUpdate)) + listener.onEnvironmentUpdate(environmentUpdate) case blockManagerAdded: SparkListenerBlockManagerAdded => - foreachListener(_.onBlockManagerAdded(blockManagerAdded)) + listener.onBlockManagerAdded(blockManagerAdded) case blockManagerRemoved: SparkListenerBlockManagerRemoved => - foreachListener(_.onBlockManagerRemoved(blockManagerRemoved)) + listener.onBlockManagerRemoved(blockManagerRemoved) case unpersistRDD: SparkListenerUnpersistRDD => - foreachListener(_.onUnpersistRDD(unpersistRDD)) + listener.onUnpersistRDD(unpersistRDD) case applicationStart: SparkListenerApplicationStart => - foreachListener(_.onApplicationStart(applicationStart)) + listener.onApplicationStart(applicationStart) case applicationEnd: SparkListenerApplicationEnd => - foreachListener(_.onApplicationEnd(applicationEnd)) + listener.onApplicationEnd(applicationEnd) case metricsUpdate: SparkListenerExecutorMetricsUpdate => - foreachListener(_.onExecutorMetricsUpdate(metricsUpdate)) + listener.onExecutorMetricsUpdate(metricsUpdate) case executorAdded: SparkListenerExecutorAdded => - foreachListener(_.onExecutorAdded(executorAdded)) + listener.onExecutorAdded(executorAdded) case executorRemoved: SparkListenerExecutorRemoved => - foreachListener(_.onExecutorRemoved(executorRemoved)) - case SparkListenerShutdown => - } - } - - /** - * Apply the given function to all attached listeners, catching and logging any exception. - */ - private def foreachListener(f: SparkListener => Unit): Unit = { - sparkListeners.foreach { listener => - try { - f(listener) - } catch { - case e: Exception => - logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) - } + listener.onExecutorRemoved(executorRemoved) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5786d367464f4..103a5c053c289 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -108,7 +108,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste logDebug(s"Decremented number of pending executors ($numPendingExecutors left)") } } - listenerBus.post(SparkListenerExecutorAdded(executorId, data)) + listenerBus.post( + SparkListenerExecutorAdded(System.currentTimeMillis(), executorId, data)) makeOffers() } @@ -216,7 +217,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste totalCoreCount.addAndGet(-executorInfo.totalCores) totalRegisteredExecutors.addAndGet(-1) scheduler.executorLost(executorId, SlaveLost(reason)) - listenerBus.post(SparkListenerExecutorRemoved(executorId)) + listenerBus.post( + SparkListenerExecutorRemoved(System.currentTimeMillis(), executorId, reason)) case None => logError(s"Asked to remove non-existent executor $executorId") } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 79c9051e88691..c3c546be6da15 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -269,7 +269,7 @@ private[spark] class MesosSchedulerBackend( mesosTasks.foreach { case (slaveId, tasks) => slaveIdToWorkerOffer.get(slaveId).foreach(o => - listenerBus.post(SparkListenerExecutorAdded(slaveId, + listenerBus.post(SparkListenerExecutorAdded(System.currentTimeMillis(), slaveId, new ExecutorInfo(o.host, o.cores))) ) d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters) @@ -327,7 +327,7 @@ private[spark] class MesosSchedulerBackend( synchronized { if (status.getState == MesosTaskState.TASK_LOST && taskIdToSlaveId.contains(tid)) { // We lost the executor on this slave, so remember that it's gone - removeExecutor(taskIdToSlaveId(tid)) + removeExecutor(taskIdToSlaveId(tid), "Lost executor") } if (isFinished(status.getState)) { taskIdToSlaveId.remove(tid) @@ -359,9 +359,9 @@ private[spark] class MesosSchedulerBackend( /** * Remove executor associated with slaveId in a thread safe manner. */ - private def removeExecutor(slaveId: String) = { + private def removeExecutor(slaveId: String, reason: String) = { synchronized { - listenerBus.post(SparkListenerExecutorRemoved(slaveId)) + listenerBus.post(SparkListenerExecutorRemoved(System.currentTimeMillis(), slaveId, reason)) slaveIdsWithExecutors -= slaveId } } @@ -369,7 +369,7 @@ private[spark] class MesosSchedulerBackend( private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) { inClassLoader() { logInfo("Mesos slave lost: " + slaveId.getValue) - removeExecutor(slaveId.getValue) + removeExecutor(slaveId.getValue, reason.toString) scheduler.executorLost(slaveId.getValue, reason) } } diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index fa8a337ad63a8..1baa0e009f3ae 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -27,7 +27,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.ByteBufferInputStream import org.apache.spark.util.Utils -private[spark] class JavaSerializationStream(out: OutputStream, counterReset: Int) +private[spark] class JavaSerializationStream( + out: OutputStream, counterReset: Int, extraDebugInfo: Boolean) extends SerializationStream { private val objOut = new ObjectOutputStream(out) private var counter = 0 @@ -39,7 +40,12 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In * the stream 'resets' object class descriptions have to be re-written) */ def writeObject[T: ClassTag](t: T): SerializationStream = { - objOut.writeObject(t) + try { + objOut.writeObject(t) + } catch { + case e: NotSerializableException if extraDebugInfo => + throw SerializationDebugger.improveException(t, e) + } counter += 1 if (counterReset > 0 && counter >= counterReset) { objOut.reset() @@ -64,7 +70,8 @@ extends DeserializationStream { } -private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoader: ClassLoader) +private[spark] class JavaSerializerInstance( + counterReset: Int, extraDebugInfo: Boolean, defaultClassLoader: ClassLoader) extends SerializerInstance { override def serialize[T: ClassTag](t: T): ByteBuffer = { @@ -88,7 +95,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade } override def serializeStream(s: OutputStream): SerializationStream = { - new JavaSerializationStream(s, counterReset) + new JavaSerializationStream(s, counterReset, extraDebugInfo) } override def deserializeStream(s: InputStream): DeserializationStream = { @@ -111,17 +118,20 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade @DeveloperApi class JavaSerializer(conf: SparkConf) extends Serializer with Externalizable { private var counterReset = conf.getInt("spark.serializer.objectStreamReset", 100) + private var extraDebugInfo = conf.getBoolean("spark.serializer.extraDebugInfo", true) override def newInstance(): SerializerInstance = { val classLoader = defaultClassLoader.getOrElse(Thread.currentThread.getContextClassLoader) - new JavaSerializerInstance(counterReset, classLoader) + new JavaSerializerInstance(counterReset, extraDebugInfo, classLoader) } override def writeExternal(out: ObjectOutput): Unit = Utils.tryOrIOException { out.writeInt(counterReset) + out.writeBoolean(extraDebugInfo) } override def readExternal(in: ObjectInput): Unit = Utils.tryOrIOException { counterReset = in.readInt() + extraDebugInfo = in.readBoolean() } } diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala new file mode 100644 index 0000000000000..cecb992579655 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/serializer/SerializationDebugger.scala @@ -0,0 +1,307 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{NotSerializableException, ObjectOutput, ObjectStreamClass, ObjectStreamField} +import java.lang.reflect.{Field, Method} +import java.security.AccessController + +import scala.annotation.tailrec +import scala.collection.mutable + +import org.apache.spark.Logging + +private[serializer] object SerializationDebugger extends Logging { + + /** + * Improve the given NotSerializableException with the serialization path leading from the given + * object to the problematic object. This is turned off automatically if + * `sun.io.serialization.extendedDebugInfo` flag is turned on for the JVM. + */ + def improveException(obj: Any, e: NotSerializableException): NotSerializableException = { + if (enableDebugging && reflect != null) { + new NotSerializableException( + e.getMessage + "\nSerialization stack:\n" + find(obj).map("\t- " + _).mkString("\n")) + } else { + e + } + } + + /** + * Find the path leading to a not serializable object. This method is modeled after OpenJDK's + * serialization mechanism, and handles the following cases: + * - primitives + * - arrays of primitives + * - arrays of non-primitive objects + * - Serializable objects + * - Externalizable objects + * - writeReplace + * + * It does not yet handle writeObject override, but that shouldn't be too hard to do either. + */ + def find(obj: Any): List[String] = { + new SerializationDebugger().visit(obj, List.empty) + } + + private[serializer] var enableDebugging: Boolean = { + !AccessController.doPrivileged(new sun.security.action.GetBooleanAction( + "sun.io.serialization.extendedDebugInfo")).booleanValue() + } + + private class SerializationDebugger { + + /** A set to track the list of objects we have visited, to avoid cycles in the graph. */ + private val visited = new mutable.HashSet[Any] + + /** + * Visit the object and its fields and stop when we find an object that is not serializable. + * Return the path as a list. If everything can be serialized, return an empty list. + */ + def visit(o: Any, stack: List[String]): List[String] = { + if (o == null) { + List.empty + } else if (visited.contains(o)) { + List.empty + } else { + visited += o + o match { + // Primitive value, string, and primitive arrays are always serializable + case _ if o.getClass.isPrimitive => List.empty + case _: String => List.empty + case _ if o.getClass.isArray && o.getClass.getComponentType.isPrimitive => List.empty + + // Traverse non primitive array. + case a: Array[_] if o.getClass.isArray && !o.getClass.getComponentType.isPrimitive => + val elem = s"array (class ${a.getClass.getName}, size ${a.length})" + visitArray(o.asInstanceOf[Array[_]], elem :: stack) + + case e: java.io.Externalizable => + val elem = s"externalizable object (class ${e.getClass.getName}, $e)" + visitExternalizable(e, elem :: stack) + + case s: Object with java.io.Serializable => + val elem = s"object (class ${s.getClass.getName}, $s)" + visitSerializable(s, elem :: stack) + + case _ => + // Found an object that is not serializable! + s"object not serializable (class: ${o.getClass.getName}, value: $o)" :: stack + } + } + } + + private def visitArray(o: Array[_], stack: List[String]): List[String] = { + var i = 0 + while (i < o.length) { + val childStack = visit(o(i), s"element of array (index: $i)" :: stack) + if (childStack.nonEmpty) { + return childStack + } + i += 1 + } + return List.empty + } + + private def visitExternalizable(o: java.io.Externalizable, stack: List[String]): List[String] = + { + val fieldList = new ListObjectOutput + o.writeExternal(fieldList) + val childObjects = fieldList.outputArray + var i = 0 + while (i < childObjects.length) { + val childStack = visit(childObjects(i), "writeExternal data" :: stack) + if (childStack.nonEmpty) { + return childStack + } + i += 1 + } + return List.empty + } + + private def visitSerializable(o: Object, stack: List[String]): List[String] = { + // An object contains multiple slots in serialization. + // Get the slots and visit fields in all of them. + val (finalObj, desc) = findObjectAndDescriptor(o) + val slotDescs = desc.getSlotDescs + var i = 0 + while (i < slotDescs.length) { + val slotDesc = slotDescs(i) + if (slotDesc.hasWriteObjectMethod) { + // TODO: Handle classes that specify writeObject method. + } else { + val fields: Array[ObjectStreamField] = slotDesc.getFields + val objFieldValues: Array[Object] = new Array[Object](slotDesc.getNumObjFields) + val numPrims = fields.length - objFieldValues.length + desc.getObjFieldValues(finalObj, objFieldValues) + + var j = 0 + while (j < objFieldValues.length) { + val fieldDesc = fields(numPrims + j) + val elem = s"field (class: ${slotDesc.getName}" + + s", name: ${fieldDesc.getName}" + + s", type: ${fieldDesc.getType})" + val childStack = visit(objFieldValues(j), elem :: stack) + if (childStack.nonEmpty) { + return childStack + } + j += 1 + } + + } + i += 1 + } + return List.empty + } + } + + /** + * Find the object to serialize and the associated [[ObjectStreamClass]]. This method handles + * writeReplace in Serializable. It starts with the object itself, and keeps calling the + * writeReplace method until there is no more + */ + @tailrec + private def findObjectAndDescriptor(o: Object): (Object, ObjectStreamClass) = { + val cl = o.getClass + val desc = ObjectStreamClass.lookupAny(cl) + if (!desc.hasWriteReplaceMethod) { + (o, desc) + } else { + // write place + findObjectAndDescriptor(desc.invokeWriteReplace(o)) + } + } + + /** + * A dummy [[ObjectOutput]] that simply saves the list of objects written by a writeExternal + * call, and returns them through `outputArray`. + */ + private class ListObjectOutput extends ObjectOutput { + private val output = new mutable.ArrayBuffer[Any] + def outputArray: Array[Any] = output.toArray + override def writeObject(o: Any): Unit = output += o + override def flush(): Unit = {} + override def write(i: Int): Unit = {} + override def write(bytes: Array[Byte]): Unit = {} + override def write(bytes: Array[Byte], i: Int, i1: Int): Unit = {} + override def close(): Unit = {} + override def writeFloat(v: Float): Unit = {} + override def writeChars(s: String): Unit = {} + override def writeDouble(v: Double): Unit = {} + override def writeUTF(s: String): Unit = {} + override def writeShort(i: Int): Unit = {} + override def writeInt(i: Int): Unit = {} + override def writeBoolean(b: Boolean): Unit = {} + override def writeBytes(s: String): Unit = {} + override def writeChar(i: Int): Unit = {} + override def writeLong(l: Long): Unit = {} + override def writeByte(i: Int): Unit = {} + } + + /** An implicit class that allows us to call private methods of ObjectStreamClass. */ + implicit class ObjectStreamClassMethods(val desc: ObjectStreamClass) extends AnyVal { + def getSlotDescs: Array[ObjectStreamClass] = { + reflect.GetClassDataLayout.invoke(desc).asInstanceOf[Array[Object]].map { + classDataSlot => reflect.DescField.get(classDataSlot).asInstanceOf[ObjectStreamClass] + } + } + + def hasWriteObjectMethod: Boolean = { + reflect.HasWriteObjectMethod.invoke(desc).asInstanceOf[Boolean] + } + + def hasWriteReplaceMethod: Boolean = { + reflect.HasWriteReplaceMethod.invoke(desc).asInstanceOf[Boolean] + } + + def invokeWriteReplace(obj: Object): Object = { + reflect.InvokeWriteReplace.invoke(desc, obj) + } + + def getNumObjFields: Int = { + reflect.GetNumObjFields.invoke(desc).asInstanceOf[Int] + } + + def getObjFieldValues(obj: Object, out: Array[Object]): Unit = { + reflect.GetObjFieldValues.invoke(desc, obj, out) + } + } + + /** + * Object to hold all the reflection objects. If we run on a JVM that we cannot understand, + * this field will be null and this the debug helper should be disabled. + */ + private val reflect: ObjectStreamClassReflection = try { + new ObjectStreamClassReflection + } catch { + case e: Exception => + logWarning("Cannot find private methods using reflection", e) + null + } + + private class ObjectStreamClassReflection { + /** ObjectStreamClass.getClassDataLayout */ + val GetClassDataLayout: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("getClassDataLayout") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.hasWriteObjectMethod */ + val HasWriteObjectMethod: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteObjectMethod") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.hasWriteReplaceMethod */ + val HasWriteReplaceMethod: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("hasWriteReplaceMethod") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.invokeWriteReplace */ + val InvokeWriteReplace: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("invokeWriteReplace", classOf[Object]) + f.setAccessible(true) + f + } + + /** ObjectStreamClass.getNumObjFields */ + val GetNumObjFields: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod("getNumObjFields") + f.setAccessible(true) + f + } + + /** ObjectStreamClass.getObjFieldValues */ + val GetObjFieldValues: Method = { + val f = classOf[ObjectStreamClass].getDeclaredMethod( + "getObjFieldValues", classOf[Object], classOf[Array[Object]]) + f.setAccessible(true) + f + } + + /** ObjectStreamClass$ClassDataSlot.desc field */ + val DescField: Field = { + val f = Class.forName("java.io.ObjectStreamClass$ClassDataSlot").getDeclaredField("desc") + f.setAccessible(true) + f + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala new file mode 100644 index 0000000000000..18c627e8c7a15 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent._ +import java.util.concurrent.atomic.AtomicBoolean + +import com.google.common.annotations.VisibleForTesting + +/** + * Asynchronously passes events to registered listeners. + * + * Until `start()` is called, all posted events are only buffered. Only after this listener bus + * has started will events be actually propagated to all attached listeners. This listener bus + * is stopped when `stop()` is called, and it will drop further events after stopping. + * + * @param name name of the listener bus, will be the name of the listener thread. + * @tparam L type of listener + * @tparam E type of event + */ +private[spark] abstract class AsynchronousListenerBus[L <: AnyRef, E](name: String) + extends ListenerBus[L, E] { + + self => + + /* Cap the capacity of the event queue so we get an explicit error (rather than + * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ + private val EVENT_QUEUE_CAPACITY = 10000 + private val eventQueue = new LinkedBlockingQueue[E](EVENT_QUEUE_CAPACITY) + + // Indicate if `start()` is called + private val started = new AtomicBoolean(false) + // Indicate if `stop()` is called + private val stopped = new AtomicBoolean(false) + + // Indicate if we are processing some event + // Guarded by `self` + private var processingEvent = false + + // A counter that represents the number of events produced and consumed in the queue + private val eventLock = new Semaphore(0) + + private val listenerThread = new Thread(name) { + setDaemon(true) + override def run(): Unit = Utils.logUncaughtExceptions { + while (true) { + eventLock.acquire() + self.synchronized { + processingEvent = true + } + try { + val event = eventQueue.poll + if (event == null) { + // Get out of the while loop and shutdown the daemon thread + if (!stopped.get) { + throw new IllegalStateException("Polling `null` from eventQueue means" + + " the listener bus has been stopped. So `stopped` must be true") + } + return + } + postToAll(event) + } finally { + self.synchronized { + processingEvent = false + } + } + } + } + } + + /** + * Start sending events to attached listeners. + * + * This first sends out all buffered events posted before this listener bus has started, then + * listens for any additional events asynchronously while the listener bus is still running. + * This should only be called once. + */ + def start() { + if (started.compareAndSet(false, true)) { + listenerThread.start() + } else { + throw new IllegalStateException(s"$name already started!") + } + } + + def post(event: E) { + if (stopped.get) { + // Drop further events to make `listenerThread` exit ASAP + logError(s"$name has already stopped! Dropping event $event") + return + } + val eventAdded = eventQueue.offer(event) + if (eventAdded) { + eventLock.release() + } else { + onDropEvent(event) + } + } + + /** + * For testing only. Wait until there are no more events in the queue, or until the specified + * time has elapsed. Return true if the queue has emptied and false is the specified time + * elapsed before the queue emptied. + */ + @VisibleForTesting + def waitUntilEmpty(timeoutMillis: Int): Boolean = { + val finishTime = System.currentTimeMillis + timeoutMillis + while (!queueIsEmpty) { + if (System.currentTimeMillis > finishTime) { + return false + } + /* Sleep rather than using wait/notify, because this is used only for testing and + * wait/notify add overhead in the general case. */ + Thread.sleep(10) + } + true + } + + /** + * For testing only. Return whether the listener daemon thread is still alive. + */ + @VisibleForTesting + def listenerThreadIsAlive: Boolean = listenerThread.isAlive + + /** + * Return whether the event queue is empty. + * + * The use of synchronized here guarantees that all events that once belonged to this queue + * have already been processed by all attached listeners, if this returns true. + */ + private def queueIsEmpty: Boolean = synchronized { eventQueue.isEmpty && !processingEvent } + + /** + * Stop the listener bus. It will wait until the queued events have been processed, but drop the + * new events after stopping. + */ + def stop() { + if (!started.get()) { + throw new IllegalStateException(s"Attempted to stop $name that has not yet started!") + } + if (stopped.compareAndSet(false, true)) { + // Call eventLock.release() so that listenerThread will poll `null` from `eventQueue` and know + // `stop` is called. + eventLock.release() + listenerThread.join() + } else { + // Keep quiet + } + } + + /** + * If the event queue exceeds its capacity, the new events will be dropped. The subclasses will be + * notified with the dropped events. + * + * Note: `onDropEvent` can be called in any thread. + */ + def onDropEvent(event: E): Unit +} diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index f896b5072e4fa..414bc49a57f8a 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -91,7 +91,6 @@ private[spark] object JsonProtocol { case executorRemoved: SparkListenerExecutorRemoved => executorRemovedToJson(executorRemoved) // These aren't used, but keeps compiler happy - case SparkListenerShutdown => JNothing case SparkListenerExecutorMetricsUpdate(_, _) => JNothing } } @@ -204,13 +203,16 @@ private[spark] object JsonProtocol { def executorAddedToJson(executorAdded: SparkListenerExecutorAdded): JValue = { ("Event" -> Utils.getFormattedClassName(executorAdded)) ~ + ("Timestamp" -> executorAdded.time) ~ ("Executor ID" -> executorAdded.executorId) ~ ("Executor Info" -> executorInfoToJson(executorAdded.executorInfo)) } def executorRemovedToJson(executorRemoved: SparkListenerExecutorRemoved): JValue = { ("Event" -> Utils.getFormattedClassName(executorRemoved)) ~ - ("Executor ID" -> executorRemoved.executorId) + ("Timestamp" -> executorRemoved.time) ~ + ("Executor ID" -> executorRemoved.executorId) ~ + ("Removed Reason" -> executorRemoved.reason) } /** ------------------------------------------------------------------- * @@ -554,14 +556,17 @@ private[spark] object JsonProtocol { } def executorAddedFromJson(json: JValue): SparkListenerExecutorAdded = { + val time = (json \ "Timestamp").extract[Long] val executorId = (json \ "Executor ID").extract[String] val executorInfo = executorInfoFromJson(json \ "Executor Info") - SparkListenerExecutorAdded(executorId, executorInfo) + SparkListenerExecutorAdded(time, executorId, executorInfo) } def executorRemovedFromJson(json: JValue): SparkListenerExecutorRemoved = { + val time = (json \ "Timestamp").extract[Long] val executorId = (json \ "Executor ID").extract[String] - SparkListenerExecutorRemoved(executorId) + val reason = (json \ "Removed Reason").extract[String] + SparkListenerExecutorRemoved(time, executorId, reason) } /** --------------------------------------------------------------------- * diff --git a/core/src/main/scala/org/apache/spark/util/ListenerBus.scala b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala new file mode 100644 index 0000000000000..bd0aa4dc4650f --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/ListenerBus.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.concurrent.CopyOnWriteArrayList + +import scala.util.control.NonFatal + +import org.apache.spark.Logging + +/** + * An event bus which posts events to its listeners. + */ +private[spark] trait ListenerBus[L <: AnyRef, E] extends Logging { + + private val listeners = new CopyOnWriteArrayList[L] + + /** + * Add a listener to listen events. This method is thread-safe and can be called in any thread. + */ + final def addListener(listener: L) { + listeners.add(listener) + } + + /** + * Post the event to all registered listeners. The `postToAll` caller should guarantee calling + * `postToAll` in the same thread for all events. + */ + final def postToAll(event: E): Unit = { + // JavaConversions will create a JIterableWrapper if we use some Scala collection functions. + // However, this method will be called frequently. To avoid the wrapper cost, here ewe use + // Java Iterator directly. + val iter = listeners.iterator + while (iter.hasNext) { + val listener = iter.next() + try { + onPostEvent(listener, event) + } catch { + case NonFatal(e) => + logError(s"Listener ${Utils.getFormattedClassName(listener)} threw an exception", e) + } + } + } + + /** + * Post an event to the specified listener. `onPostEvent` is guaranteed to be called in the same + * thread. + */ + def onPostEvent(listener: L, event: E): Unit + +} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2c04e4ddfbcb7..86ac307fc84ba 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -410,10 +410,10 @@ private[spark] object Utils extends Logging { // Decompress the file if it's a .tar or .tar.gz if (fileName.endsWith(".tar.gz") || fileName.endsWith(".tgz")) { logInfo("Untarring " + fileName) - Utils.execute(Seq("tar", "-xzf", fileName), targetDir) + executeAndGetOutput(Seq("tar", "-xzf", fileName), targetDir) } else if (fileName.endsWith(".tar")) { logInfo("Untarring " + fileName) - Utils.execute(Seq("tar", "-xf", fileName), targetDir) + executeAndGetOutput(Seq("tar", "-xf", fileName), targetDir) } // Make the file executable - That's necessary for scripts FileUtil.chmod(targetFile.getAbsolutePath, "a+x") @@ -956,25 +956,25 @@ private[spark] object Utils extends Logging { } /** - * Execute a command in the given working directory, throwing an exception if it completes - * with an exit code other than 0. + * Execute a command and return the process running the command. */ - def execute(command: Seq[String], workingDir: File) { - val process = new ProcessBuilder(command: _*) - .directory(workingDir) - .redirectErrorStream(true) - .start() - new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines()) { - System.err.println(line) - } - } - }.start() - val exitCode = process.waitFor() - if (exitCode != 0) { - throw new SparkException("Process " + command + " exited with code " + exitCode) + def executeCommand( + command: Seq[String], + workingDir: File = new File("."), + extraEnvironment: Map[String, String] = Map.empty, + redirectStderr: Boolean = true): Process = { + val builder = new ProcessBuilder(command: _*).directory(workingDir) + val environment = builder.environment() + for ((key, value) <- extraEnvironment) { + environment.put(key, value) + } + val process = builder.start() + if (redirectStderr) { + val threadName = "redirect stderr for command " + command(0) + def log(s: String): Unit = logInfo(s) + processStreamByLine(threadName, process.getErrorStream, log) } + process } /** @@ -983,31 +983,13 @@ private[spark] object Utils extends Logging { def executeAndGetOutput( command: Seq[String], workingDir: File = new File("."), - extraEnvironment: Map[String, String] = Map.empty): String = { - val builder = new ProcessBuilder(command: _*) - .directory(workingDir) - val environment = builder.environment() - for ((key, value) <- extraEnvironment) { - environment.put(key, value) - } - - val process = builder.start() - new Thread("read stderr for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getErrorStream).getLines()) { - logInfo(line) - } - } - }.start() + extraEnvironment: Map[String, String] = Map.empty, + redirectStderr: Boolean = true): String = { + val process = executeCommand(command, workingDir, extraEnvironment, redirectStderr) val output = new StringBuffer - val stdoutThread = new Thread("read stdout for " + command(0)) { - override def run() { - for (line <- Source.fromInputStream(process.getInputStream).getLines()) { - output.append(line) - } - } - } - stdoutThread.start() + val threadName = "read stdout for " + command(0) + def appendToOutput(s: String): Unit = output.append(s) + val stdoutThread = processStreamByLine(threadName, process.getInputStream, appendToOutput) val exitCode = process.waitFor() stdoutThread.join() // Wait for it to finish reading output if (exitCode != 0) { @@ -1017,6 +999,25 @@ private[spark] object Utils extends Logging { output.toString } + /** + * Return and start a daemon thread that processes the content of the input stream line by line. + */ + def processStreamByLine( + threadName: String, + inputStream: InputStream, + processLine: String => Unit): Thread = { + val t = new Thread(threadName) { + override def run() { + for (line <- Source.fromInputStream(inputStream).getLines()) { + processLine(line) + } + } + } + t.setDaemon(true) + t.start() + t + } + /** * Execute a block of code that evaluates to Unit, forwarding any uncaught exceptions to the * default UncaughtExceptionHandler diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 004de05c10ee1..b16a1e9460286 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -492,6 +492,36 @@ public Integer call(Integer a, Integer b) { Assert.assertEquals(33, sum); } + @Test + public void treeReduce() { + JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeReduce(add, depth); + Assert.assertEquals(-5, sum); + } + } + + @Test + public void treeAggregate() { + JavaRDD rdd = sc.parallelize(Arrays.asList(-5, -4, -3, -2, -1, 1, 2, 3, 4), 10); + Function2 add = new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }; + for (int depth = 1; depth <= 10; depth++) { + int sum = rdd.treeAggregate(0, add, add, depth); + Assert.assertEquals(-5, sum); + } + } + @SuppressWarnings("unchecked") @Test public void aggregateByKey() { diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 8a54360e81795..9bd5dfec8703a 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -28,31 +28,29 @@ import org.apache.spark.util.Utils class DriverSuite extends FunSuite with Timeouts { - test("driver should exit after finishing") { + test("driver should exit after finishing without cleanup (SPARK-530)") { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" - val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) + val masters = Table("master", "local", "local-cluster[2,1,512]") forAll(masters) { (master: String) => - failAfter(60 seconds) { - Utils.executeAndGetOutput( - Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), - new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - } + val process = Utils.executeCommand( + Seq(s"$sparkHome/bin/spark-class", "org.apache.spark.DriverWithoutCleanup", master), + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + failAfter(60 seconds) { process.waitFor() } + // Ensure we still kill the process in case it timed out + process.destroy() } } } /** - * Program that creates a Spark driver but doesn't call SparkContext.stop() or - * Sys.exit() after finishing. + * Program that creates a Spark driver but doesn't call SparkContext#stop() or + * sys.exit() after finishing. */ object DriverWithoutCleanup { def main(args: Array[String]) { Utils.configTestLog4j("INFO") - // Bind the web UI to an ephemeral port in order to avoid conflicts with other tests running on - // the same machine (we shouldn't just disable the UI here, since that might mask bugs): - val conf = new SparkConf().set("spark.ui.port", "0") + val conf = new SparkConf val sc = new SparkContext(args(0), "DriverWithoutCleanup", conf) sc.parallelize(1 to 100, 4).count() } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index 8ae4f243ec1ae..bbed8ddc6bafc 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -149,7 +149,7 @@ class SparkContextSchedulerCreationSuite } test("yarn-client") { - testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnScheduler") } def testMesos(master: String, expectedClass: Class[_], coarse: Boolean) { diff --git a/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala new file mode 100644 index 0000000000000..f8c39326145e1 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/api/python/SerDeUtilSuite.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import org.scalatest.FunSuite + +import org.apache.spark.SharedSparkContext + +class SerDeUtilSuite extends FunSuite with SharedSparkContext { + + test("Converting an empty pair RDD to python does not throw an exception (SPARK-5441)") { + val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) + SerDeUtil.pairRDDToPython(emptyRdd, 10) + } + + test("Converting an empty python RDD to pair RDD does not throw an exception (SPARK-5441)") { + val emptyRdd = sc.makeRDD(Seq[(Any, Any)]()) + val javaRdd = emptyRdd.toJavaRDD() + val pythonRdd = SerDeUtil.javaToPython(javaRdd) + SerDeUtil.pythonToPairRDD(pythonRdd, false) + } +} + diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 065b7534cece6..82628ad3abd99 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,25 +21,28 @@ import java.io._ import scala.collection.mutable.ArrayBuffer +import org.scalatest.FunSuite +import org.scalatest.Matchers +import org.scalatest.concurrent.Timeouts +import org.scalatest.time.SpanSugar._ + import org.apache.spark._ import org.apache.spark.deploy.SparkSubmit._ import org.apache.spark.util.{ResetSystemProperties, Utils} -import org.scalatest.FunSuite -import org.scalatest.Matchers // Note: this suite mixes in ResetSystemProperties because SparkSubmit.main() sets a bunch // of properties that neeed to be cleared after tests. -class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties { +class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties with Timeouts { def beforeAll() { System.setProperty("spark.testing", "true") } - val noOpOutputStream = new OutputStream { + private val noOpOutputStream = new OutputStream { def write(b: Int) = {} } /** Simple PrintStream that reads data into a buffer */ - class BufferPrintStream extends PrintStream(noOpOutputStream) { + private class BufferPrintStream extends PrintStream(noOpOutputStream) { var lineBuffer = ArrayBuffer[String]() override def println(line: String) { lineBuffer += line @@ -47,7 +50,7 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties } /** Returns true if the script exits and the given search string is printed. */ - def testPrematureExit(input: Array[String], searchString: String) = { + private def testPrematureExit(input: Array[String], searchString: String) = { val printStream = new BufferPrintStream() SparkSubmit.printStream = printStream @@ -290,7 +293,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), "--name", "testApp", "--master", "local", - "--conf", "spark.ui.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -305,7 +307,6 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties "--name", "testApp", "--master", "local-cluster[2,1,512]", "--jars", jarsString, - "--conf", "spark.ui.enabled=false", unusedJar.toString) runSparkSubmit(args) } @@ -430,15 +431,18 @@ class SparkSubmitSuite extends FunSuite with Matchers with ResetSystemProperties } // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - def runSparkSubmit(args: Seq[String]): String = { + private def runSparkSubmit(args: Seq[String]): Unit = { val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - Utils.executeAndGetOutput( + val process = Utils.executeCommand( Seq("./bin/spark-submit") ++ args, new File(sparkHome), Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + failAfter(60 seconds) { process.waitFor() } + // Ensure we still kill the process in case it timed out + process.destroy() } - def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { + private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { val tmpDir = Utils.createTempDir() val defaultsConf = new File(tmpDir.getAbsolutePath, "spark-defaults.conf") diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 10a39990f80ce..81db66ae17464 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -26,7 +26,16 @@ import org.scalatest.FunSuite import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, JobConf, + LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, Reporter, + TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, + CombineFileSplit => OldCombineFileSplit, CombineFileRecordReader => OldCombineFileRecordReader} +import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, + TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, + CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, + FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.spark.SharedSparkContext import org.apache.spark.deploy.SparkHadoopUtil @@ -202,7 +211,7 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { val fs = FileSystem.getLocal(new Configuration()) val outPath = new Path(fs.getWorkingDirectory, "outdir") - if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(outPath, fs.getConf).isDefined) { + if (SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback().isDefined) { val taskBytesWritten = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { @@ -225,4 +234,88 @@ class InputOutputMetricsSuite extends FunSuite with SharedSparkContext { } } } + + test("input metrics with old CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.hadoopFile(tmpFilePath, classOf[OldCombineTextInputFormat], classOf[LongWritable], + classOf[Text], 2).count() + } + assert(bytesRead >= tmpFile.length()) + } + + test("input metrics with new CombineFileInputFormat") { + val bytesRead = runAndReturnBytesRead { + sc.newAPIHadoopFile(tmpFilePath, classOf[NewCombineTextInputFormat], classOf[LongWritable], + classOf[Text], new Configuration()).count() + } + assert(bytesRead >= tmpFile.length()) + } +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class OldCombineTextInputFormat extends OldCombineFileInputFormat[LongWritable, Text] { + override def getRecordReader(split: OldInputSplit, conf: JobConf, reporter: Reporter) + : OldRecordReader[LongWritable, Text] = { + new OldCombineFileRecordReader[LongWritable, Text](conf, + split.asInstanceOf[OldCombineFileSplit], reporter, classOf[OldCombineTextRecordReaderWrapper] + .asInstanceOf[Class[OldRecordReader[LongWritable, Text]]]) + } +} + +class OldCombineTextRecordReaderWrapper( + split: OldCombineFileSplit, + conf: Configuration, + reporter: Reporter, + idx: Integer) extends OldRecordReader[LongWritable, Text] { + + val fileSplit = new OldFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate: OldLineRecordReader = new OldTextInputFormat().getRecordReader(fileSplit, + conf.asInstanceOf[JobConf], reporter).asInstanceOf[OldLineRecordReader] + + override def next(key: LongWritable, value: Text): Boolean = delegate.next(key, value) + override def createKey(): LongWritable = delegate.createKey() + override def createValue(): Text = delegate.createValue() + override def getPos(): Long = delegate.getPos + override def close(): Unit = delegate.close() + override def getProgress(): Float = delegate.getProgress +} + +/** + * Hadoop 2 has a version of this, but we can't use it for backwards compatibility + */ +class NewCombineTextInputFormat extends NewCombineFileInputFormat[LongWritable,Text] { + def createRecordReader(split: NewInputSplit, context: TaskAttemptContext) + : NewRecordReader[LongWritable, Text] = { + new NewCombineFileRecordReader[LongWritable,Text](split.asInstanceOf[NewCombineFileSplit], + context, classOf[NewCombineTextRecordReaderWrapper]) + } } + +class NewCombineTextRecordReaderWrapper( + split: NewCombineFileSplit, + context: TaskAttemptContext, + idx: Integer) extends NewRecordReader[LongWritable, Text] { + + val fileSplit = new NewFileSplit(split.getPath(idx), + split.getOffset(idx), + split.getLength(idx), + split.getLocations()) + + val delegate = new NewTextInputFormat().createRecordReader(fileSplit, context) + + override def initialize(split: NewInputSplit, context: TaskAttemptContext): Unit = { + delegate.initialize(fileSplit, context) + } + + override def nextKeyValue(): Boolean = delegate.nextKeyValue() + override def getCurrentKey(): LongWritable = delegate.getCurrentKey + override def getCurrentValue(): Text = delegate.getCurrentValue + override def getProgress(): Float = delegate.getProgress + override def close(): Unit = delegate.close() +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index e33b4bbbb8e4c..bede1ffb3e2d0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -157,6 +157,24 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) } + test("treeAggregate") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def seqOp = (c: Long, x: Int) => c + x + def combOp = (c1: Long, c2: Long) => c1 + c2 + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) + assert(sum === -1000L) + } + } + + test("treeReduce") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + for (depth <- 1 until 10) { + val sum = rdd.treeReduce(_ + _, depth) + assert(sum === -1000) + } + } + test("basic caching") { val rdd = sc.makeRDD(Array(1, 2, 3, 4), 2).cache() assert(rdd.collect().toList === List(1, 2, 3, 4)) @@ -967,4 +985,5 @@ class RDDSuite extends FunSuite with SharedSparkContext { assertFails { sc.parallelize(1 to 100) } assertFails { sc.textFile("/nonexistent-path") } } + } diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala index 073814c127edc..f2ff98eb72daf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosSchedulerBackendSuite.scala @@ -43,7 +43,7 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea conf.set("spark.mesos.executor.home" , "/mesos-home") val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2))) + listenerBus.post(SparkListenerExecutorAdded(EasyMock.anyLong, "s1", new ExecutorInfo("host1", 2))) EasyMock.replay(listenerBus) val sc = EasyMock.createMock(classOf[SparkContext]) @@ -88,7 +88,7 @@ class MesosSchedulerBackendSuite extends FunSuite with LocalSparkContext with Ea val taskScheduler = EasyMock.createMock(classOf[TaskSchedulerImpl]) val listenerBus = EasyMock.createMock(classOf[LiveListenerBus]) - listenerBus.post(SparkListenerExecutorAdded("s1", new ExecutorInfo("host1", 2))) + listenerBus.post(SparkListenerExecutorAdded(EasyMock.anyLong, "s1", new ExecutorInfo("host1", 2))) EasyMock.replay(listenerBus) val sc = EasyMock.createMock(classOf[SparkContext]) diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala new file mode 100644 index 0000000000000..e62828c4fbac6 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -0,0 +1,139 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.serializer + +import java.io.{ObjectOutput, ObjectInput} + +import org.scalatest.{BeforeAndAfterEach, FunSuite} + + +class SerializationDebuggerSuite extends FunSuite with BeforeAndAfterEach { + + import SerializationDebugger.find + + override def beforeEach(): Unit = { + SerializationDebugger.enableDebugging = true + } + + test("primitives, strings, and nulls") { + assert(find(1) === List.empty) + assert(find(1L) === List.empty) + assert(find(1.toShort) === List.empty) + assert(find(1.0) === List.empty) + assert(find("1") === List.empty) + assert(find(null) === List.empty) + } + + test("primitive arrays") { + assert(find(Array[Int](1, 2)) === List.empty) + assert(find(Array[Long](1, 2)) === List.empty) + } + + test("non-primitive arrays") { + assert(find(Array("aa", "bb")) === List.empty) + assert(find(Array(new SerializableClass1)) === List.empty) + } + + test("serializable object") { + assert(find(new Foo(1, "b", 'c', 'd', null, null, null)) === List.empty) + } + + test("nested arrays") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + val foo2 = new Foo(1, "b", 'c', 'd', null, Array(foo1), null) + assert(find(new Foo(1, "b", 'c', 'd', null, Array(foo2), null)) === List.empty) + } + + test("nested objects") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + val foo2 = new Foo(1, "b", 'c', 'd', null, null, foo1) + assert(find(new Foo(1, "b", 'c', 'd', null, null, foo2)) === List.empty) + } + + test("cycles (should not loop forever)") { + val foo1 = new Foo(1, "b", 'c', 'd', null, null, null) + foo1.g = foo1 + assert(find(new Foo(1, "b", 'c', 'd', null, null, foo1)) === List.empty) + } + + test("root object not serializable") { + val s = find(new NotSerializable) + assert(s.size === 1) + assert(s.head.contains("NotSerializable")) + } + + test("array containing not serializable element") { + val s = find(new SerializableArray(Array(new NotSerializable))) + assert(s.size === 5) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("element of array")) + assert(s(2).contains("array")) + assert(s(3).contains("arrayField")) + assert(s(4).contains("SerializableArray")) + } + + test("object containing not serializable field") { + val s = find(new SerializableClass2(new NotSerializable)) + assert(s.size === 3) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("objectField")) + assert(s(2).contains("SerializableClass2")) + } + + test("externalizable class writing out not serializable object") { + val s = find(new ExternalizableClass) + assert(s.size === 5) + assert(s(0).contains("NotSerializable")) + assert(s(1).contains("objectField")) + assert(s(2).contains("SerializableClass2")) + assert(s(3).contains("writeExternal")) + assert(s(4).contains("ExternalizableClass")) + } +} + + +class SerializableClass1 extends Serializable + + +class SerializableClass2(val objectField: Object) extends Serializable + + +class SerializableArray(val arrayField: Array[Object]) extends Serializable + + +class ExternalizableClass extends java.io.Externalizable { + override def writeExternal(out: ObjectOutput): Unit = { + out.writeInt(1) + out.writeObject(new SerializableClass2(new NotSerializable)) + } + + override def readExternal(in: ObjectInput): Unit = {} +} + + +class Foo( + a: Int, + b: String, + c: Char, + d: Byte, + e: Array[Int], + f: Array[Object], + var g: Foo) extends Serializable + + +class NotSerializable diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 0357fc6ce2780..6577ebaa2e9a8 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -37,6 +37,9 @@ class JsonProtocolSuite extends FunSuite { val jobSubmissionTime = 1421191042750L val jobCompletionTime = 1421191296660L + val executorAddedTime = 1421458410000L + val executorRemovedTime = 1421458922000L + test("SparkListenerEvent") { val stageSubmitted = SparkListenerStageSubmitted(makeStageInfo(100, 200, 300, 400L, 500L), properties) @@ -73,9 +76,9 @@ class JsonProtocolSuite extends FunSuite { val unpersistRdd = SparkListenerUnpersistRDD(12345) val applicationStart = SparkListenerApplicationStart("The winner of all", None, 42L, "Garfield") val applicationEnd = SparkListenerApplicationEnd(42L) - val executorAdded = SparkListenerExecutorAdded("exec1", + val executorAdded = SparkListenerExecutorAdded(executorAddedTime, "exec1", new ExecutorInfo("Hostee.awesome.com", 11)) - val executorRemoved = SparkListenerExecutorRemoved("exec2") + val executorRemoved = SparkListenerExecutorRemoved(executorRemovedTime, "exec2", "test reason") testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -1453,9 +1456,10 @@ class JsonProtocolSuite extends FunSuite { """ private val executorAddedJsonString = - """ + s""" |{ | "Event": "SparkListenerExecutorAdded", + | "Timestamp": ${executorAddedTime}, | "Executor ID": "exec1", | "Executor Info": { | "Host": "Hostee.awesome.com", @@ -1465,10 +1469,12 @@ class JsonProtocolSuite extends FunSuite { """ private val executorRemovedJsonString = - """ + s""" |{ | "Event": "SparkListenerExecutorRemoved", - | "Executor ID": "exec2" + | "Timestamp": ${executorRemovedTime}, + | "Executor ID": "exec2", + | "Removed Reason": "test reason" |} """ } diff --git a/dev/check-license b/dev/check-license index 72b1013479964..a006f65710d6d 100755 --- a/dev/check-license +++ b/dev/check-license @@ -27,17 +27,17 @@ acquire_rat_jar () { if [[ ! -f "$rat_jar" ]]; then # Download rat launch jar if it hasn't been downloaded yet if [ ! -f "$JAR" ]; then - # Download - printf "Attempting to fetch rat\n" - JAR_DL="${JAR}.part" - if hash curl 2>/dev/null; then - curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" - elif hash wget 2>/dev/null; then - wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" - else - printf "You do not have curl or wget installed, please install rat manually.\n" - exit -1 - fi + # Download + printf "Attempting to fetch rat\n" + JAR_DL="${JAR}.part" + if [ $(command -v curl) ]; then + curl --silent "${URL}" > "$JAR_DL" && mv "$JAR_DL" "$JAR" + elif [ $(command -v wget) ]; then + wget --quiet ${URL} -O "$JAR_DL" && mv "$JAR_DL" "$JAR" + else + printf "You do not have curl or wget installed, please install rat manually.\n" + exit -1 + fi fi unzip -tq $JAR &> /dev/null diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh index b1b8cb44e098b..b2a7e092a0291 100755 --- a/dev/create-release/create-release.sh +++ b/dev/create-release/create-release.sh @@ -122,8 +122,14 @@ if [[ ! "$@" =~ --package-only ]]; then for file in $(find . -type f) do echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file; - gpg --print-md MD5 $file > $file.md5; - gpg --print-md SHA1 $file > $file.sha1 + if [ $(command -v md5) ]; then + # Available on OS X; -q to keep only hash + md5 -q $file > $file.md5 + else + # Available on Linux; cut to keep only hash + md5sum $file | cut -f1 -d' ' > $file.md5 + fi + shasum -a 1 $file | cut -f1 -d' ' > $file.sha1 done nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id diff --git a/docs/configuration.md b/docs/configuration.md index 7c5b6d011cfd3..e4e4b8d516b75 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -311,6 +311,9 @@ Apart from these, the following properties are also available, and may be useful or it will be displayed before the driver exiting. It also can be dumped into disk by `sc.dump_profiles(path)`. If some of the profile results had been displayed maually, they will not be displayed automatically before driver exiting. + + By default the `pyspark.profiler.BasicProfiler` will be used, but this can be overridden by + passing a profiler class in as a parameter to the `SparkContext` constructor. diff --git a/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png b/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png new file mode 100644 index 0000000000000..ed9adad11d03a Binary files /dev/null and b/docs/img/PIClusteringFiveCirclesInputsAndOutputs.png differ diff --git a/docs/mllib-clustering.md b/docs/mllib-clustering.md index c696ae9c8e8c8..413b824e369da 100644 --- a/docs/mllib-clustering.md +++ b/docs/mllib-clustering.md @@ -34,6 +34,26 @@ a given dataset, the algorithm returns the best clustering result). * *initializationSteps* determines the number of steps in the k-means\|\| algorithm. * *epsilon* determines the distance threshold within which we consider k-means to have converged. +### Power Iteration Clustering + +Power iteration clustering is a scalable and efficient algorithm for clustering points given pointwise mutual affinity values. Internally the algorithm: + +* accepts a [Graph](https://spark.apache.org/docs/0.9.2/api/graphx/index.html#org.apache.spark.graphx.Graph) that represents a normalized pairwise affinity between all input points. +* calculates the principal eigenvalue and eigenvector +* Clusters each of the input points according to their principal eigenvector component value + +Details of this algorithm are found within [Power Iteration Clustering, Lin and Cohen]{www.icml2010.org/papers/387.pdf} + +Example outputs for a dataset inspired by the paper - but with five clusters instead of three- have he following output from our implementation: + +

+ The Property Graph + +

+ ### Examples
diff --git a/docs/mllib-collaborative-filtering.md b/docs/mllib-collaborative-filtering.md index 2094963392295..ef18cec9371d6 100644 --- a/docs/mllib-collaborative-filtering.md +++ b/docs/mllib-collaborative-filtering.md @@ -192,12 +192,11 @@ We use the default ALS.train() method which assumes ratings are explicit. We eva recommendation by measuring the Mean Squared Error of rating prediction. {% highlight python %} -from pyspark.mllib.recommendation import ALS -from numpy import array +from pyspark.mllib.recommendation import ALS, Rating # Load and parse the data data = sc.textFile("data/mllib/als/test.data") -ratings = data.map(lambda line: array([float(x) for x in line.split(',')])) +ratings = data.map(lambda l: l.split(',')).map(lambda l: Rating(int(l[0]), int(l[1]), float(l[2]))) # Build the recommendation model using Alternating Least Squares rank = 10 @@ -205,10 +204,10 @@ numIterations = 20 model = ALS.train(ratings, rank, numIterations) # Evaluate the model on training data -testdata = ratings.map(lambda p: (int(p[0]), int(p[1]))) +testdata = ratings.map(lambda p: (p[0], p[1])) predictions = model.predictAll(testdata).map(lambda r: ((r[0], r[1]), r[2])) ratesAndPreds = ratings.map(lambda r: ((r[0], r[1]), r[2])).join(predictions) -MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y)/ratesAndPreds.count() +MSE = ratesAndPreds.map(lambda r: (r[1][0] - r[1][1])**2).reduce(lambda x, y: x + y) / ratesAndPreds.count() print("Mean Squared Error = " + str(MSE)) {% endhighlight %} @@ -217,7 +216,7 @@ signals), you can use the trainImplicit method to get better results. {% highlight python %} # Build the recommendation model using Alternating Least Squares based on implicit ratings -model = ALS.trainImplicit(ratings, rank, numIterations, alpha = 0.01) +model = ALS.trainImplicit(ratings, rank, numIterations, alpha=0.01) {% endhighlight %}
diff --git a/docs/mllib-feature-extraction.md b/docs/mllib-feature-extraction.md index 197bc77d506c6..d4a61a7fbf3d7 100644 --- a/docs/mllib-feature-extraction.md +++ b/docs/mllib-feature-extraction.md @@ -240,11 +240,11 @@ following parameters in the constructor: * `withMean` False by default. Centers the data with mean before scaling. It will build a dense output, so this does not work on sparse input and will raise an exception. -* `withStd` True by default. Scales the data to unit variance. +* `withStd` True by default. Scales the data to unit standard deviation. We provide a [`fit`](api/scala/index.html#org.apache.spark.mllib.feature.StandardScaler) method in `StandardScaler` which can take an input of `RDD[Vector]`, learn the summary statistics, and then -return a model which can transform the input dataset into unit variance and/or zero mean features +return a model which can transform the input dataset into unit standard deviation and/or zero mean features depending how we configure the `StandardScaler`. This model implements [`VectorTransformer`](api/scala/index.html#org.apache.spark.mllib.feature.VectorTransformer) @@ -257,7 +257,7 @@ for that feature. ### Example The example below demonstrates how to load a dataset in libsvm format, and standardize the features -so that the new features have unit variance and/or zero mean. +so that the new features have unit standard deviation and/or zero mean.
@@ -271,6 +271,8 @@ val data = MLUtils.loadLibSVMFile(sc, "data/mllib/sample_libsvm_data.txt") val scaler1 = new StandardScaler().fit(data.map(x => x.features)) val scaler2 = new StandardScaler(withMean = true, withStd = true).fit(data.map(x => x.features)) +// scaler3 is an identical model to scaler2, and will produce identical transformations +val scaler3 = new StandardScalerModel(scaler2.std, scaler2.mean) // data1 will be unit variance. val data1 = data.map(x => (x.label, scaler1.transform(x.features))) @@ -294,6 +296,9 @@ features = data.map(lambda x: x.features) scaler1 = StandardScaler().fit(features) scaler2 = StandardScaler(withMean=True, withStd=True).fit(features) +# scaler3 is an identical model to scaler2, and will produce identical transformations +scaler3 = StandardScalerModel(scaler2.std, scaler2.mean) + # data1 will be unit variance. data1 = label.zip(scaler1.transform(features)) diff --git a/docs/programming-guide.md b/docs/programming-guide.md index 2443fc29b4706..6486614e71354 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -886,7 +886,7 @@ for details. groupByKey([numTasks]) When called on a dataset of (K, V) pairs, returns a dataset of (K, Iterable<V>) pairs.
Note: If you are grouping in order to perform an aggregation (such as a sum or - average) over each key, using reduceByKey or combineByKey will yield much better + average) over each key, using reduceByKey or aggregateByKey will yield much better performance.
Note: By default, the level of parallelism in the output depends on the number of partitions of the parent RDD. diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 index 3abd3f396f605..26e7d22655694 100755 --- a/ec2/spark-ec2 +++ b/ec2/spark-ec2 @@ -20,6 +20,6 @@ # Preserve the user's CWD so that relative paths are passed correctly to #+ the underlying Python script. -SPARK_EC2_DIR="$(dirname $0)" +SPARK_EC2_DIR="$(dirname "$0")" python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/examples/pom.xml b/examples/pom.xml index 4b92147725f6b..8caad2bc2e27a 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -35,12 +35,6 @@ http://spark.apache.org/ - - - com.google.guava - guava - compile - org.apache.spark spark-core_${scala.binary.version} @@ -310,69 +304,40 @@ org.apache.maven.plugins maven-shade-plugin - - - package - - shade - - - false - ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar - - - *:* - - - - - com.google.guava:guava - - - ** - - - - *:* - - META-INF/*.SF - META-INF/*.DSA - META-INF/*.RSA - - - - - - com.google - org.spark-project.guava - - com.google.common.** - - - com.google.common.base.Optional** - - - - org.apache.commons.math3 - org.spark-project.commons.math3 - - - - - - reference.conf - - - log4j.properties - - - - - + + false + ${project.build.directory}/scala-${scala.binary.version}/spark-examples-${project.version}-hadoop${hadoop.version}.jar + + + *:* + + + + + *:* + + META-INF/*.SF + META-INF/*.DSA + META-INF/*.RSA + + + + + + org.apache.commons.math3 + org.spark-project.commons.math3 + + + + + + reference.conf + + + log4j.properties + + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java index 247d2a5e31a8c..0fbee6e433608 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java @@ -33,7 +33,7 @@ import org.apache.spark.ml.tuning.CrossValidator; import org.apache.spark.ml.tuning.CrossValidatorModel; import org.apache.spark.ml.tuning.ParamGridBuilder; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -71,7 +71,7 @@ public static void main(String[] args) { new LabeledDocument(9L, "a e c l", 0.0), new LabeledDocument(10L, "spark compile", 1.0), new LabeledDocument(11L, "hadoop software", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -112,11 +112,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + cvModel.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java index 5b92655e2e838..eaaa344be49c8 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java @@ -28,7 +28,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -48,13 +48,13 @@ public static void main(String[] args) { // Prepare training data. // We use LabeledPoint, which is a JavaBean. Spark SQL can convert RDDs of JavaBeans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. List localTraining = Lists.newArrayList( new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class); // Create a LogisticRegression instance. This instance is an Estimator. LogisticRegression lr = new LogisticRegression(); @@ -94,14 +94,14 @@ public static void main(String[] args) { new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)), new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)), new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class); // Make predictions on test documents using the Transformer.transform() method. // LogisticRegression.transform will only use the 'features' column. // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. - model2.transform(test).registerAsTable("results"); - SchemaRDD results = + model2.transform(test).registerTempTable("results"); + DataFrame results = jsql.sql("SELECT features, label, probability, prediction FROM results"); for (Row r: results.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java index 74db449fada7d..82d665a3e1386 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java @@ -29,7 +29,7 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.HashingTF; import org.apache.spark.ml.feature.Tokenizer; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.Row; @@ -54,7 +54,7 @@ public static void main(String[] args) { new LabeledDocument(1L, "b d", 0.0), new LabeledDocument(2L, "spark f g h", 1.0), new LabeledDocument(3L, "hadoop mapreduce", 0.0)); - SchemaRDD training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); + DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledDocument.class); // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. Tokenizer tokenizer = new Tokenizer() @@ -79,11 +79,11 @@ public static void main(String[] args) { new Document(5L, "l m n"), new Document(6L, "mapreduce spark"), new Document(7L, "apache hadoop")); - SchemaRDD test = jsql.applySchema(jsc.parallelize(localTest), Document.class); + DataFrame test = jsql.applySchema(jsc.parallelize(localTest), Document.class); // Make predictions on test documents. - model.transform(test).registerAsTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); + model.transform(test).registerTempTable("prediction"); + DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction"); for (Row r: predictions.collect()) { System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2) + ", prediction=" + r.get(3)); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java index b70804635d5c9..8defb769ffaaf 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java @@ -26,9 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaSparkSQL { public static class Person implements Serializable { @@ -74,13 +74,13 @@ public Person call(String line) { }); // Apply a schema to an RDD of Java Beans and register it as a table. - SchemaRDD schemaPeople = sqlCtx.applySchema(people, Person.class); + DataFrame schemaPeople = sqlCtx.applySchema(people, Person.class); schemaPeople.registerTempTable("people"); // SQL can be run over RDDs that have been registered as tables. - SchemaRDD teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are SchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrames and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. List teenagerNames = teenagers.toJavaRDD().map(new Function() { @Override @@ -93,17 +93,17 @@ public String call(Row row) { } System.out.println("=== Data source: Parquet File ==="); - // JavaSchemaRDDs can be saved as parquet files, maintaining the schema information. + // DataFrames can be saved as parquet files, maintaining the schema information. schemaPeople.saveAsParquetFile("people.parquet"); // Read in the parquet file created above. // Parquet files are self-describing so the schema is preserved. - // The result of loading a parquet file is also a JavaSchemaRDD. - SchemaRDD parquetFile = sqlCtx.parquetFile("people.parquet"); + // The result of loading a parquet file is also a DataFrame. + DataFrame parquetFile = sqlCtx.parquetFile("people.parquet"); //Parquet files can also be registered as tables and then used in SQL statements. parquetFile.registerTempTable("parquetFile"); - SchemaRDD teenagers2 = + DataFrame teenagers2 = sqlCtx.sql("SELECT name FROM parquetFile WHERE age >= 13 AND age <= 19"); teenagerNames = teenagers2.toJavaRDD().map(new Function() { @Override @@ -119,8 +119,8 @@ public String call(Row row) { // A JSON dataset is pointed by path. // The path can be either a single text file or a directory storing text files. String path = "examples/src/main/resources/people.json"; - // Create a JavaSchemaRDD from the file(s) pointed by path - SchemaRDD peopleFromJsonFile = sqlCtx.jsonFile(path); + // Create a DataFrame from the file(s) pointed by path + DataFrame peopleFromJsonFile = sqlCtx.jsonFile(path); // Because the schema of a JSON dataset is automatically inferred, to write queries, // it is better to take a look at what is the schema. @@ -130,13 +130,13 @@ public String call(Row row) { // |-- age: IntegerType // |-- name: StringType - // Register this JavaSchemaRDD as a table. + // Register this DataFrame as a table. peopleFromJsonFile.registerTempTable("people"); // SQL statements can be run by using the sql methods provided by sqlCtx. - SchemaRDD teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); + DataFrame teenagers3 = sqlCtx.sql("SELECT name FROM people WHERE age >= 13 AND age <= 19"); - // The results of SQL queries are JavaSchemaRDDs and support all the normal RDD operations. + // The results of SQL queries are DataFrame and support all the normal RDD operations. // The columns of a row in the result can be accessed by ordinal. teenagerNames = teenagers3.toJavaRDD().map(new Function() { @Override @@ -146,14 +146,14 @@ public String call(Row row) { System.out.println(name); } - // Alternatively, a JavaSchemaRDD can be created for a JSON dataset represented by + // Alternatively, a DataFrame can be created for a JSON dataset represented by // a RDD[String] storing one JSON object per string. List jsonData = Arrays.asList( "{\"name\":\"Yin\",\"address\":{\"city\":\"Columbus\",\"state\":\"Ohio\"}}"); JavaRDD anotherPeopleRDD = ctx.parallelize(jsonData); - SchemaRDD peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); + DataFrame peopleFromJsonRDD = sqlCtx.jsonRDD(anotherPeopleRDD.rdd()); - // Take a look at the schema of this new JavaSchemaRDD. + // Take a look at the schema of this new DataFrame. peopleFromJsonRDD.printSchema(); // The schema of anotherPeople is ... // root @@ -164,7 +164,7 @@ public String call(Row row) { peopleFromJsonRDD.registerTempTable("people2"); - SchemaRDD peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); + DataFrame peopleWithCity = sqlCtx.sql("SELECT name, address.city FROM people2"); List nameAndCity = peopleWithCity.toJavaRDD().map(new Function() { @Override public String call(Row row) { diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py new file mode 100644 index 0000000000000..c7df3d7b74767 --- /dev/null +++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.sql import SQLContext, Row +from pyspark.ml import Pipeline +from pyspark.ml.feature import HashingTF, Tokenizer +from pyspark.ml.classification import LogisticRegression + + +""" +A simple text classification pipeline that recognizes "spark" from +input text. This is to show how to create and configure a Spark ML +pipeline in Python. Run with: + + bin/spark-submit examples/src/main/python/ml/simple_text_classification_pipeline.py +""" + + +if __name__ == "__main__": + sc = SparkContext(appName="SimpleTextClassificationPipeline") + sqlCtx = SQLContext(sc) + + # Prepare training documents, which are labeled. + LabeledDocument = Row('id', 'text', 'label') + training = sqlCtx.inferSchema( + sc.parallelize([(0L, "a b c d e spark", 1.0), + (1L, "b d", 0.0), + (2L, "spark f g h", 1.0), + (3L, "hadoop mapreduce", 0.0)]) + .map(lambda x: LabeledDocument(*x))) + + # Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr. + tokenizer = Tokenizer() \ + .setInputCol("text") \ + .setOutputCol("words") + hashingTF = HashingTF() \ + .setInputCol(tokenizer.getOutputCol()) \ + .setOutputCol("features") + lr = LogisticRegression() \ + .setMaxIter(10) \ + .setRegParam(0.01) + pipeline = Pipeline() \ + .setStages([tokenizer, hashingTF, lr]) + + # Fit the pipeline to training documents. + model = pipeline.fit(training) + + # Prepare test documents, which are unlabeled. + Document = Row('id', 'text') + test = sqlCtx.inferSchema( + sc.parallelize([(4L, "spark i j k"), + (5L, "l m n"), + (6L, "mapreduce spark"), + (7L, "apache hadoop")]) + .map(lambda x: Document(*x))) + + # Make predictions on test documents and print columns of interest. + prediction = model.transform(test) + prediction.registerTempTable("prediction") + selected = sqlCtx.sql("SELECT id, text, prediction from prediction") + for row in selected.collect(): + print row + + sc.stop() diff --git a/examples/src/main/python/mllib/dataset_example.py b/examples/src/main/python/mllib/dataset_example.py index 540dae785f6ea..b5a70db2b9a3c 100644 --- a/examples/src/main/python/mllib/dataset_example.py +++ b/examples/src/main/python/mllib/dataset_example.py @@ -16,7 +16,7 @@ # """ -An example of how to use SchemaRDD as a dataset for ML. Run with:: +An example of how to use DataFrame as a dataset for ML. Run with:: bin/spark-submit examples/src/main/python/mllib/dataset_example.py """ diff --git a/examples/src/main/python/mllib/gradient_boosted_trees.py b/examples/src/main/python/mllib/gradient_boosted_trees.py new file mode 100644 index 0000000000000..e647773ad9060 --- /dev/null +++ b/examples/src/main/python/mllib/gradient_boosted_trees.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Gradient boosted Trees classification and regression using MLlib. +""" + +import sys + +from pyspark.context import SparkContext +from pyspark.mllib.tree import GradientBoostedTrees +from pyspark.mllib.util import MLUtils + + +def testClassification(trainingData, testData): + # Train a GradientBoostedTrees model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = GradientBoostedTrees.trainClassifier(trainingData, categoricalFeaturesInfo={}, + numIterations=30, maxDepth=4) + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testErr = labelsAndPredictions.filter(lambda (v, p): v != p).count() \ + / float(testData.count()) + print('Test Error = ' + str(testErr)) + print('Learned classification ensemble model:') + print(model.toDebugString()) + + +def testRegression(trainingData, testData): + # Train a GradientBoostedTrees model. + # Empty categoricalFeaturesInfo indicates all features are continuous. + model = GradientBoostedTrees.trainRegressor(trainingData, categoricalFeaturesInfo={}, + numIterations=30, maxDepth=4) + # Evaluate model on test instances and compute test error + predictions = model.predict(testData.map(lambda x: x.features)) + labelsAndPredictions = testData.map(lambda lp: lp.label).zip(predictions) + testMSE = labelsAndPredictions.map(lambda (v, p): (v - p) * (v - p)).sum() \ + / float(testData.count()) + print('Test Mean Squared Error = ' + str(testMSE)) + print('Learned regression ensemble model:') + print(model.toDebugString()) + + +if __name__ == "__main__": + if len(sys.argv) > 1: + print >> sys.stderr, "Usage: gradient_boosted_trees" + exit(1) + sc = SparkContext(appName="PythonGradientBoostedTrees") + + # Load and parse the data file into an RDD of LabeledPoint. + data = MLUtils.loadLibSVMFile(sc, 'data/mllib/sample_libsvm_data.txt') + # Split the data into training and test sets (30% held out for testing) + (trainingData, testData) = data.randomSplit([0.7, 0.3]) + + print('\nRunning example of classification using GradientBoostedTrees\n') + testClassification(trainingData, testData) + + print('\nRunning example of regression using GradientBoostedTrees\n') + testRegression(trainingData, testData) + + sc.stop() diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py index d2c5ca48c6cb8..7f5c68e3d0fe2 100644 --- a/examples/src/main/python/sql.py +++ b/examples/src/main/python/sql.py @@ -30,18 +30,18 @@ some_rdd = sc.parallelize([Row(name="John", age=19), Row(name="Smith", age=23), Row(name="Sarah", age=18)]) - # Infer schema from the first row, create a SchemaRDD and print the schema - some_schemardd = sqlContext.inferSchema(some_rdd) - some_schemardd.printSchema() + # Infer schema from the first row, create a DataFrame and print the schema + some_df = sqlContext.inferSchema(some_rdd) + some_df.printSchema() # Another RDD is created from a list of tuples another_rdd = sc.parallelize([("John", 19), ("Smith", 23), ("Sarah", 18)]) # Schema with two fields - person_name and person_age schema = StructType([StructField("person_name", StringType(), False), StructField("person_age", IntegerType(), False)]) - # Create a SchemaRDD by applying the schema to the RDD and print the schema - another_schemardd = sqlContext.applySchema(another_rdd, schema) - another_schemardd.printSchema() + # Create a DataFrame by applying the schema to the RDD and print the schema + another_df = sqlContext.applySchema(another_rdd, schema) + another_df.printSchema() # root # |-- age: integer (nullable = true) # |-- name: string (nullable = true) @@ -49,7 +49,7 @@ # A JSON dataset is pointed to by path. # The path can be either a single text file or a directory storing text files. path = os.path.join(os.environ['SPARK_HOME'], "examples/src/main/resources/people.json") - # Create a SchemaRDD from the file(s) pointed to by path + # Create a DataFrame from the file(s) pointed to by path people = sqlContext.jsonFile(path) # root # |-- person_name: string (nullable = false) @@ -61,7 +61,7 @@ # |-- age: IntegerType # |-- name: StringType - # Register this SchemaRDD as a table. + # Register this DataFrame as a table. people.registerAsTable("people") # SQL statements can be run by using the sql methods provided by sqlContext diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index d8c7ef38ee46d..283bb80f1c788 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator @@ -101,7 +100,7 @@ object CrossValidatorExample { // Make predictions on test documents. cvModel uses the best model found (lrModel). cvModel.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index cf62772b92651..b7885829459a3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -143,7 +143,7 @@ object MovieLensALS { // Evaluate the model. // TODO: Create an evaluator to compute RMSE. - val mse = predictions.select('rating, 'prediction) + val mse = predictions.select("rating", "prediction").rdd .flatMap { case Row(rating: Float, prediction: Float) => val err = rating.toDouble - prediction val err2 = err * err diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala index e8a2adff929cb..95cc9801eaeb9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala @@ -18,7 +18,6 @@ package org.apache.spark.examples.ml import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.param.ParamMap import org.apache.spark.mllib.linalg.{Vector, Vectors} @@ -42,7 +41,7 @@ object SimpleParamsExample { // Prepare training data. // We use LabeledPoint, which is a case class. Spark SQL can convert RDDs of Java Beans - // into SchemaRDDs, where it uses the bean metadata to infer the schema. + // into DataFrames, where it uses the bean metadata to infer the schema. val training = sparkContext.parallelize(Seq( LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)), LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)), @@ -92,7 +91,7 @@ object SimpleParamsExample { // Note that model2.transform() outputs a 'probability' column instead of the usual 'score' // column since we renamed the lr.scoreCol parameter previously. model2.transform(test) - .select('features, 'label, 'probability, 'prediction) + .select("features", "label", "probability", "prediction") .collect() .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) => println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala index b9a6ef0229def..065db62b0f5ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala @@ -20,7 +20,6 @@ package org.apache.spark.examples.ml import scala.beans.BeanInfo import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.SparkContext._ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.feature.{HashingTF, Tokenizer} @@ -80,7 +79,7 @@ object SimpleTextClassificationPipeline { // Make predictions on test documents. model.transform(test) - .select('id, 'text, 'score, 'prediction) + .select("id", "text", "score", "prediction") .collect() .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) => println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala index f8d83f4ec7327..ab58375649d25 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DatasetExample.scala @@ -28,10 +28,10 @@ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, SQLContext, SchemaRDD} +import org.apache.spark.sql.{Row, SQLContext, DataFrame} /** - * An example of how to use [[org.apache.spark.sql.SchemaRDD]] as a Dataset for ML. Run with + * An example of how to use [[org.apache.spark.sql.DataFrame]] as a Dataset for ML. Run with * {{{ * ./bin/run-example org.apache.spark.examples.mllib.DatasetExample [options] * }}} @@ -47,7 +47,7 @@ object DatasetExample { val defaultParams = Params() val parser = new OptionParser[Params]("DatasetExample") { - head("Dataset: an example app using SchemaRDD as a Dataset for ML.") + head("Dataset: an example app using DataFrame as a Dataset for ML.") opt[String]("input") .text(s"input path to dataset") .action((x, c) => c.copy(input = x)) @@ -80,20 +80,20 @@ object DatasetExample { } println(s"Loaded ${origData.count()} instances from file: ${params.input}") - // Convert input data to SchemaRDD explicitly. - val schemaRDD: SchemaRDD = origData - println(s"Inferred schema:\n${schemaRDD.schema.prettyJson}") - println(s"Converted to SchemaRDD with ${schemaRDD.count()} records") + // Convert input data to DataFrame explicitly. + val df: DataFrame = origData.toDataFrame + println(s"Inferred schema:\n${df.schema.prettyJson}") + println(s"Converted to DataFrame with ${df.count()} records") - // Select columns, using implicit conversion to SchemaRDD. - val labelsSchemaRDD: SchemaRDD = origData.select('label) - val labels: RDD[Double] = labelsSchemaRDD.map { case Row(v: Double) => v } + // Select columns, using implicit conversion to DataFrames. + val labelsDf: DataFrame = origData.select("label") + val labels: RDD[Double] = labelsDf.map { case Row(v: Double) => v } val numLabels = labels.count() val meanLabel = labels.fold(0.0)(_ + _) / numLabels println(s"Selected label column with average value $meanLabel") - val featuresSchemaRDD: SchemaRDD = origData.select('features) - val features: RDD[Vector] = featuresSchemaRDD.map { case Row(v: Vector) => v } + val featuresDf: DataFrame = origData.select("features") + val features: RDD[Vector] = featuresDf.map { case Row(v: Vector) => v } val featureSummary = features.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) @@ -103,13 +103,13 @@ object DatasetExample { tmpDir.deleteOnExit() val outputDir = new File(tmpDir, "dataset").toString println(s"Saving to $outputDir as Parquet file.") - schemaRDD.saveAsParquetFile(outputDir) + df.saveAsParquetFile(outputDir) println(s"Loading Parquet file with UDT from $outputDir.") val newDataset = sqlContext.parquetFile(outputDir) println(s"Schema from Parquet: ${newDataset.schema.prettyJson}") - val newFeatures = newDataset.select('features).map { case Row(v: Vector) => v } + val newFeatures = newDataset.select("features").map { case Row(v: Vector) => v } val newFeaturesSummary = newFeatures.aggregate(new MultivariateOnlineSummarizer())( (summary, feat) => summary.add(feat), (sum1, sum2) => sum1.merge(sum2)) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala similarity index 91% rename from examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala rename to examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala index de58be38c7bfb..df76b45e50810 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGaussianMixture.scala @@ -18,17 +18,17 @@ package org.apache.spark.examples.mllib import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.mllib.clustering.GaussianMixtureEM +import org.apache.spark.mllib.clustering.GaussianMixture import org.apache.spark.mllib.linalg.Vectors /** * An example Gaussian Mixture Model EM app. Run with * {{{ - * ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM + * ./bin/run-example mllib.DenseGaussianMixture * }}} * If you use it as a template to create your own app, please use `spark-submit` to submit your app. */ -object DenseGmmEM { +object DenseGaussianMixture { def main(args: Array[String]): Unit = { if (args.length < 3) { println("usage: DenseGmmEM [maxIterations]") @@ -46,7 +46,7 @@ object DenseGmmEM { Vectors.dense(line.trim.split(' ').map(_.toDouble)) }.cache() - val clusters = new GaussianMixtureEM() + val clusters = new GaussianMixture() .setK(k) .setConvergenceTol(convergenceTol) .setMaxIterations(maxIterations) diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala index 2e98b2dc30b80..82a0b637b3cff 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala @@ -19,6 +19,7 @@ package org.apache.spark.examples.sql import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.Dsl._ // One method for defining the schema of an RDD is to make a case class with the desired column // names and types. @@ -54,7 +55,7 @@ object RDDRelation { rddFromSql.map(row => s"Key: ${row(0)}, Value: ${row(1)}").collect().foreach(println) // Queries can also be written using a LINQ-like Scala DSL. - rdd.where('key === 1).orderBy('value.asc).select('key).collect().foreach(println) + rdd.where($"key" === 1).orderBy($"value".asc).select($"key").collect().foreach(println) // Write out an RDD as a parquet file. rdd.saveAsParquetFile("pair.parquet") @@ -63,7 +64,7 @@ object RDDRelation { val parquetFile = sqlContext.parquetFile("pair.parquet") // Queries can be run using the DSL on parequet files just like the original RDD. - parquetFile.where('key === 1).select('value as 'a).collect().foreach(println) + parquetFile.where($"key" === 1).select($"value".as("a")).collect().foreach(println) // These files can also be registered as tables. parquetFile.registerTempTable("parquetFile") diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index d1427f6a0c6e9..f2f0aa78b0a4b 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -42,7 +42,7 @@ - com.codahale.metrics + io.dropwizard.metrics metrics-ganglia diff --git a/graphx/pom.xml b/graphx/pom.xml index 72374aae6da9b..8fac24b6ed86d 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -40,6 +40,10 @@ spark-core_${scala.binary.version} ${project.version} + + com.google.guava + guava + org.jblas jblas diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala index 84b72b390ca35..ab56580a3abc8 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Graph.scala @@ -55,7 +55,7 @@ abstract class Graph[VD: ClassTag, ED: ClassTag] protected () extends Serializab * @return an RDD containing the edges in this graph * * @see [[Edge]] for the edge type. - * @see [[triplets]] to get an RDD which contains all the edges + * @see [[Graph#triplets]] to get an RDD which contains all the edges * along with their vertex data. * */ diff --git a/make-distribution.sh b/make-distribution.sh index 0adca7851819b..051c87c0894ae 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -32,6 +32,10 @@ SPARK_HOME="$(cd "`dirname "$0"`"; pwd)" DISTDIR="$SPARK_HOME/dist" SPARK_TACHYON=false +TACHYON_VERSION="0.5.0" +TACHYON_TGZ="tachyon-${TACHYON_VERSION}-bin.tar.gz" +TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/${TACHYON_TGZ}" + MAKE_TGZ=false NAME=none MVN="$SPARK_HOME/build/mvn" @@ -93,7 +97,7 @@ done if [ -z "$JAVA_HOME" ]; then # Fall back on JAVA_HOME from rpm, if found - if which rpm &>/dev/null; then + if [ $(command -v rpm) ]; then RPM_JAVA_HOME=$(rpm -E %java_home 2>/dev/null) if [ "$RPM_JAVA_HOME" != "%java_home" ]; then JAVA_HOME=$RPM_JAVA_HOME @@ -107,7 +111,7 @@ if [ -z "$JAVA_HOME" ]; then exit -1 fi -if which git &>/dev/null; then +if [ $(command -v git) ]; then GITREV=$(git rev-parse --short HEAD 2>/dev/null || :) if [ ! -z $GITREV ]; then GITREVSTRING=" (git revision $GITREV)" @@ -115,14 +119,15 @@ if which git &>/dev/null; then unset GITREV fi -if ! which "$MVN" &>/dev/null; then + +if [ ! $(command -v $MVN) ] ; then echo -e "Could not locate Maven command: '$MVN'." echo -e "Specify the Maven command with the --mvn flag" exit -1; fi -VERSION=$(mvn help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) -SPARK_HADOOP_VERSION=$(mvn help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ +VERSION=$($MVN help:evaluate -Dexpression=project.version 2>/dev/null | grep -v "INFO" | tail -n 1) +SPARK_HADOOP_VERSION=$($MVN help:evaluate -Dexpression=hadoop.version $@ 2>/dev/null\ | grep -v "INFO"\ | tail -n 1) SPARK_HIVE=$($MVN help:evaluate -Dexpression=project.activeProfiles -pl sql/hive $@ 2>/dev/null\ @@ -225,16 +230,22 @@ cp -r "$SPARK_HOME/ec2" "$DISTDIR" # Download and copy in tachyon, if requested if [ "$SPARK_TACHYON" == "true" ]; then - TACHYON_VERSION="0.5.0" - TACHYON_URL="https://github.com/amplab/tachyon/releases/download/v${TACHYON_VERSION}/tachyon-${TACHYON_VERSION}-bin.tar.gz" - TMPD=`mktemp -d 2>/dev/null || mktemp -d -t 'disttmp'` pushd $TMPD > /dev/null echo "Fetching tachyon tgz" - wget "$TACHYON_URL" - tar xf "tachyon-${TACHYON_VERSION}-bin.tar.gz" + TACHYON_DL="${TACHYON_TGZ}.part" + if [ $(command -v curl) ]; then + curl --silent -k -L "${TACHYON_URL}" > "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" + elif [ $(command -v wget) ]; then + wget --quiet "${TACHYON_URL}" -O "${TACHYON_DL}" && mv "${TACHYON_DL}" "${TACHYON_TGZ}" + else + printf "You do not have curl or wget installed. please install Tachyon manually.\n" + exit -1 + fi + + tar xzf "${TACHYON_TGZ}" cp "tachyon-${TACHYON_VERSION}/core/target/tachyon-${TACHYON_VERSION}-jar-with-dependencies.jar" "$DISTDIR/lib" mkdir -p "$DISTDIR/tachyon/src/main/java/tachyon/web" cp -r "tachyon-${TACHYON_VERSION}"/{bin,conf,libexec} "$DISTDIR/tachyon" diff --git a/mllib/pom.xml b/mllib/pom.xml index a0bda89ccaa71..a8cee3d51a780 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -50,6 +50,11 @@ spark-sql_${scala.binary.version} ${project.version} + + org.apache.spark + spark-graphx_${scala.binary.version} + ${project.version} + org.jblas jblas @@ -125,6 +130,9 @@ ../python pyspark/mllib/*.py + pyspark/mllib/stat/*.py + pyspark/ml/*.py + pyspark/ml/param/*.py diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index 77d230eb4a122..bc3defe968afd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -21,7 +21,7 @@ import scala.annotation.varargs import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -38,7 +38,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @return fitted model */ @varargs - def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = { + def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = { val map = new ParamMap().put(paramPairs: _*) fit(dataset, map) } @@ -50,7 +50,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMap parameter map * @return fitted model */ - def fit(dataset: SchemaRDD, paramMap: ParamMap): M + def fit(dataset: DataFrame, paramMap: ParamMap): M /** * Fits multiple models to the input data with multiple sets of parameters. @@ -61,7 +61,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { * @param paramMaps an array of parameter maps * @return fitted models, matching the input parameter maps */ - def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { + def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala index db563dd550e56..d2ca2e6871e6b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame /** * :: AlphaComponent :: @@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable { * @param paramMap parameter map that specifies the input columns and output metrics * @return metric */ - def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double + def evaluate(dataset: DataFrame, paramMap: ParamMap): Double } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index ad6fed178fae9..bb291e6e1fd7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{Param, ParamMap} -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -58,11 +58,11 @@ abstract class PipelineStage extends Serializable with Logging { /** * :: AlphaComponent :: * A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each - * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline.fit]] is called, the - * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator.fit]] method will + * of which is either an [[Estimator]] or a [[Transformer]]. When [[Pipeline#fit]] is called, the + * stages are executed in order. If a stage is an [[Estimator]], its [[Estimator#fit]] method will * be called on the input dataset to fit a model. Then the model, which is a transformer, will be * used to transform the dataset as the input to the next stage. If a stage is a [[Transformer]], - * its [[Transformer.transform]] method will be called to produce the dataset for the next stage. + * its [[Transformer#transform]] method will be called to produce the dataset for the next stage. * The fitted model from a [[Pipeline]] is an [[PipelineModel]], which consists of fitted models and * transformers, corresponding to the pipeline stages. If there are no stages, the pipeline acts as * an identity transformer. @@ -77,9 +77,9 @@ class Pipeline extends Estimator[PipelineModel] { /** * Fits the pipeline to the input dataset with additional parameters. If a stage is an - * [[Estimator]], its [[Estimator.fit]] method will be called on the input dataset to fit a model. + * [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model. * Then the model, which is a transformer, will be used to transform the dataset as the input to - * the next stage. If a stage is a [[Transformer]], its [[Transformer.transform]] method will be + * the next stage. If a stage is a [[Transformer]], its [[Transformer#transform]] method will be * called to produce the dataset for the next stage. The fitted model from a [[Pipeline]] is an * [[PipelineModel]], which consists of fitted models and transformers, corresponding to the * pipeline stages. If there are no stages, the output model acts as an identity transformer. @@ -88,7 +88,7 @@ class Pipeline extends Estimator[PipelineModel] { * @param paramMap parameter map * @return fitted pipeline */ - override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = { transformSchema(dataset.schema, paramMap, logging = true) val map = this.paramMap ++ paramMap val theStages = map(stages) @@ -162,7 +162,7 @@ class PipelineModel private[ml] ( } } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap val map = (fittingParamMap ++ this.paramMap) ++ paramMap transformSchema(dataset.schema, map, logging = true) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index af56f9c435351..cd95c16aa768d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -22,9 +22,8 @@ import scala.annotation.varargs import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.expressions.ScalaUdf +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types._ /** @@ -41,7 +40,7 @@ abstract class Transformer extends PipelineStage with Params { * @return transformed dataset */ @varargs - def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = { + def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = { val map = new ParamMap() paramPairs.foreach(map.put(_)) transform(dataset, map) @@ -53,7 +52,7 @@ abstract class Transformer extends PipelineStage with Params { * @param paramMap additional parameters, overwrite embedded params * @return transformed dataset */ - def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD + def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame } /** @@ -95,11 +94,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O StructType(outputFields) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr)) - dataset.select(Star(None), udf as map(outputCol)) + dataset.select($"*", callUDF( + this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 8c570812f8316..18be35ad59452 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -24,8 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} import org.apache.spark.storage.StorageLevel @@ -87,11 +86,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti def setScoreCol(value: String): this.type = set(scoreCol, value) def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr) + val instances = dataset.select(map(labelCol), map(featuresCol)) .map { case Row(label: Double, features: Vector) => LabeledPoint(label, features) }.persist(StorageLevel.MEMORY_AND_DISK) @@ -131,19 +129,17 @@ class LogisticRegressionModel private[ml] ( validateAndTransformSchema(schema, paramMap, fitting = false) } - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val score: Vector => Double = (v) => { + val scoreFunction: Vector => Double = (v) => { val margin = BLAS.dot(v, weights) 1.0 / (1.0 + math.exp(-margin)) } val t = map(threshold) - val predict: Double => Double = (score) => { - if (score > t) 1.0 else 0.0 - } - dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol)) - .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol)) + val predictFunction: Double => Double = (score) => { if (score > t) 1.0 else 0.0 } + dataset + .select($"*", callUDF(scoreFunction, col(map(featuresCol))).as(map(scoreCol))) + .select($"*", callUDF(predictFunction, col(map(scoreCol))).as(map(predictionCol))) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 12473cb2b5719..1979ab9eb6516 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param._ import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics -import org.apache.spark.sql.{Row, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** @@ -41,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params def setScoreCol(value: String): this.type = set(scoreCol, value) def setLabelCol(value: String): this.type = set(labelCol, value) - override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = { + override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = { val map = this.paramMap ++ paramMap val schema = dataset.schema @@ -52,8 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params require(labelType == DoubleType, s"Label column ${map(labelCol)} must be double type but found $labelType") - import dataset.sqlContext._ - val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr) + val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol)) .map { case Row(score: Double, label: Double) => (score, label) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index 72825f6e02182..01a4f5eb205e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -23,8 +23,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.Star -import org.apache.spark.sql.catalyst.dsl._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types.{StructField, StructType} /** @@ -43,14 +42,10 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap - val input = dataset.select(map(inputCol).attr) - .map { case Row(v: Vector) => - v - } + val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v } val scaler = new feature.StandardScaler().fit(input) val model = new StandardScalerModel(this, map, scaler) Params.inheritValues(map, this, model) @@ -83,14 +78,13 @@ class StandardScalerModel private[ml] ( def setInputCol(value: String): this.type = set(inputCol, value) def setOutputCol(value: String): this.type = set(outputCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { transformSchema(dataset.schema, paramMap, logging = true) - import dataset.sqlContext._ val map = this.paramMap ++ paramMap val scale: (Vector) => Vector = (v) => { scaler.transform(v) } - dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol)) + dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol))) } private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 04f9cfb1bfc2f..5fb4379e23c2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -164,6 +164,13 @@ trait Params extends Identifiable with Serializable { this } + /** + * Sets a parameter (by name) in the embedded param map. + */ + private[ml] def set(param: String, value: Any): this.type = { + set(getParam(param), value) + } + /** * Gets the value of a parameter in the embedded param map. */ @@ -286,7 +293,6 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten new ParamMap(this.map ++ other.map) } - /** * Adds all parameters from the input param map into this param map. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 2d89e76a4c8b2..979a19d3b2057 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -20,19 +20,20 @@ package org.apache.spark.ml.recommendation import java.{util => ju} import scala.collection.mutable +import scala.reflect.ClassTag +import scala.util.Sorting import com.github.fommil.netlib.BLAS.{getInstance => blas} import com.github.fommil.netlib.LAPACK.{getInstance => lapack} import org.netlib.util.intW import org.apache.spark.{HashPartitioner, Logging, Partitioner} +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.catalyst.dsl._ -import org.apache.spark.sql.catalyst.expressions.Cast -import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType} import org.apache.spark.util.Utils import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter} @@ -112,21 +113,11 @@ class ALSModel private[ml] ( def setPredictionCol(value: String): this.type = set(predictionCol, value) - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { - import dataset.sqlContext._ - import org.apache.spark.ml.recommendation.ALSModel.Factor + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { + import dataset.sqlContext.createDataFrame val map = this.paramMap ++ paramMap - // TODO: Add DSL to simplify the code here. - val instanceTable = s"instance_$uid" - val userTable = s"user_$uid" - val itemTable = s"item_$uid" - val instances = dataset.as(Symbol(instanceTable)) - val users = userFactors.map { case (id, features) => - Factor(id, features) - }.as(Symbol(userTable)) - val items = itemFactors.map { case (id, features) => - Factor(id, features) - }.as(Symbol(itemTable)) + val users = userFactors.toDataFrame("id", "features") + val items = itemFactors.toDataFrame("id", "features") val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => { if (userFeatures != null && itemFeatures != null) { blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1) @@ -135,13 +126,14 @@ class ALSModel private[ml] ( } } val inputColumns = dataset.schema.fieldNames - val prediction = - predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol) - val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction - instances - .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr)) - .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr)) + val prediction = callUDF(predict, users("features"), items("features")).as(map(predictionCol)) + val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction + dataset + .join(users, dataset(map(userCol)) === users("id"), "left") + .join(items, dataset(map(itemCol)) === items("id"), "left") .select(outputColumns: _*) + // TODO: Just use a dataset("*") + // .select(dataset("*"), prediction) } override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType = { @@ -149,10 +141,6 @@ class ALSModel private[ml] ( } } -private object ALSModel { - /** Case class to convert factors to SchemaRDDs */ - private case class Factor(id: Int, features: Seq[Float]) -} /** * Alternating Least Squares (ALS) matrix factorization. @@ -209,14 +197,13 @@ class ALS extends Estimator[ALSModel] with ALSParams { setMaxIter(20) setRegParam(1.0) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = { - import dataset.sqlContext._ + override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = { val map = this.paramMap ++ paramMap - val ratings = - dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType)) - .map { row => - new Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) - } + val ratings = dataset + .select(col(map(userCol)), col(map(itemCol)), col(map(ratingCol)).cast(FloatType)) + .map { row => + Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) + } val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank), numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks), maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs), @@ -231,10 +218,19 @@ class ALS extends Estimator[ALSModel] with ALSParams { } } -private[recommendation] object ALS extends Logging { +/** + * :: DeveloperApi :: + * An implementation of ALS that supports generic ID types, specialized for Int and Long. This is + * exposed as a developer API for users who do need other ID types. But it is not recommended + * because it increases the shuffle size and memory requirement during training. For simplicity, + * users and items must have the same type. The number of distinct users/items should be smaller + * than 2 billion. + */ +@DeveloperApi +object ALS extends Logging { /** Rating class for better code readability. */ - private[recommendation] case class Rating(user: Int, item: Int, rating: Float) + case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float) /** Cholesky solver for least square problems. */ private[recommendation] class CholeskySolver { @@ -301,7 +297,7 @@ private[recommendation] object ALS extends Logging { /** Adds an observation. */ def add(a: Array[Float], b: Float): this.type = { - require(a.size == k) + require(a.length == k) copyToDouble(a) blas.dspr(upper, k, 1.0, da, 1, ata) blas.daxpy(k, b.toDouble, da, 1, atb, 1) @@ -313,7 +309,7 @@ private[recommendation] object ALS extends Logging { * Adds an observation with implicit feedback. Note that this does not increment the counter. */ def addImplicit(a: Array[Float], b: Float, alpha: Double): this.type = { - require(a.size == k) + require(a.length == k) // Extension to the original paper to handle b < 0. confidence is a function of |b| instead // so that it is never negative. val confidence = 1.0 + alpha * math.abs(b) @@ -329,8 +325,8 @@ private[recommendation] object ALS extends Logging { /** Merges another normal equation object. */ def merge(other: NormalEquation): this.type = { require(other.k == k) - blas.daxpy(ata.size, 1.0, other.ata, 1, ata, 1) - blas.daxpy(atb.size, 1.0, other.atb, 1, atb, 1) + blas.daxpy(ata.length, 1.0, other.ata, 1, ata, 1) + blas.daxpy(atb.length, 1.0, other.atb, 1, atb, 1) n += other.n this } @@ -346,15 +342,16 @@ private[recommendation] object ALS extends Logging { /** * Implementation of the ALS algorithm. */ - private def train( - ratings: RDD[Rating], + def train[ID: ClassTag]( + ratings: RDD[Rating[ID]], rank: Int = 10, numUserBlocks: Int = 10, numItemBlocks: Int = 10, maxIter: Int = 10, regParam: Double = 1.0, implicitPrefs: Boolean = false, - alpha: Double = 1.0): (RDD[(Int, Array[Float])], RDD[(Int, Array[Float])]) = { + alpha: Double = 1.0)( + implicit ord: Ordering[ID]): (RDD[(ID, Array[Float])], RDD[(ID, Array[Float])]) = { val userPart = new HashPartitioner(numUserBlocks) val itemPart = new HashPartitioner(numItemBlocks) val userLocalIndexEncoder = new LocalIndexEncoder(userPart.numPartitions) @@ -457,16 +454,15 @@ private[recommendation] object ALS extends Logging { * * @see [[LocalIndexEncoder]] */ - private[recommendation] case class InBlock( - srcIds: Array[Int], + private[recommendation] case class InBlock[@specialized(Int, Long) ID: ClassTag]( + srcIds: Array[ID], dstPtrs: Array[Int], dstEncodedIndices: Array[Int], ratings: Array[Float]) { /** Size of the block. */ - val size: Int = ratings.size - - require(dstEncodedIndices.size == size) - require(dstPtrs.size == srcIds.size + 1) + def size: Int = ratings.length + require(dstEncodedIndices.length == size) + require(dstPtrs.length == srcIds.length + 1) } /** @@ -476,7 +472,9 @@ private[recommendation] object ALS extends Logging { * @param rank rank * @return initialized factor blocks */ - private def initialize(inBlocks: RDD[(Int, InBlock)], rank: Int): RDD[(Int, FactorBlock)] = { + private def initialize[ID]( + inBlocks: RDD[(Int, InBlock[ID])], + rank: Int): RDD[(Int, FactorBlock)] = { // Choose a unit vector uniformly at random from the unit sphere, but from the // "first quadrant" where all elements are nonnegative. This can be done by choosing // elements distributed as Normal(0,1) and taking the absolute value, and then normalizing. @@ -484,7 +482,7 @@ private[recommendation] object ALS extends Logging { // (<1%) compared picking elements uniformly at random in [0,1]. inBlocks.map { case (srcBlockId, inBlock) => val random = new XORShiftRandom(srcBlockId) - val factors = Array.fill(inBlock.srcIds.size) { + val factors = Array.fill(inBlock.srcIds.length) { val factor = Array.fill(rank)(random.nextGaussian().toFloat) val nrm = blas.snrm2(rank, factor, 1) blas.sscal(rank, 1.0f / nrm, factor, 1) @@ -497,26 +495,29 @@ private[recommendation] object ALS extends Logging { /** * A rating block that contains src IDs, dst IDs, and ratings, stored in primitive arrays. */ - private[recommendation] - case class RatingBlock(srcIds: Array[Int], dstIds: Array[Int], ratings: Array[Float]) { + private[recommendation] case class RatingBlock[@specialized(Int, Long) ID: ClassTag]( + srcIds: Array[ID], + dstIds: Array[ID], + ratings: Array[Float]) { /** Size of the block. */ - val size: Int = srcIds.size - require(dstIds.size == size) - require(ratings.size == size) + def size: Int = srcIds.length + require(dstIds.length == srcIds.length) + require(ratings.length == srcIds.length) } /** * Builder for [[RatingBlock]]. [[mutable.ArrayBuilder]] is used to avoid boxing/unboxing. */ - private[recommendation] class RatingBlockBuilder extends Serializable { + private[recommendation] class RatingBlockBuilder[@specialized(Int, Long) ID: ClassTag] + extends Serializable { - private val srcIds = mutable.ArrayBuilder.make[Int] - private val dstIds = mutable.ArrayBuilder.make[Int] + private val srcIds = mutable.ArrayBuilder.make[ID] + private val dstIds = mutable.ArrayBuilder.make[ID] private val ratings = mutable.ArrayBuilder.make[Float] var size = 0 /** Adds a rating. */ - def add(r: Rating): this.type = { + def add(r: Rating[ID]): this.type = { size += 1 srcIds += r.user dstIds += r.item @@ -525,8 +526,8 @@ private[recommendation] object ALS extends Logging { } /** Merges another [[RatingBlockBuilder]]. */ - def merge(other: RatingBlock): this.type = { - size += other.srcIds.size + def merge(other: RatingBlock[ID]): this.type = { + size += other.srcIds.length srcIds ++= other.srcIds dstIds ++= other.dstIds ratings ++= other.ratings @@ -534,8 +535,8 @@ private[recommendation] object ALS extends Logging { } /** Builds a [[RatingBlock]]. */ - def build(): RatingBlock = { - RatingBlock(srcIds.result(), dstIds.result(), ratings.result()) + def build(): RatingBlock[ID] = { + RatingBlock[ID](srcIds.result(), dstIds.result(), ratings.result()) } } @@ -548,10 +549,10 @@ private[recommendation] object ALS extends Logging { * * @return an RDD of rating blocks in the form of ((srcBlockId, dstBlockId), ratingBlock) */ - private def partitionRatings( - ratings: RDD[Rating], + private def partitionRatings[ID: ClassTag]( + ratings: RDD[Rating[ID]], srcPart: Partitioner, - dstPart: Partitioner): RDD[((Int, Int), RatingBlock)] = { + dstPart: Partitioner): RDD[((Int, Int), RatingBlock[ID])] = { /* The implementation produces the same result as the following but generates less objects. @@ -565,7 +566,7 @@ private[recommendation] object ALS extends Logging { val numPartitions = srcPart.numPartitions * dstPart.numPartitions ratings.mapPartitions { iter => - val builders = Array.fill(numPartitions)(new RatingBlockBuilder) + val builders = Array.fill(numPartitions)(new RatingBlockBuilder[ID]) iter.flatMap { r => val srcBlockId = srcPart.getPartition(r.user) val dstBlockId = dstPart.getPartition(r.item) @@ -586,7 +587,7 @@ private[recommendation] object ALS extends Logging { } } }.groupByKey().mapValues { blocks => - val builder = new RatingBlockBuilder + val builder = new RatingBlockBuilder[ID] blocks.foreach(builder.merge) builder.build() }.setName("ratingBlocks") @@ -596,9 +597,11 @@ private[recommendation] object ALS extends Logging { * Builder for uncompressed in-blocks of (srcId, dstEncodedIndex, rating) tuples. * @param encoder encoder for dst indices */ - private[recommendation] class UncompressedInBlockBuilder(encoder: LocalIndexEncoder) { + private[recommendation] class UncompressedInBlockBuilder[@specialized(Int, Long) ID: ClassTag]( + encoder: LocalIndexEncoder)( + implicit ord: Ordering[ID]) { - private val srcIds = mutable.ArrayBuilder.make[Int] + private val srcIds = mutable.ArrayBuilder.make[ID] private val dstEncodedIndices = mutable.ArrayBuilder.make[Int] private val ratings = mutable.ArrayBuilder.make[Float] @@ -612,12 +615,12 @@ private[recommendation] object ALS extends Logging { */ def add( dstBlockId: Int, - srcIds: Array[Int], + srcIds: Array[ID], dstLocalIndices: Array[Int], ratings: Array[Float]): this.type = { - val sz = srcIds.size - require(dstLocalIndices.size == sz) - require(ratings.size == sz) + val sz = srcIds.length + require(dstLocalIndices.length == sz) + require(ratings.length == sz) this.srcIds ++= srcIds this.ratings ++= ratings var j = 0 @@ -629,7 +632,7 @@ private[recommendation] object ALS extends Logging { } /** Builds a [[UncompressedInBlock]]. */ - def build(): UncompressedInBlock = { + def build(): UncompressedInBlock[ID] = { new UncompressedInBlock(srcIds.result(), dstEncodedIndices.result(), ratings.result()) } } @@ -637,24 +640,25 @@ private[recommendation] object ALS extends Logging { /** * A block of (srcId, dstEncodedIndex, rating) tuples stored in primitive arrays. */ - private[recommendation] class UncompressedInBlock( - val srcIds: Array[Int], + private[recommendation] class UncompressedInBlock[@specialized(Int, Long) ID: ClassTag]( + val srcIds: Array[ID], val dstEncodedIndices: Array[Int], - val ratings: Array[Float]) { + val ratings: Array[Float])( + implicit ord: Ordering[ID]) { /** Size the of block. */ - def size: Int = srcIds.size + def length: Int = srcIds.length /** * Compresses the block into an [[InBlock]]. The algorithm is the same as converting a * sparse matrix from coordinate list (COO) format into compressed sparse column (CSC) format. * Sorting is done using Spark's built-in Timsort to avoid generating too many objects. */ - def compress(): InBlock = { - val sz = size + def compress(): InBlock[ID] = { + val sz = length assert(sz > 0, "Empty in-link block should not exist.") sort() - val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[Int] + val uniqueSrcIdsBuilder = mutable.ArrayBuilder.make[ID] val dstCountsBuilder = mutable.ArrayBuilder.make[Int] var preSrcId = srcIds(0) uniqueSrcIdsBuilder += preSrcId @@ -675,7 +679,7 @@ private[recommendation] object ALS extends Logging { } dstCountsBuilder += curCount val uniqueSrcIds = uniqueSrcIdsBuilder.result() - val numUniqueSrdIds = uniqueSrcIds.size + val numUniqueSrdIds = uniqueSrcIds.length val dstCounts = dstCountsBuilder.result() val dstPtrs = new Array[Int](numUniqueSrdIds + 1) var sum = 0 @@ -689,51 +693,61 @@ private[recommendation] object ALS extends Logging { } private def sort(): Unit = { - val sz = size + val sz = length // Since there might be interleaved log messages, we insert a unique id for easy pairing. val sortId = Utils.random.nextInt() logDebug(s"Start sorting an uncompressed in-block of size $sz. (sortId = $sortId)") val start = System.nanoTime() - val sorter = new Sorter(new UncompressedInBlockSort) - sorter.sort(this, 0, size, Ordering[IntWrapper]) + val sorter = new Sorter(new UncompressedInBlockSort[ID]) + sorter.sort(this, 0, length, Ordering[KeyWrapper[ID]]) val duration = (System.nanoTime() - start) / 1e9 logDebug(s"Sorting took $duration seconds. (sortId = $sortId)") } } /** - * A wrapper that holds a primitive integer key. + * A wrapper that holds a primitive key. * * @see [[UncompressedInBlockSort]] */ - private class IntWrapper(var key: Int = 0) extends Ordered[IntWrapper] { - override def compare(that: IntWrapper): Int = { - key.compare(that.key) + private class KeyWrapper[@specialized(Int, Long) ID: ClassTag]( + implicit ord: Ordering[ID]) extends Ordered[KeyWrapper[ID]] { + + var key: ID = _ + + override def compare(that: KeyWrapper[ID]): Int = { + ord.compare(key, that.key) + } + + def setKey(key: ID): this.type = { + this.key = key + this } } /** * [[SortDataFormat]] of [[UncompressedInBlock]] used by [[Sorter]]. */ - private class UncompressedInBlockSort extends SortDataFormat[IntWrapper, UncompressedInBlock] { + private class UncompressedInBlockSort[@specialized(Int, Long) ID: ClassTag]( + implicit ord: Ordering[ID]) + extends SortDataFormat[KeyWrapper[ID], UncompressedInBlock[ID]] { - override def newKey(): IntWrapper = new IntWrapper() + override def newKey(): KeyWrapper[ID] = new KeyWrapper() override def getKey( - data: UncompressedInBlock, + data: UncompressedInBlock[ID], pos: Int, - reuse: IntWrapper): IntWrapper = { + reuse: KeyWrapper[ID]): KeyWrapper[ID] = { if (reuse == null) { - new IntWrapper(data.srcIds(pos)) + new KeyWrapper().setKey(data.srcIds(pos)) } else { - reuse.key = data.srcIds(pos) - reuse + reuse.setKey(data.srcIds(pos)) } } override def getKey( - data: UncompressedInBlock, - pos: Int): IntWrapper = { + data: UncompressedInBlock[ID], + pos: Int): KeyWrapper[ID] = { getKey(data, pos, null) } @@ -746,16 +760,16 @@ private[recommendation] object ALS extends Logging { data(pos1) = tmp } - override def swap(data: UncompressedInBlock, pos0: Int, pos1: Int): Unit = { + override def swap(data: UncompressedInBlock[ID], pos0: Int, pos1: Int): Unit = { swapElements(data.srcIds, pos0, pos1) swapElements(data.dstEncodedIndices, pos0, pos1) swapElements(data.ratings, pos0, pos1) } override def copyRange( - src: UncompressedInBlock, + src: UncompressedInBlock[ID], srcPos: Int, - dst: UncompressedInBlock, + dst: UncompressedInBlock[ID], dstPos: Int, length: Int): Unit = { System.arraycopy(src.srcIds, srcPos, dst.srcIds, dstPos, length) @@ -763,15 +777,15 @@ private[recommendation] object ALS extends Logging { System.arraycopy(src.ratings, srcPos, dst.ratings, dstPos, length) } - override def allocate(length: Int): UncompressedInBlock = { + override def allocate(length: Int): UncompressedInBlock[ID] = { new UncompressedInBlock( - new Array[Int](length), new Array[Int](length), new Array[Float](length)) + new Array[ID](length), new Array[Int](length), new Array[Float](length)) } override def copyElement( - src: UncompressedInBlock, + src: UncompressedInBlock[ID], srcPos: Int, - dst: UncompressedInBlock, + dst: UncompressedInBlock[ID], dstPos: Int): Unit = { dst.srcIds(dstPos) = src.srcIds(srcPos) dst.dstEncodedIndices(dstPos) = src.dstEncodedIndices(srcPos) @@ -787,19 +801,20 @@ private[recommendation] object ALS extends Logging { * @param dstPart partitioner for dst IDs * @return (in-blocks, out-blocks) */ - private def makeBlocks( + private def makeBlocks[ID: ClassTag]( prefix: String, - ratingBlocks: RDD[((Int, Int), RatingBlock)], + ratingBlocks: RDD[((Int, Int), RatingBlock[ID])], srcPart: Partitioner, - dstPart: Partitioner): (RDD[(Int, InBlock)], RDD[(Int, OutBlock)]) = { + dstPart: Partitioner)( + implicit srcOrd: Ordering[ID]): (RDD[(Int, InBlock[ID])], RDD[(Int, OutBlock)]) = { val inBlocks = ratingBlocks.map { case ((srcBlockId, dstBlockId), RatingBlock(srcIds, dstIds, ratings)) => // The implementation is a faster version of // val dstIdToLocalIndex = dstIds.toSet.toSeq.sorted.zipWithIndex.toMap val start = System.nanoTime() - val dstIdSet = new OpenHashSet[Int](1 << 20) + val dstIdSet = new OpenHashSet[ID](1 << 20) dstIds.foreach(dstIdSet.add) - val sortedDstIds = new Array[Int](dstIdSet.size) + val sortedDstIds = new Array[ID](dstIdSet.size) var i = 0 var pos = dstIdSet.nextPos(0) while (pos != -1) { @@ -808,10 +823,10 @@ private[recommendation] object ALS extends Logging { i += 1 } assert(i == dstIdSet.size) - ju.Arrays.sort(sortedDstIds) - val dstIdToLocalIndex = new OpenHashMap[Int, Int](sortedDstIds.size) + Sorting.quickSort(sortedDstIds) + val dstIdToLocalIndex = new OpenHashMap[ID, Int](sortedDstIds.length) i = 0 - while (i < sortedDstIds.size) { + while (i < sortedDstIds.length) { dstIdToLocalIndex.update(sortedDstIds(i), i) i += 1 } @@ -822,7 +837,7 @@ private[recommendation] object ALS extends Logging { }.groupByKey(new HashPartitioner(srcPart.numPartitions)) .mapValues { iter => val builder = - new UncompressedInBlockBuilder(new LocalIndexEncoder(dstPart.numPartitions)) + new UncompressedInBlockBuilder[ID](new LocalIndexEncoder(dstPart.numPartitions)) iter.foreach { case (dstBlockId, srcIds, dstLocalIndices, ratings) => builder.add(dstBlockId, srcIds, dstLocalIndices, ratings) } @@ -833,7 +848,7 @@ private[recommendation] object ALS extends Logging { val activeIds = Array.fill(dstPart.numPartitions)(mutable.ArrayBuilder.make[Int]) var i = 0 val seen = new Array[Boolean](dstPart.numPartitions) - while (i < srcIds.size) { + while (i < srcIds.length) { var j = dstPtrs(i) ju.Arrays.fill(seen, false) while (j < dstPtrs(i + 1)) { @@ -867,16 +882,16 @@ private[recommendation] object ALS extends Logging { * * @return dst factors */ - private def computeFactors( + private def computeFactors[ID]( srcFactorBlocks: RDD[(Int, FactorBlock)], srcOutBlocks: RDD[(Int, OutBlock)], - dstInBlocks: RDD[(Int, InBlock)], + dstInBlocks: RDD[(Int, InBlock[ID])], rank: Int, regParam: Double, srcEncoder: LocalIndexEncoder, implicitPrefs: Boolean = false, alpha: Double = 1.0): RDD[(Int, FactorBlock)] = { - val numSrcBlocks = srcFactorBlocks.partitions.size + val numSrcBlocks = srcFactorBlocks.partitions.length val YtY = if (implicitPrefs) Some(computeYtY(srcFactorBlocks, rank)) else None val srcOut = srcOutBlocks.join(srcFactorBlocks).flatMap { case (srcBlockId, (srcOutBlock, srcFactors)) => @@ -884,18 +899,18 @@ private[recommendation] object ALS extends Logging { (dstBlockId, (srcBlockId, activeIndices.map(idx => srcFactors(idx)))) } } - val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.size)) + val merged = srcOut.groupByKey(new HashPartitioner(dstInBlocks.partitions.length)) dstInBlocks.join(merged).mapValues { case (InBlock(dstIds, srcPtrs, srcEncodedIndices, ratings), srcFactors) => val sortedSrcFactors = new Array[FactorBlock](numSrcBlocks) srcFactors.foreach { case (srcBlockId, factors) => sortedSrcFactors(srcBlockId) = factors } - val dstFactors = new Array[Array[Float]](dstIds.size) + val dstFactors = new Array[Array[Float]](dstIds.length) var j = 0 val ls = new NormalEquation(rank) val solver = new CholeskySolver // TODO: add NNLS solver - while (j < dstIds.size) { + while (j < dstIds.length) { ls.reset() if (implicitPrefs) { ls.merge(YtY.get) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 08fe99176424a..5d51c51346665 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml._ import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType /** @@ -64,7 +64,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP def setEvaluator(value: Evaluator): this.type = set(evaluator, value) def setNumFolds(value: Int): this.type = set(numFolds, value) - override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = { + override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = { val map = this.paramMap ++ paramMap val schema = dataset.schema transformSchema(dataset.schema, paramMap, logging = true) @@ -74,7 +74,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP val epm = map(estimatorParamMaps) val numModels = epm.size val metrics = new Array[Double](epm.size) - val splits = MLUtils.kFold(dataset, map(numFolds), 0) + val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0) splits.zipWithIndex.foreach { case ((training, validation), splitIndex) => val trainingDataset = sqlCtx.applySchema(training, schema).cache() val validationDataset = sqlCtx.applySchema(validation, schema).cache() @@ -117,7 +117,7 @@ class CrossValidatorModel private[ml] ( val bestModel: Model[_]) extends Model[CrossValidatorModel] with CrossValidatorParams { - override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = { + override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = { bestModel.transform(dataset, paramMap) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 9c90e774562e7..f973ce1f2b738 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -42,10 +42,11 @@ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.ChiSqTestResult -import org.apache.spark.mllib.tree.{RandomForest, DecisionTree} -import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} +import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree} +import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy} import org.apache.spark.mllib.tree.impurity._ -import org.apache.spark.mllib.tree.model.{RandomForestModel, DecisionTreeModel} +import org.apache.spark.mllib.tree.loss.Losses +import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel, RandomForestModel, DecisionTreeModel} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -586,6 +587,35 @@ class PythonMLLibAPI extends Serializable { } } + /** + * Java stub for Python mllib GradientBoostedTrees.train(). + * This stub returns a handle to the Java object instead of the content of the Java object. + * Extra care needs to be taken in the Python code to ensure it gets freed on exit; + * see the Py4J documentation. + */ + def trainGradientBoostedTreesModel( + data: JavaRDD[LabeledPoint], + algoStr: String, + categoricalFeaturesInfo: JMap[Int, Int], + lossStr: String, + numIterations: Int, + learningRate: Double, + maxDepth: Int): GradientBoostedTreesModel = { + val boostingStrategy = BoostingStrategy.defaultParams(algoStr) + boostingStrategy.setLoss(Losses.fromString(lossStr)) + boostingStrategy.setNumIterations(numIterations) + boostingStrategy.setLearningRate(learningRate) + boostingStrategy.treeStrategy.setMaxDepth(maxDepth) + boostingStrategy.treeStrategy.categoricalFeaturesInfo = categoricalFeaturesInfo.asScala.toMap + + val cached = data.rdd.persist(StorageLevel.MEMORY_AND_DISK) + try { + GradientBoostedTrees.train(cached, boostingStrategy) + } finally { + cached.unpersist(blocking = false) + } + } + /** * Java stub for mllib Statistics.colStats(X: RDD[Vector]). * TODO figure out return type. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala similarity index 99% rename from mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala rename to mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 899fe5e9e9cf2..5c626fde4e657 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureEM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -44,7 +44,7 @@ import org.apache.spark.util.Utils * is considered to have occurred. * @param maxIterations The maximum number of iterations to perform */ -class GaussianMixtureEM private ( +class GaussianMixture private ( private var k: Int, private var convergenceTol: Double, private var maxIterations: Int, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala new file mode 100644 index 0000000000000..fcb9a3643cc48 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -0,0 +1,206 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import org.apache.spark.{Logging, SparkException} +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.rdd.RDD +import org.apache.spark.util.random.XORShiftRandom + +/** + * Model produced by [[PowerIterationClustering]]. + * + * @param k number of clusters + * @param assignments an RDD of (vertexID, clusterID) pairs + */ +class PowerIterationClusteringModel( + val k: Int, + val assignments: RDD[(Long, Int)]) extends Serializable + +/** + * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and + * Cohen (see http://www.icml2010.org/papers/387.pdf). From the abstract: PIC finds a very + * low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise + * similarity matrix of the data. + * + * @param k Number of clusters. + * @param maxIterations Maximum number of iterations of the PIC algorithm. + */ +class PowerIterationClustering private[clustering] ( + private var k: Int, + private var maxIterations: Int) extends Serializable { + + import org.apache.spark.mllib.clustering.PowerIterationClustering._ + + /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100}. */ + def this() = this(k = 2, maxIterations = 100) + + /** + * Set the number of clusters. + */ + def setK(k: Int): this.type = { + this.k = k + this + } + + /** + * Set maximum number of iterations of the power iteration loop + */ + def setMaxIterations(maxIterations: Int): this.type = { + this.maxIterations = maxIterations + this + } + + /** + * Run the PIC algorithm. + * + * @param similarities an RDD of (i, j, s_ij_) tuples representing the affinity matrix, which is + * the matrix A in the PIC paper. The similarity s_ij_ must be nonnegative. + * This is a symmetric matrix and hence s_ij_ = s_ji_. For any (i, j) with + * nonzero similarity, there should be either (i, j, s_ij_) or (j, i, s_ji_) + * in the input. Tuples with i = j are ignored, because we assume s_ij_ = 0.0. + * + * @return a [[PowerIterationClusteringModel]] that contains the clustering result + */ + def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = { + val w = normalize(similarities) + val w0 = randomInit(w) + pic(w0) + } + + /** + * Runs the PIC algorithm. + * + * @param w The normalized affinity matrix, which is the matrix W in the PIC paper with + * w_ij_ = a_ij_ / d_ii_ as its edge properties and the initial vector of the power + * iteration as its vertex properties. + */ + private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = { + val v = powerIter(w, maxIterations) + val assignments = kMeans(v, k) + new PowerIterationClusteringModel(k, assignments) + } +} + +private[clustering] object PowerIterationClustering extends Logging { + /** + * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W). + */ + def normalize(similarities: RDD[(Long, Long, Double)]): Graph[Double, Double] = { + val edges = similarities.flatMap { case (i, j, s) => + if (s < 0.0) { + throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.") + } + if (i != j) { + Seq(Edge(i, j, s), Edge(j, i, s)) + } else { + None + } + } + val gA = Graph.fromEdges(edges, 0.0) + val vD = gA.aggregateMessages[Double]( + sendMsg = ctx => { + ctx.sendToSrc(ctx.attr) + }, + mergeMsg = _ + _, + TripletFields.EdgeOnly) + GraphImpl.fromExistingRDDs(vD, gA.edges) + .mapTriplets( + e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON), + TripletFields.Src) + } + + /** + * Generates random vertex properties (v0) to start power iteration. + * + * @param g a graph representing the normalized affinity matrix (W) + * @return a graph with edges representing W and vertices representing a random vector + * with unit 1-norm + */ + def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = { + val r = g.vertices.mapPartitionsWithIndex( + (part, iter) => { + val random = new XORShiftRandom(part) + iter.map { case (id, _) => + (id, random.nextGaussian()) + } + }, preservesPartitioning = true).cache() + val sum = r.values.map(math.abs).sum() + val v0 = r.mapValues(x => x / sum) + GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges) + } + + /** + * Runs power iteration. + * @param g input graph with edges representing the normalized affinity matrix (W) and vertices + * representing the initial vector of the power iterations. + * @param maxIterations maximum number of iterations + * @return a [[VertexRDD]] representing the pseudo-eigenvector + */ + def powerIter( + g: Graph[Double, Double], + maxIterations: Int): VertexRDD[Double] = { + // the default tolerance used in the PIC paper, with a lower bound 1e-8 + val tol = math.max(1e-5 / g.vertices.count(), 1e-8) + var prevDelta = Double.MaxValue + var diffDelta = Double.MaxValue + var curG = g + for (iter <- 0 until maxIterations if math.abs(diffDelta) > tol) { + val msgPrefix = s"Iteration $iter" + // multiply W by vt + val v = curG.aggregateMessages[Double]( + sendMsg = ctx => ctx.sendToSrc(ctx.attr * ctx.dstAttr), + mergeMsg = _ + _, + TripletFields.Dst).cache() + // normalize v + val norm = v.values.map(math.abs).sum() + logInfo(s"$msgPrefix: norm(v) = $norm.") + val v1 = v.mapValues(x => x / norm) + // compare difference + val delta = curG.joinVertices(v1) { case (_, x, y) => + math.abs(x - y) + }.vertices.values.sum() + logInfo(s"$msgPrefix: delta = $delta.") + diffDelta = math.abs(delta - prevDelta) + logInfo(s"$msgPrefix: diff(delta) = $diffDelta.") + // update v + curG = GraphImpl.fromExistingRDDs(VertexRDD(v1), g.edges) + prevDelta = delta + } + curG.vertices + } + + /** + * Runs k-means clustering. + * @param v a [[VertexRDD]] representing the pseudo-eigenvector + * @param k number of clusters + * @return a [[VertexRDD]] representing the clustering assignments + */ + def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = { + val points = v.mapValues(x => Vectors.dense(x)).cache() + val model = new KMeans() + .setK(k) + .setRuns(5) + .setSeed(0L) + .run(points.values) + points.mapValues(p => model.predict(p)).cache() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala index 3260f27513c7f..a89eea0e21be2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala @@ -22,7 +22,6 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala index 3c2091732f9b0..6ae6917eae595 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/StandardScaler.scala @@ -18,15 +18,14 @@ package org.apache.spark.mllib.feature import org.apache.spark.Logging -import org.apache.spark.annotation.Experimental +import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.rdd.RDD /** * :: Experimental :: - * Standardizes features by removing the mean and scaling to unit variance using column summary + * Standardizes features by removing the mean and scaling to unit std using column summary * statistics on the samples in the training set. * * @param withMean False by default. Centers the data with mean before scaling. It will build a @@ -53,7 +52,11 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { val summary = data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) - new StandardScalerModel(withMean, withStd, summary.mean, summary.variance) + new StandardScalerModel( + Vectors.dense(summary.variance.toArray.map(v => math.sqrt(v))), + summary.mean, + withStd, + withMean) } } @@ -61,28 +64,43 @@ class StandardScaler(withMean: Boolean, withStd: Boolean) extends Logging { * :: Experimental :: * Represents a StandardScaler model that can transform vectors. * - * @param withMean whether to center the data before scaling - * @param withStd whether to scale the data to have unit standard deviation + * @param std column standard deviation values * @param mean column mean values - * @param variance column variance values + * @param withStd whether to scale the data to have unit standard deviation + * @param withMean whether to center the data before scaling */ @Experimental -class StandardScalerModel private[mllib] ( - val withMean: Boolean, - val withStd: Boolean, +class StandardScalerModel ( + val std: Vector, val mean: Vector, - val variance: Vector) extends VectorTransformer { - - require(mean.size == variance.size) + var withStd: Boolean, + var withMean: Boolean) extends VectorTransformer { - private lazy val factor: Array[Double] = { - val f = Array.ofDim[Double](variance.size) - var i = 0 - while (i < f.size) { - f(i) = if (variance(i) != 0.0) 1.0 / math.sqrt(variance(i)) else 0.0 - i += 1 + def this(std: Vector, mean: Vector) { + this(std, mean, withStd = std != null, withMean = mean != null) + require(this.withStd || this.withMean, + "at least one of std or mean vectors must be provided") + if (this.withStd && this.withMean) { + require(mean.size == std.size, + "mean and std vectors must have equal size if both are provided") } - f + } + + def this(std: Vector) = this(std, null) + + @DeveloperApi + def setWithMean(withMean: Boolean): this.type = { + require(!(withMean && this.mean == null),"cannot set withMean to true while mean is null") + this.withMean = withMean + this + } + + @DeveloperApi + def setWithStd(withStd: Boolean): this.type = { + require(!(withStd && this.std == null), + "cannot set withStd to true while std is null") + this.withStd = withStd + this } // Since `shift` will be only used in `withMean` branch, we have it as @@ -94,8 +112,8 @@ class StandardScalerModel private[mllib] ( * Applies standardization transformation on a vector. * * @param vector Vector to be standardized. - * @return Standardized vector. If the variance of a column is zero, it will return default `0.0` - * for the column with zero variance. + * @return Standardized vector. If the std of a column is zero, it will return default `0.0` + * for the column with zero std. */ override def transform(vector: Vector): Vector = { require(mean.size == vector.size) @@ -109,11 +127,9 @@ class StandardScalerModel private[mllib] ( val values = vs.clone() val size = values.size if (withStd) { - // Having a local reference of `factor` to avoid overhead as the comment before. - val localFactor = factor var i = 0 while (i < size) { - values(i) = (values(i) - localShift(i)) * localFactor(i) + values(i) = if (std(i) != 0.0) (values(i) - localShift(i)) * (1.0 / std(i)) else 0.0 i += 1 } } else { @@ -127,15 +143,13 @@ class StandardScalerModel private[mllib] ( case v => throw new IllegalArgumentException("Do not support vector type " + v.getClass) } } else if (withStd) { - // Having a local reference of `factor` to avoid overhead as the comment before. - val localFactor = factor vector match { case DenseVector(vs) => val values = vs.clone() val size = values.size var i = 0 while(i < size) { - values(i) *= localFactor(i) + values(i) *= (if (std(i) != 0.0) 1.0 / std(i) else 0.0) i += 1 } Vectors.dense(values) @@ -146,7 +160,7 @@ class StandardScalerModel private[mllib] ( val nnz = values.size var i = 0 while (i < nnz) { - values(i) *= localFactor(indices(i)) + values(i) *= (if (std(indices(i)) != 0.0) 1.0 / std(indices(i)) else 0.0) i += 1 } Vectors.sparse(size, indices, values) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d25a7cd5b439d..a3e40200bc063 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -290,6 +290,13 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) + + if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + } + val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala index 3414daccd7ca4..079f7ca564a92 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala @@ -257,80 +257,55 @@ private[spark] object BLAS extends Serializable with Logging { /** * C := alpha * A * B + beta * C - * @param transA whether to use the transpose of matrix A (true), or A itself (false). - * @param transB whether to use the transpose of matrix B (true), or B itself (false). * @param alpha a scalar to scale the multiplication A * B. * @param A the matrix A that will be left multiplied to B. Size of m x k. * @param B the matrix B that will be left multiplied by A. Size of k x n. * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. + * @param C the resulting matrix C. Size of m x n. C.isTransposed must be false. */ def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: Matrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { + require(!C.isTransposed, + "The matrix C cannot be the product of a transpose() call. C.isTransposed must be false.") if (alpha == 0.0) { logDebug("gemm: alpha is equal to 0. Returning C.") } else { A match { - case sparse: SparseMatrix => - gemm(transA, transB, alpha, sparse, B, beta, C) - case dense: DenseMatrix => - gemm(transA, transB, alpha, dense, B, beta, C) + case sparse: SparseMatrix => gemm(alpha, sparse, B, beta, C) + case dense: DenseMatrix => gemm(alpha, dense, B, beta, C) case _ => throw new IllegalArgumentException(s"gemm doesn't support matrix type ${A.getClass}.") } } } - /** - * C := alpha * A * B + beta * C - * - * @param alpha a scalar to scale the multiplication A * B. - * @param A the matrix A that will be left multiplied to B. Size of m x k. - * @param B the matrix B that will be left multiplied by A. Size of k x n. - * @param beta a scalar that can be used to scale matrix C. - * @param C the resulting matrix C. Size of m x n. - */ - def gemm( - alpha: Double, - A: Matrix, - B: DenseMatrix, - beta: Double, - C: DenseMatrix): Unit = { - gemm(false, false, alpha, A, B, beta, C) - } - /** * C := alpha * A * B + beta * C * For `DenseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: DenseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols - val tAstr = if (!transA) "N" else "T" - val tBstr = if (!transB) "N" else "T" - - require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") - require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") - require(nB == C.numCols, - s"The columns of C don't match the columns of B. C: ${C.numCols}, A: $nB") - - nativeBLAS.dgemm(tAstr, tBstr, mA, nB, kA, alpha, A.values, A.numRows, B.values, B.numRows, - beta, C.values, C.numRows) + val tAstr = if (A.isTransposed) "T" else "N" + val tBstr = if (B.isTransposed) "T" else "N" + val lda = if (!A.isTransposed) A.numRows else A.numCols + val ldb = if (!B.isTransposed) B.numRows else B.numCols + + require(A.numCols == B.numRows, + s"The columns of A don't match the rows of B. A: ${A.numCols}, B: ${B.numRows}") + require(A.numRows == C.numRows, + s"The rows of C don't match the rows of A. C: ${C.numRows}, A: ${A.numRows}") + require(B.numCols == C.numCols, + s"The columns of C don't match the columns of B. C: ${C.numCols}, A: ${B.numCols}") + nativeBLAS.dgemm(tAstr, tBstr, A.numRows, B.numCols, A.numCols, alpha, A.values, lda, + B.values, ldb, beta, C.values, C.numRows) } /** @@ -338,17 +313,15 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemm( - transA: Boolean, - transB: Boolean, alpha: Double, A: SparseMatrix, B: DenseMatrix, beta: Double, C: DenseMatrix): Unit = { - val mA: Int = if (!transA) A.numRows else A.numCols - val nB: Int = if (!transB) B.numCols else B.numRows - val kA: Int = if (!transA) A.numCols else A.numRows - val kB: Int = if (!transB) B.numRows else B.numCols + val mA: Int = A.numRows + val nB: Int = B.numCols + val kA: Int = A.numCols + val kB: Int = B.numRows require(kA == kB, s"The columns of A don't match the rows of B. A: $kA, B: $kB") require(mA == C.numRows, s"The rows of C don't match the rows of A. C: ${C.numRows}, A: $mA") @@ -358,23 +331,23 @@ private[spark] object BLAS extends Serializable with Logging { val Avals = A.values val Bvals = B.values val Cvals = C.values - val Arows = if (!transA) A.rowIndices else A.colPtrs - val Acols = if (!transA) A.colPtrs else A.rowIndices + val ArowIndices = A.rowIndices + val AcolPtrs = A.colPtrs // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (transA){ + if (A.isTransposed){ var colCounterForB = 0 - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var rowCounterForA = 0 val Cstart = colCounterForB * mA val Bstart = colCounterForB * kA while (rowCounterForA < mA) { - var i = Arows(rowCounterForA) - val indEnd = Arows(rowCounterForA + 1) + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * Bvals(Bstart + Acols(i)) + sum += Avals(i) * Bvals(Bstart + ArowIndices(i)) i += 1 } val Cindex = Cstart + rowCounterForA @@ -385,19 +358,19 @@ private[spark] object BLAS extends Serializable with Logging { } } else { while (colCounterForB < nB) { - var rowCounter = 0 + var rowCounterForA = 0 val Cstart = colCounterForB * mA - while (rowCounter < mA) { - var i = Arows(rowCounter) - val indEnd = Arows(rowCounter + 1) + while (rowCounterForA < mA) { + var i = AcolPtrs(rowCounterForA) + val indEnd = AcolPtrs(rowCounterForA + 1) var sum = 0.0 while (i < indEnd) { - sum += Avals(i) * B(colCounterForB, Acols(i)) + sum += Avals(i) * B(ArowIndices(i), colCounterForB) i += 1 } - val Cindex = Cstart + rowCounter + val Cindex = Cstart + rowCounterForA Cvals(Cindex) = beta * Cvals(Cindex) + sum * alpha - rowCounter += 1 + rowCounterForA += 1 } colCounterForB += 1 } @@ -410,17 +383,17 @@ private[spark] object BLAS extends Serializable with Logging { // Perform matrix multiplication and add to C. The rows of A are multiplied by the columns of // B, and added to C. var colCounterForB = 0 // the column to be updated in C - if (!transB) { // Expensive to put the check inside the loop + if (!B.isTransposed) { // Expensive to put the check inside the loop while (colCounterForB < nB) { var colCounterForA = 0 // The column of A to multiply with the row of B val Bstart = colCounterForB * kB val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) val Bval = Bvals(Bstart + colCounterForA) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -432,11 +405,11 @@ private[spark] object BLAS extends Serializable with Logging { var colCounterForA = 0 // The column of A to multiply with the row of B val Cstart = colCounterForB * mA while (colCounterForA < kA) { - var i = Acols(colCounterForA) - val indEnd = Acols(colCounterForA + 1) - val Bval = B(colCounterForB, colCounterForA) * alpha + var i = AcolPtrs(colCounterForA) + val indEnd = AcolPtrs(colCounterForA + 1) + val Bval = B(colCounterForA, colCounterForB) * alpha while (i < indEnd) { - Cvals(Cstart + Arows(i)) += Avals(i) * Bval + Cvals(Cstart + ArowIndices(i)) += Avals(i) * Bval i += 1 } colCounterForA += 1 @@ -449,7 +422,6 @@ private[spark] object BLAS extends Serializable with Logging { /** * y := alpha * A * x + beta * y - * @param trans whether to use the transpose of matrix A (true), or A itself (false). * @param alpha a scalar to scale the multiplication A * x. * @param A the matrix A that will be left multiplied to x. Size of m x n. * @param x the vector x that will be left multiplied by A. Size of n x 1. @@ -457,65 +429,43 @@ private[spark] object BLAS extends Serializable with Logging { * @param y the resulting vector y. Size of m x 1. */ def gemv( - trans: Boolean, alpha: Double, A: Matrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - - val mA: Int = if (!trans) A.numRows else A.numCols - val nx: Int = x.size - val nA: Int = if (!trans) A.numCols else A.numRows - - require(nA == nx, s"The columns of A don't match the number of elements of x. A: $nA, x: $nx") - require(mA == y.size, - s"The rows of A don't match the number of elements of y. A: $mA, y:${y.size}}") + require(A.numCols == x.size, + s"The columns of A don't match the number of elements of x. A: ${A.numCols}, x: ${x.size}") + require(A.numRows == y.size, + s"The rows of A don't match the number of elements of y. A: ${A.numRows}, y:${y.size}}") if (alpha == 0.0) { logDebug("gemv: alpha is equal to 0. Returning y.") } else { A match { case sparse: SparseMatrix => - gemv(trans, alpha, sparse, x, beta, y) + gemv(alpha, sparse, x, beta, y) case dense: DenseMatrix => - gemv(trans, alpha, dense, x, beta, y) + gemv(alpha, dense, x, beta, y) case _ => throw new IllegalArgumentException(s"gemv doesn't support matrix type ${A.getClass}.") } } } - /** - * y := alpha * A * x + beta * y - * - * @param alpha a scalar to scale the multiplication A * x. - * @param A the matrix A that will be left multiplied to x. Size of m x n. - * @param x the vector x that will be left multiplied by A. Size of n x 1. - * @param beta a scalar that can be used to scale vector y. - * @param y the resulting vector y. Size of m x 1. - */ - def gemv( - alpha: Double, - A: Matrix, - x: DenseVector, - beta: Double, - y: DenseVector): Unit = { - gemv(false, alpha, A, x, beta, y) - } - /** * y := alpha * A * x + beta * y * For `DenseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: DenseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val tStrA = if (!trans) "N" else "T" - nativeBLAS.dgemv(tStrA, A.numRows, A.numCols, alpha, A.values, A.numRows, x.values, 1, beta, + val tStrA = if (A.isTransposed) "T" else "N" + val mA = if (!A.isTransposed) A.numRows else A.numCols + val nA = if (!A.isTransposed) A.numCols else A.numRows + nativeBLAS.dgemv(tStrA, mA, nA, alpha, A.values, mA, x.values, 1, beta, y.values, 1) } @@ -524,24 +474,21 @@ private[spark] object BLAS extends Serializable with Logging { * For `SparseMatrix` A. */ private def gemv( - trans: Boolean, alpha: Double, A: SparseMatrix, x: DenseVector, beta: Double, y: DenseVector): Unit = { - val xValues = x.values val yValues = y.values - - val mA: Int = if (!trans) A.numRows else A.numCols - val nA: Int = if (!trans) A.numCols else A.numRows + val mA: Int = A.numRows + val nA: Int = A.numCols val Avals = A.values - val Arows = if (!trans) A.rowIndices else A.colPtrs - val Acols = if (!trans) A.colPtrs else A.rowIndices + val Arows = if (!A.isTransposed) A.rowIndices else A.colPtrs + val Acols = if (!A.isTransposed) A.colPtrs else A.rowIndices // Slicing is easy in this case. This is the optimal multiplication setting for sparse matrices - if (trans) { + if (A.isTransposed) { var rowCounter = 0 while (rowCounter < mA) { var i = Arows(rowCounter) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 5a7281ec6dc3c..ad7e86827b368 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -34,8 +34,17 @@ sealed trait Matrix extends Serializable { /** Number of columns. */ def numCols: Int + /** Flag that keeps track whether the matrix is transposed or not. False by default. */ + val isTransposed: Boolean = false + /** Converts to a dense array in column major. */ - def toArray: Array[Double] + def toArray: Array[Double] = { + val newArray = new Array[Double](numRows * numCols) + foreachActive { (i, j, v) => + newArray(j * numRows + i) = v + } + newArray + } /** Converts to a breeze matrix. */ private[mllib] def toBreeze: BM[Double] @@ -52,10 +61,13 @@ sealed trait Matrix extends Serializable { /** Get a deep copy of the matrix. */ def copy: Matrix + /** Transpose the Matrix. Returns a new `Matrix` instance sharing the same underlying data. */ + def transpose: Matrix + /** Convenience method for `Matrix`-`DenseMatrix` multiplication. */ def multiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numRows, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(false, false, 1.0, this, y, 0.0, C) + val C: DenseMatrix = DenseMatrix.zeros(numRows, y.numCols) + BLAS.gemm(1.0, this, y, 0.0, C) C } @@ -66,20 +78,6 @@ sealed trait Matrix extends Serializable { output } - /** Convenience method for `Matrix`^T^-`DenseMatrix` multiplication. */ - private[mllib] def transposeMultiply(y: DenseMatrix): DenseMatrix = { - val C: DenseMatrix = Matrices.zeros(numCols, y.numCols).asInstanceOf[DenseMatrix] - BLAS.gemm(true, false, 1.0, this, y, 0.0, C) - C - } - - /** Convenience method for `Matrix`^T^-`DenseVector` multiplication. */ - private[mllib] def transposeMultiply(y: DenseVector): DenseVector = { - val output = new DenseVector(new Array[Double](numCols)) - BLAS.gemv(true, 1.0, this, y, 0.0, output) - output - } - /** A human readable representation of the matrix */ override def toString: String = toBreeze.toString() @@ -92,6 +90,16 @@ sealed trait Matrix extends Serializable { * backing array. For example, an operation such as addition or subtraction will only be * performed on the non-zero values in a `SparseMatrix`. */ private[mllib] def update(f: Double => Double): Matrix + + /** + * Applies a function `f` to all the active elements of dense and sparse matrix. The ordering + * of the elements are not defined. + * + * @param f the function takes three parameters where the first two parameters are the row + * and column indices respectively with the type `Int`, and the final parameter is the + * corresponding value in the matrix with type `Double`. + */ + private[spark] def foreachActive(f: (Int, Int, Double) => Unit) } /** @@ -108,13 +116,35 @@ sealed trait Matrix extends Serializable { * @param numRows number of rows * @param numCols number of columns * @param values matrix entries in column major + * @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in + * row major. */ -class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) extends Matrix { +class DenseMatrix( + val numRows: Int, + val numCols: Int, + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == numRows * numCols, "The number of values supplied doesn't match the " + s"size of the matrix! values.length: ${values.length}, numRows * numCols: ${numRows * numCols}") - override def toArray: Array[Double] = values + /** + * Column-major dense matrix. + * The entry values are stored in a single array of doubles with columns listed in sequence. + * For example, the following matrix + * {{{ + * 1.0 2.0 + * 3.0 4.0 + * 5.0 6.0 + * }}} + * is stored as `[1.0, 3.0, 5.0, 2.0, 4.0, 6.0]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param values matrix entries in column major + */ + def this(numRows: Int, numCols: Int, values: Array[Double]) = + this(numRows, numCols, values, false) override def equals(o: Any) = o match { case m: DenseMatrix => @@ -122,13 +152,22 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) case _ => false } - private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values) + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BDM[Double](numRows, numCols, values) + } else { + val breezeMatrix = new BDM[Double](numCols, numRows, values) + breezeMatrix.t + } + } private[mllib] def apply(i: Int): Double = values(i) private[mllib] def apply(i: Int, j: Int): Double = values(index(i, j)) - private[mllib] def index(i: Int, j: Int): Int = i + numRows * j + private[mllib] def index(i: Int, j: Int): Int = { + if (!isTransposed) i + numRows * j else j + numCols * i + } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { values(index(i, j)) = v @@ -148,7 +187,38 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) this } - /** Generate a `SparseMatrix` from the given `DenseMatrix`. */ + override def transpose: Matrix = new DenseMatrix(numCols, numRows, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + // outer loop over columns + var j = 0 + while (j < numCols) { + var i = 0 + val indStart = j * numRows + while (i < numRows) { + f(i, j, values(indStart + i)) + i += 1 + } + j += 1 + } + } else { + // outer loop over rows + var i = 0 + while (i < numRows) { + var j = 0 + val indStart = i * numCols + while (j < numCols) { + f(i, j, values(indStart + j)) + j += 1 + } + i += 1 + } + } + } + + /** Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed + * set to false. */ def toSparse(): SparseMatrix = { val spVals: MArrayBuilder[Double] = new MArrayBuilder.ofDouble val colPtrs: Array[Int] = new Array[Int](numCols + 1) @@ -157,9 +227,8 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double]) var j = 0 while (j < numCols) { var i = 0 - val indStart = j * numRows while (i < numRows) { - val v = values(indStart + i) + val v = values(index(i, j)) if (v != 0.0) { rowIndices += i spVals += v @@ -271,49 +340,73 @@ object DenseMatrix { * @param rowIndices the row index of the entry. They must be in strictly increasing order for each * column * @param values non-zero matrix entries in column major + * @param isTransposed whether the matrix is transposed. If true, the matrix can be considered + * Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs, + * and `rowIndices` behave as colIndices, and `values` are stored in row major. */ class SparseMatrix( val numRows: Int, val numCols: Int, val colPtrs: Array[Int], val rowIndices: Array[Int], - val values: Array[Double]) extends Matrix { + val values: Array[Double], + override val isTransposed: Boolean) extends Matrix { require(values.length == rowIndices.length, "The number of row indices and values don't match! " + s"values.length: ${values.length}, rowIndices.length: ${rowIndices.length}") - require(colPtrs.length == numCols + 1, "The length of the column indices should be the " + - s"number of columns + 1. Currently, colPointers.length: ${colPtrs.length}, " + - s"numCols: $numCols") + // The Or statement is for the case when the matrix is transposed + require(colPtrs.length == numCols + 1 || colPtrs.length == numRows + 1, "The length of the " + + "column indices should be the number of columns + 1. Currently, colPointers.length: " + + s"${colPtrs.length}, numCols: $numCols") require(values.length == colPtrs.last, "The last value of colPtrs must equal the number of " + s"elements. values.length: ${values.length}, colPtrs.last: ${colPtrs.last}") - override def toArray: Array[Double] = { - val arr = new Array[Double](numRows * numCols) - var j = 0 - while (j < numCols) { - var i = colPtrs(j) - val indEnd = colPtrs(j + 1) - val offset = j * numRows - while (i < indEnd) { - val rowIndex = rowIndices(i) - arr(offset + rowIndex) = values(i) - i += 1 - } - j += 1 - } - arr + /** + * Column-major sparse matrix. + * The entry values are stored in Compressed Sparse Column (CSC) format. + * For example, the following matrix + * {{{ + * 1.0 0.0 4.0 + * 0.0 3.0 5.0 + * 2.0 0.0 6.0 + * }}} + * is stored as `values: [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]`, + * `rowIndices=[0, 2, 1, 0, 1, 2]`, `colPointers=[0, 2, 3, 6]`. + * + * @param numRows number of rows + * @param numCols number of columns + * @param colPtrs the index corresponding to the start of a new column + * @param rowIndices the row index of the entry. They must be in strictly increasing + * order for each column + * @param values non-zero matrix entries in column major + */ + def this( + numRows: Int, + numCols: Int, + colPtrs: Array[Int], + rowIndices: Array[Int], + values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false) + + private[mllib] def toBreeze: BM[Double] = { + if (!isTransposed) { + new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) + } else { + val breezeMatrix = new BSM[Double](values, numCols, numRows, colPtrs, rowIndices) + breezeMatrix.t + } } - private[mllib] def toBreeze: BM[Double] = - new BSM[Double](values, numRows, numCols, colPtrs, rowIndices) - private[mllib] def apply(i: Int, j: Int): Double = { val ind = index(i, j) if (ind < 0) 0.0 else values(ind) } private[mllib] def index(i: Int, j: Int): Int = { - Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + if (!isTransposed) { + Arrays.binarySearch(rowIndices, colPtrs(j), colPtrs(j + 1), i) + } else { + Arrays.binarySearch(rowIndices, colPtrs(i), colPtrs(i + 1), j) + } } private[mllib] def update(i: Int, j: Int, v: Double): Unit = { @@ -322,7 +415,7 @@ class SparseMatrix( throw new NoSuchElementException("The given row and column indices correspond to a zero " + "value. Only non-zero elements in Sparse Matrices can be updated.") } else { - values(index(i, j)) = v + values(ind) = v } } @@ -341,7 +434,38 @@ class SparseMatrix( this } - /** Generate a `DenseMatrix` from the given `SparseMatrix`. */ + override def transpose: Matrix = + new SparseMatrix(numCols, numRows, colPtrs, rowIndices, values, !isTransposed) + + private[spark] override def foreachActive(f: (Int, Int, Double) => Unit): Unit = { + if (!isTransposed) { + var j = 0 + while (j < numCols) { + var idx = colPtrs(j) + val idxEnd = colPtrs(j + 1) + while (idx < idxEnd) { + f(rowIndices(idx), j, values(idx)) + idx += 1 + } + j += 1 + } + } else { + var i = 0 + while (i < numRows) { + var idx = colPtrs(i) + val idxEnd = colPtrs(i + 1) + while (idx < idxEnd) { + val j = rowIndices(idx) + f(i, j, values(idx)) + idx += 1 + } + i += 1 + } + } + } + + /** Generate a `DenseMatrix` from the given `SparseMatrix`. The new matrix will have isTransposed + * set to false. */ def toDense(): DenseMatrix = { new DenseMatrix(numRows, numCols, toArray) } @@ -557,10 +681,9 @@ object Matrices { private[mllib] def fromBreeze(breeze: BM[Double]): Matrix = { breeze match { case dm: BDM[Double] => - require(dm.majorStride == dm.rows, - "Do not support stride size different from the number of rows.") - new DenseMatrix(dm.rows, dm.cols, dm.data) + new DenseMatrix(dm.rows, dm.cols, dm.data, dm.isTranspose) case sm: BSM[Double] => + // There is no isTranspose flag for sparse matrices in Breeze new SparseMatrix(sm.rows, sm.cols, sm.colPtrs, sm.rowIndices, sm.data) case _ => throw new UnsupportedOperationException( @@ -679,46 +802,28 @@ object Matrices { new DenseMatrix(numRows, numCols, matrices.flatMap(_.toArray)) } else { var startCol = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - val nCols = spMat.numCols - while (j < nCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i, j + startCol, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nCols = mat.numCols + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i, j + startCol, v) + cnt += 1 } - j += 1 - } - startCol += nCols - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startCol += nCols + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i, j + startCol, v)) } - i += 1 } - j += 1 - } - startCol += nCols - data + startCol += nCols + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } @@ -744,14 +849,12 @@ object Matrices { require(numCols == mat.numCols, "The number of rows of the matrices in this sequence, " + "don't match!") mat match { - case sparse: SparseMatrix => - hasSparse = true - case dense: DenseMatrix => + case sparse: SparseMatrix => hasSparse = true + case dense: DenseMatrix => // empty on purpose case _ => throw new IllegalArgumentException("Unsupported matrix format. Expected " + s"SparseMatrix or DenseMatrix. Instead got: ${mat.getClass}") } numRows += mat.numRows - } if (!hasSparse) { val allValues = new Array[Double](numRows * numCols) @@ -759,61 +862,37 @@ object Matrices { matrices.foreach { mat => var j = 0 val nRows = mat.numRows - val values = mat.toArray - while (j < numCols) { - var i = 0 + mat.foreachActive { (i, j, v) => val indStart = j * numRows + startRow - val subMatStart = j * nRows - while (i < nRows) { - allValues(indStart + i) = values(subMatStart + i) - i += 1 - } - j += 1 + allValues(indStart + i) = v } startRow += nRows } new DenseMatrix(numRows, numCols, allValues) } else { var startRow = 0 - val entries: Array[(Int, Int, Double)] = matrices.flatMap { - case spMat: SparseMatrix => - var j = 0 - val colPtrs = spMat.colPtrs - val rowIndices = spMat.rowIndices - val values = spMat.values - val data = new Array[(Int, Int, Double)](values.length) - while (j < numCols) { - var idx = colPtrs(j) - while (idx < colPtrs(j + 1)) { - val i = rowIndices(idx) - val v = values(idx) - data(idx) = (i + startRow, j, v) - idx += 1 + val entries: Array[(Int, Int, Double)] = matrices.flatMap { mat => + val nRows = mat.numRows + mat match { + case spMat: SparseMatrix => + val data = new Array[(Int, Int, Double)](spMat.values.length) + var cnt = 0 + spMat.foreachActive { (i, j, v) => + data(cnt) = (i + startRow, j, v) + cnt += 1 } - j += 1 - } - startRow += spMat.numRows - data - case dnMat: DenseMatrix => - val data = new ArrayBuffer[(Int, Int, Double)]() - var j = 0 - val nCols = dnMat.numCols - val nRows = dnMat.numRows - val values = dnMat.values - while (j < nCols) { - var i = 0 - val indStart = j * nRows - while (i < nRows) { - val v = values(indStart + i) + startRow += nRows + data + case dnMat: DenseMatrix => + val data = new ArrayBuffer[(Int, Int, Double)]() + dnMat.foreachActive { (i, j, v) => if (v != 0.0) { data.append((i + startRow, j, v)) } - i += 1 } - j += 1 - } - startRow += nRows - data + startRow += nRows + data + } } SparseMatrix.fromCOO(numRows, numCols, entries) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index b3022add38469..8f75e6f46e05d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.apache.spark.SparkException import org.apache.spark.mllib.util.NumericParser -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row} +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions.GenericMutableRow import org.apache.spark.sql.types._ /** @@ -77,7 +78,7 @@ sealed trait Vector extends Serializable { result = 31 * result + (bits ^ (bits >>> 32)).toInt } } - return result + result } /** @@ -110,7 +111,7 @@ sealed trait Vector extends Serializable { /** * User-defined type for [[Vector]] which allows easy interaction with SQL - * via [[org.apache.spark.sql.SchemaRDD]]. + * via [[org.apache.spark.sql.DataFrame]]. */ private[spark] class VectorUDT extends UserDefinedType[Vector] { @@ -371,18 +372,23 @@ object Vectors { squaredDistance += score * score } - case (v1: SparseVector, v2: DenseVector) if v1.indices.length / v1.size < 0.5 => + case (v1: SparseVector, v2: DenseVector) => squaredDistance = sqdist(v1, v2) - case (v1: DenseVector, v2: SparseVector) if v2.indices.length / v2.size < 0.5 => + case (v1: DenseVector, v2: SparseVector) => squaredDistance = sqdist(v2, v1) - // When a SparseVector is approximately dense, we treat it as a DenseVector - case (v1, v2) => - squaredDistance = v1.toArray.zip(v2.toArray).foldLeft(0.0){ (distance, elems) => - val score = elems._1 - elems._2 - distance + score * score + case (DenseVector(vv1), DenseVector(vv2)) => + var kv = 0 + val sz = vv1.size + while (kv < sz) { + val score = vv1(kv) - vv2(kv) + squaredDistance += score * score + kv += 1 } + case _ => + throw new IllegalArgumentException("Do not support vector type " + v1.getClass + + " and " + v2.getClass) } squaredDistance } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala new file mode 100644 index 0000000000000..3871152d065a7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrix.scala @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import scala.collection.mutable.ArrayBuffer + +import breeze.linalg.{DenseMatrix => BDM} + +import org.apache.spark.{SparkException, Logging, Partitioner} +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * A grid partitioner, which uses a regular grid to partition coordinates. + * + * @param rows Number of rows. + * @param cols Number of columns. + * @param rowsPerPart Number of rows per partition, which may be less at the bottom edge. + * @param colsPerPart Number of columns per partition, which may be less at the right edge. + */ +private[mllib] class GridPartitioner( + val rows: Int, + val cols: Int, + val rowsPerPart: Int, + val colsPerPart: Int) extends Partitioner { + + require(rows > 0) + require(cols > 0) + require(rowsPerPart > 0) + require(colsPerPart > 0) + + private val rowPartitions = math.ceil(rows * 1.0 / rowsPerPart).toInt + private val colPartitions = math.ceil(cols * 1.0 / colsPerPart).toInt + + override val numPartitions = rowPartitions * colPartitions + + /** + * Returns the index of the partition the input coordinate belongs to. + * + * @param key The coordinate (i, j) or a tuple (i, j, k), where k is the inner index used in + * multiplication. k is ignored in computing partitions. + * @return The index of the partition, which the coordinate belongs to. + */ + override def getPartition(key: Any): Int = { + key match { + case (i: Int, j: Int) => + getPartitionId(i, j) + case (i: Int, j: Int, _: Int) => + getPartitionId(i, j) + case _ => + throw new IllegalArgumentException(s"Unrecognized key: $key.") + } + } + + /** Partitions sub-matrices as blocks with neighboring sub-matrices. */ + private def getPartitionId(i: Int, j: Int): Int = { + require(0 <= i && i < rows, s"Row index $i out of range [0, $rows).") + require(0 <= j && j < cols, s"Column index $j out of range [0, $cols).") + i / rowsPerPart + j / colsPerPart * rowPartitions + } + + override def equals(obj: Any): Boolean = { + obj match { + case r: GridPartitioner => + (this.rows == r.rows) && (this.cols == r.cols) && + (this.rowsPerPart == r.rowsPerPart) && (this.colsPerPart == r.colsPerPart) + case _ => + false + } + } +} + +private[mllib] object GridPartitioner { + + /** Creates a new [[GridPartitioner]] instance. */ + def apply(rows: Int, cols: Int, rowsPerPart: Int, colsPerPart: Int): GridPartitioner = { + new GridPartitioner(rows, cols, rowsPerPart, colsPerPart) + } + + /** Creates a new [[GridPartitioner]] instance with the input suggested number of partitions. */ + def apply(rows: Int, cols: Int, suggestedNumPartitions: Int): GridPartitioner = { + require(suggestedNumPartitions > 0) + val scale = 1.0 / math.sqrt(suggestedNumPartitions) + val rowsPerPart = math.round(math.max(scale * rows, 1.0)).toInt + val colsPerPart = math.round(math.max(scale * cols, 1.0)).toInt + new GridPartitioner(rows, cols, rowsPerPart, colsPerPart) + } +} + +/** + * Represents a distributed matrix in blocks of local matrices. + * + * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that + * form this distributed matrix. If multiple blocks with the same index exist, the + * results for operations like add and multiply will be unpredictable. + * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final + * rows are not required to have the given number of rows + * @param colsPerBlock Number of columns that make up each block. The blocks forming the final + * columns are not required to have the given number of columns + * @param nRows Number of rows of this matrix. If the supplied value is less than or equal to zero, + * the number of rows will be calculated when `numRows` is invoked. + * @param nCols Number of columns of this matrix. If the supplied value is less than or equal to + * zero, the number of columns will be calculated when `numCols` is invoked. + */ +class BlockMatrix( + val blocks: RDD[((Int, Int), Matrix)], + val rowsPerBlock: Int, + val colsPerBlock: Int, + private var nRows: Long, + private var nCols: Long) extends DistributedMatrix with Logging { + + private type MatrixBlock = ((Int, Int), Matrix) // ((blockRowIndex, blockColIndex), sub-matrix) + + /** + * Alternate constructor for BlockMatrix without the input of the number of rows and columns. + * + * @param blocks The RDD of sub-matrix blocks ((blockRowIndex, blockColIndex), sub-matrix) that + * form this distributed matrix. If multiple blocks with the same index exist, the + * results for operations like add and multiply will be unpredictable. + * @param rowsPerBlock Number of rows that make up each block. The blocks forming the final + * rows are not required to have the given number of rows + * @param colsPerBlock Number of columns that make up each block. The blocks forming the final + * columns are not required to have the given number of columns + */ + def this( + blocks: RDD[((Int, Int), Matrix)], + rowsPerBlock: Int, + colsPerBlock: Int) = { + this(blocks, rowsPerBlock, colsPerBlock, 0L, 0L) + } + + override def numRows(): Long = { + if (nRows <= 0L) estimateDim() + nRows + } + + override def numCols(): Long = { + if (nCols <= 0L) estimateDim() + nCols + } + + val numRowBlocks = math.ceil(numRows() * 1.0 / rowsPerBlock).toInt + val numColBlocks = math.ceil(numCols() * 1.0 / colsPerBlock).toInt + + private[mllib] def createPartitioner(): GridPartitioner = + GridPartitioner(numRowBlocks, numColBlocks, suggestedNumPartitions = blocks.partitions.size) + + private lazy val blockInfo = blocks.mapValues(block => (block.numRows, block.numCols)).cache() + + /** Estimates the dimensions of the matrix. */ + private def estimateDim(): Unit = { + val (rows, cols) = blockInfo.map { case ((blockRowIndex, blockColIndex), (m, n)) => + (blockRowIndex.toLong * rowsPerBlock + m, + blockColIndex.toLong * colsPerBlock + n) + }.reduce { (x0, x1) => + (math.max(x0._1, x1._1), math.max(x0._2, x1._2)) + } + if (nRows <= 0L) nRows = rows + assert(rows <= nRows, s"The number of rows $rows is more than claimed $nRows.") + if (nCols <= 0L) nCols = cols + assert(cols <= nCols, s"The number of columns $cols is more than claimed $nCols.") + } + + def validate(): Unit = { + logDebug("Validating BlockMatrix...") + // check if the matrix is larger than the claimed dimensions + estimateDim() + logDebug("BlockMatrix dimensions are okay...") + + // Check if there are multiple MatrixBlocks with the same index. + blockInfo.countByKey().foreach { case (key, cnt) => + if (cnt > 1) { + throw new SparkException(s"Found multiple MatrixBlocks with the indices $key. Please " + + "remove blocks with duplicate indices.") + } + } + logDebug("MatrixBlock indices are okay...") + // Check if each MatrixBlock (except edges) has the dimensions rowsPerBlock x colsPerBlock + // The first tuple is the index and the second tuple is the dimensions of the MatrixBlock + val dimensionMsg = s"dimensions different than rowsPerBlock: $rowsPerBlock, and " + + s"colsPerBlock: $colsPerBlock. Blocks on the right and bottom edges can have smaller " + + s"dimensions. You may use the repartition method to fix this issue." + blockInfo.foreach { case ((blockRowIndex, blockColIndex), (m, n)) => + if ((blockRowIndex < numRowBlocks - 1 && m != rowsPerBlock) || + (blockRowIndex == numRowBlocks - 1 && (m <= 0 || m > rowsPerBlock))) { + throw new SparkException(s"The MatrixBlock at ($blockRowIndex, $blockColIndex) has " + + dimensionMsg) + } + if ((blockColIndex < numColBlocks - 1 && n != colsPerBlock) || + (blockColIndex == numColBlocks - 1 && (n <= 0 || n > colsPerBlock))) { + throw new SparkException(s"The MatrixBlock at ($blockRowIndex, $blockColIndex) has " + + dimensionMsg) + } + } + logDebug("MatrixBlock dimensions are okay...") + logDebug("BlockMatrix is valid!") + } + + /** Caches the underlying RDD. */ + def cache(): this.type = { + blocks.cache() + this + } + + /** Persists the underlying RDD with the specified storage level. */ + def persist(storageLevel: StorageLevel): this.type = { + blocks.persist(storageLevel) + this + } + + /** Converts to CoordinateMatrix. */ + def toCoordinateMatrix(): CoordinateMatrix = { + val entryRDD = blocks.flatMap { case ((blockRowIndex, blockColIndex), mat) => + val rowStart = blockRowIndex.toLong * rowsPerBlock + val colStart = blockColIndex.toLong * colsPerBlock + val entryValues = new ArrayBuffer[MatrixEntry]() + mat.foreachActive { (i, j, v) => + if (v != 0.0) entryValues.append(new MatrixEntry(rowStart + i, colStart + j, v)) + } + entryValues + } + new CoordinateMatrix(entryRDD, numRows(), numCols()) + } + + /** Converts to IndexedRowMatrix. The number of columns must be within the integer range. */ + def toIndexedRowMatrix(): IndexedRowMatrix = { + require(numCols() < Int.MaxValue, "The number of columns must be within the integer range. " + + s"numCols: ${numCols()}") + // TODO: This implementation may be optimized + toCoordinateMatrix().toIndexedRowMatrix() + } + + /** Collect the distributed matrix on the driver as a `DenseMatrix`. */ + def toLocalMatrix(): Matrix = { + require(numRows() < Int.MaxValue, "The number of rows of this matrix should be less than " + + s"Int.MaxValue. Currently numRows: ${numRows()}") + require(numCols() < Int.MaxValue, "The number of columns of this matrix should be less than " + + s"Int.MaxValue. Currently numCols: ${numCols()}") + require(numRows() * numCols() < Int.MaxValue, "The length of the values array must be " + + s"less than Int.MaxValue. Currently numRows * numCols: ${numRows() * numCols()}") + val m = numRows().toInt + val n = numCols().toInt + val mem = m * n / 125000 + if (mem > 500) logWarning(s"Storing this matrix will require $mem MB of memory!") + val localBlocks = blocks.collect() + val values = new Array[Double](m * n) + localBlocks.foreach { case ((blockRowIndex, blockColIndex), submat) => + val rowOffset = blockRowIndex * rowsPerBlock + val colOffset = blockColIndex * colsPerBlock + submat.foreachActive { (i, j, v) => + val indexOffset = (j + colOffset) * m + rowOffset + i + values(indexOffset) = v + } + } + new DenseMatrix(m, n, values) + } + + /** Transpose this `BlockMatrix`. Returns a new `BlockMatrix` instance sharing the + * same underlying data. Is a lazy operation. */ + def transpose: BlockMatrix = { + val transposedBlocks = blocks.map { case ((blockRowIndex, blockColIndex), mat) => + ((blockColIndex, blockRowIndex), mat.transpose) + } + new BlockMatrix(transposedBlocks, colsPerBlock, rowsPerBlock, nCols, nRows) + } + + /** Collects data and assembles a local dense breeze matrix (for test only). */ + private[mllib] def toBreeze(): BDM[Double] = { + val localMat = toLocalMatrix() + new BDM[Double](localMat.numRows, localMat.numCols, localMat.toArray) + } + + /** Adds two block matrices together. The matrices must have the same size and matching + * `rowsPerBlock` and `colsPerBlock` values. If one of the blocks that are being added are + * instances of [[SparseMatrix]], the resulting sub matrix will also be a [[SparseMatrix]], even + * if it is being added to a [[DenseMatrix]]. If two dense matrices are added, the output will + * also be a [[DenseMatrix]]. + */ + def add(other: BlockMatrix): BlockMatrix = { + require(numRows() == other.numRows(), "Both matrices must have the same number of rows. " + + s"A.numRows: ${numRows()}, B.numRows: ${other.numRows()}") + require(numCols() == other.numCols(), "Both matrices must have the same number of columns. " + + s"A.numCols: ${numCols()}, B.numCols: ${other.numCols()}") + if (rowsPerBlock == other.rowsPerBlock && colsPerBlock == other.colsPerBlock) { + val addedBlocks = blocks.cogroup(other.blocks, createPartitioner()) + .map { case ((blockRowIndex, blockColIndex), (a, b)) => + if (a.size > 1 || b.size > 1) { + throw new SparkException("There are multiple MatrixBlocks with indices: " + + s"($blockRowIndex, $blockColIndex). Please remove them.") + } + if (a.isEmpty) { + new MatrixBlock((blockRowIndex, blockColIndex), b.head) + } else if (b.isEmpty) { + new MatrixBlock((blockRowIndex, blockColIndex), a.head) + } else { + val result = a.head.toBreeze + b.head.toBreeze + new MatrixBlock((blockRowIndex, blockColIndex), Matrices.fromBreeze(result)) + } + } + new BlockMatrix(addedBlocks, rowsPerBlock, colsPerBlock, numRows(), numCols()) + } else { + throw new SparkException("Cannot add matrices with different block dimensions") + } + } + + /** Left multiplies this [[BlockMatrix]] to `other`, another [[BlockMatrix]]. The `colsPerBlock` + * of this matrix must equal the `rowsPerBlock` of `other`. If `other` contains + * [[SparseMatrix]], they will have to be converted to a [[DenseMatrix]]. The output + * [[BlockMatrix]] will only consist of blocks of [[DenseMatrix]]. This may cause + * some performance issues until support for multiplying two sparse matrices is added. + */ + def multiply(other: BlockMatrix): BlockMatrix = { + require(numCols() == other.numRows(), "The number of columns of A and the number of rows " + + s"of B must be equal. A.numCols: ${numCols()}, B.numRows: ${other.numRows()}. If you " + + "think they should be equal, try setting the dimensions of A and B explicitly while " + + "initializing them.") + if (colsPerBlock == other.rowsPerBlock) { + val resultPartitioner = GridPartitioner(numRowBlocks, other.numColBlocks, + math.max(blocks.partitions.length, other.blocks.partitions.length)) + // Each block of A must be multiplied with the corresponding blocks in each column of B. + // TODO: Optimize to send block to a partition once, similar to ALS + val flatA = blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => + Iterator.tabulate(other.numColBlocks)(j => ((blockRowIndex, j, blockColIndex), block)) + } + // Each block of B must be multiplied with the corresponding blocks in each row of A. + val flatB = other.blocks.flatMap { case ((blockRowIndex, blockColIndex), block) => + Iterator.tabulate(numRowBlocks)(i => ((i, blockColIndex, blockRowIndex), block)) + } + val newBlocks: RDD[MatrixBlock] = flatA.cogroup(flatB, resultPartitioner) + .flatMap { case ((blockRowIndex, blockColIndex, _), (a, b)) => + if (a.size > 1 || b.size > 1) { + throw new SparkException("There are multiple MatrixBlocks with indices: " + + s"($blockRowIndex, $blockColIndex). Please remove them.") + } + if (a.nonEmpty && b.nonEmpty) { + val C = b.head match { + case dense: DenseMatrix => a.head.multiply(dense) + case sparse: SparseMatrix => a.head.multiply(sparse.toDense()) + case _ => throw new SparkException(s"Unrecognized matrix type ${b.head.getClass}.") + } + Iterator(((blockRowIndex, blockColIndex), C.toBreeze)) + } else { + Iterator() + } + }.reduceByKey(resultPartitioner, (a, b) => a + b) + .mapValues(Matrices.fromBreeze) + // TODO: Try to use aggregateByKey instead of reduceByKey to get rid of intermediate matrices + new BlockMatrix(newBlocks, rowsPerBlock, other.colsPerBlock, numRows(), other.numCols()) + } else { + throw new SparkException("colsPerBlock of A doesn't match rowsPerBlock of B. " + + s"A.colsPerBlock: $colsPerBlock, B.rowsPerBlock: ${other.rowsPerBlock}") + } + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index b60559c853a50..078d1fac44443 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -21,8 +21,7 @@ import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} /** * :: Experimental :: @@ -98,6 +97,46 @@ class CoordinateMatrix( toIndexedRowMatrix().toRowMatrix() } + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + def toBlockMatrix(): BlockMatrix = { + toBlockMatrix(1024, 1024) + } + + /** + * Converts to BlockMatrix. Creates blocks of [[SparseMatrix]]. + * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have + * a smaller value. Must be an integer value greater than 0. + * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have + * a smaller value. Must be an integer value greater than 0. + * @return a [[BlockMatrix]] + */ + def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { + require(rowsPerBlock > 0, + s"rowsPerBlock needs to be greater than 0. rowsPerBlock: $rowsPerBlock") + require(colsPerBlock > 0, + s"colsPerBlock needs to be greater than 0. colsPerBlock: $colsPerBlock") + val m = numRows() + val n = numCols() + val numRowBlocks = math.ceil(m.toDouble / rowsPerBlock).toInt + val numColBlocks = math.ceil(n.toDouble / colsPerBlock).toInt + val partitioner = GridPartitioner(numRowBlocks, numColBlocks, entries.partitions.length) + + val blocks: RDD[((Int, Int), Matrix)] = entries.map { entry => + val blockRowIndex = (entry.i / rowsPerBlock).toInt + val blockColIndex = (entry.j / colsPerBlock).toInt + + val rowId = entry.i % rowsPerBlock + val colId = entry.j % colsPerBlock + + ((blockRowIndex, blockColIndex), (rowId.toInt, colId.toInt, entry.value)) + }.groupByKey(partitioner).map { case ((blockRowIndex, blockColIndex), entry) => + val effRows = math.min(m - blockRowIndex.toLong * rowsPerBlock, rowsPerBlock).toInt + val effCols = math.min(n - blockColIndex.toLong * colsPerBlock, colsPerBlock).toInt + ((blockRowIndex, blockColIndex), SparseMatrix.fromCOO(effRows, effCols, entry)) + } + new BlockMatrix(blocks, rowsPerBlock, colsPerBlock, m, n) + } + /** Determines the size by computing the max row/column index. */ private def computeSize() { // Reduce will throw an exception if `entries` is empty. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index c518271f04729..3be530fa07537 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -75,6 +75,24 @@ class IndexedRowMatrix( new RowMatrix(rows.map(_.vector), 0L, nCols) } + /** Converts to BlockMatrix. Creates blocks of [[SparseMatrix]] with size 1024 x 1024. */ + def toBlockMatrix(): BlockMatrix = { + toBlockMatrix(1024, 1024) + } + + /** + * Converts to BlockMatrix. Creates blocks of [[SparseMatrix]]. + * @param rowsPerBlock The number of rows of each block. The blocks at the bottom edge may have + * a smaller value. Must be an integer value greater than 0. + * @param colsPerBlock The number of columns of each block. The blocks at the right edge may have + * a smaller value. Must be an integer value greater than 0. + * @return a [[BlockMatrix]] + */ + def toBlockMatrix(rowsPerBlock: Int, colsPerBlock: Int): BlockMatrix = { + // TODO: This implementation may be optimized + toCoordinateMatrix().toBlockMatrix(rowsPerBlock, colsPerBlock) + } + /** * Converts this matrix to a * [[org.apache.spark.mllib.linalg.distributed.CoordinateMatrix]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 02075edbabf85..961111507f2c2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -30,7 +30,6 @@ import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -152,10 +151,10 @@ class RowMatrix( * storing the right singular vectors, is computed via matrix multiplication as * U = A * (V * S^-1^), if requested by user. The actual method to use is determined * automatically based on the cost: - * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute the Gramian - * matrix first and then compute its top eigenvalues and eigenvectors locally on the driver. - * This requires a single pass with O(n^2^) storage on each executor and on the driver, and - * O(n^2^ k) time on the driver. + * - If n is small (n < 100) or k is large compared with n (k > n / 2), we compute + * the Gramian matrix first and then compute its top eigenvalues and eigenvectors locally + * on the driver. This requires a single pass with O(n^2^) storage on each executor and + * on the driver, and O(n^2^ k) time on the driver. * - Otherwise, we compute (A' * A) * v in a distributive way and send it to ARPACK's DSAUPD to * compute (A' * A)'s top eigenvalues and eigenvectors on the driver node. This requires O(k) * passes, O(n) storage on each executor, and O(n k) storage on the driver. @@ -220,8 +219,12 @@ class RowMatrix( val computeMode = mode match { case "auto" => + if(k > 5000) { + logWarning(s"computing svd with k=$k and n=$n, please check necessity") + } + // TODO: The conditions below are not fully tested. - if (n < 100 || k > n / 2) { + if (n < 100 || (k > n / 2 && n <= 15000)) { // If n is small or k is large compared with n, we better compute the Gramian matrix first // and then compute its eigenvalues locally, instead of making multiple passes. if (k < n / 3) { @@ -246,6 +249,8 @@ class RowMatrix( val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] EigenValueDecomposition.symmetricEigs(v => G * v, n, k, tol, maxIter) case SVDMode.LocalLAPACK => + // breeze (v0.10) svd latent constraint, 7 * n * n + 4 * n < Int.MaxValue + require(n < 17515, s"$n exceeds the breeze svd capability") val G = computeGramianMatrix().toBreeze.asInstanceOf[BDM[Double]] val brzSvd.SVD(uFull: BDM[Double], sigmaSquaresFull: BDV[Double], _) = brzSvd(G) (sigmaSquaresFull, uFull) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 0857877951c82..4b7d0589c973b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -25,7 +25,6 @@ import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Vectors, Vector} -import org.apache.spark.mllib.rdd.RDDFunctions._ /** * Class used to solve an optimization problem using Gradient Descent. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala index d16d0daf08565..d5e4f4ccbff10 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala @@ -26,7 +26,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.axpy -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala index 57c0768084e41..78172843be56e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RDDFunctions.scala @@ -21,10 +21,7 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.HashPartitioner -import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils /** * Machine learning specific RDD functions. @@ -53,63 +50,25 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { * Reduces the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#reduce]] + * @see [[org.apache.spark.rdd.RDD#treeReduce]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeReduce]] instead. */ - def treeReduce(f: (T, T) => T, depth: Int = 2): T = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - val cleanF = self.context.clean(f) - val reducePartition: Iterator[T] => Option[T] = iter => { - if (iter.hasNext) { - Some(iter.reduceLeft(cleanF)) - } else { - None - } - } - val partiallyReduced = self.mapPartitions(it => Iterator(reducePartition(it))) - val op: (Option[T], Option[T]) => Option[T] = (c, x) => { - if (c.isDefined && x.isDefined) { - Some(cleanF(c.get, x.get)) - } else if (c.isDefined) { - c - } else if (x.isDefined) { - x - } else { - None - } - } - RDDFunctions.fromRDD(partiallyReduced).treeAggregate(Option.empty[T])(op, op, depth) - .getOrElse(throw new UnsupportedOperationException("empty collection")) - } + @deprecated("Use RDD.treeReduce instead.", "1.3.0") + def treeReduce(f: (T, T) => T, depth: Int = 2): T = self.treeReduce(f, depth) /** * Aggregates the elements of this RDD in a multi-level tree pattern. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#aggregate]] + * @see [[org.apache.spark.rdd.RDD#treeAggregate]] + * @deprecated Use [[org.apache.spark.rdd.RDD#treeAggregate]] instead. */ + @deprecated("Use RDD.treeAggregate instead.", "1.3.0") def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, combOp: (U, U) => U, depth: Int = 2): U = { - require(depth >= 1, s"Depth must be greater than or equal to 1 but got $depth.") - if (self.partitions.size == 0) { - return Utils.clone(zeroValue, self.context.env.closureSerializer.newInstance()) - } - val cleanSeqOp = self.context.clean(seqOp) - val cleanCombOp = self.context.clean(combOp) - val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = self.mapPartitions(it => Iterator(aggregatePartition(it))) - var numPartitions = partiallyAggregated.partitions.size - val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) - // If creating an extra level doesn't help reduce the wall-clock time, we stop tree aggregation. - while (numPartitions > scale + numPartitions / scale) { - numPartitions /= scale - val curNumPartitions = numPartitions - partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => - iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values - } - partiallyAggregated.reduce(cleanCombOp) + self.treeAggregate(zeroValue)(seqOp, combOp, depth) } } @@ -117,5 +76,5 @@ class RDDFunctions[T: ClassTag](self: RDD[T]) extends Serializable { object RDDFunctions { /** Implicit conversion from an RDD to RDDFunctions. */ - implicit def fromRDD[T: ClassTag](rdd: RDD[T]) = new RDDFunctions[T](rdd) + implicit def fromRDD[T: ClassTag](rdd: RDD[T]): RDDFunctions[T] = new RDDFunctions[T](rdd) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala new file mode 100644 index 0000000000000..5ed6477bae3b2 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/IsotonicRegression.scala @@ -0,0 +1,304 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import java.io.Serializable +import java.lang.{Double => JDouble} +import java.util.Arrays.binarySearch + +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} +import org.apache.spark.rdd.RDD + +/** + * Regression model for isotonic regression. + * + * @param boundaries Array of boundaries for which predictions are known. + * Boundaries must be sorted in increasing order. + * @param predictions Array of predictions associated to the boundaries at the same index. + * Results of isotonic regression and therefore monotone. + * @param isotonic indicates whether this is isotonic or antitonic. + */ +class IsotonicRegressionModel ( + val boundaries: Array[Double], + val predictions: Array[Double], + val isotonic: Boolean) extends Serializable { + + private val predictionOrd = if (isotonic) Ordering[Double] else Ordering[Double].reverse + + require(boundaries.length == predictions.length) + assertOrdered(boundaries) + assertOrdered(predictions)(predictionOrd) + + /** Asserts the input array is monotone with the given ordering. */ + private def assertOrdered(xs: Array[Double])(implicit ord: Ordering[Double]): Unit = { + var i = 1 + while (i < xs.length) { + require(ord.compare(xs(i - 1), xs(i)) <= 0, + s"Elements (${xs(i - 1)}, ${xs(i)}) are not ordered.") + i += 1 + } + } + + /** + * Predict labels for provided features. + * Using a piecewise linear function. + * + * @param testData Features to be labeled. + * @return Predicted labels. + */ + def predict(testData: RDD[Double]): RDD[Double] = { + testData.map(predict) + } + + /** + * Predict labels for provided features. + * Using a piecewise linear function. + * + * @param testData Features to be labeled. + * @return Predicted labels. + */ + def predict(testData: JavaDoubleRDD): JavaDoubleRDD = { + JavaDoubleRDD.fromRDD(predict(testData.rdd.retag.asInstanceOf[RDD[Double]])) + } + + /** + * Predict a single label. + * Using a piecewise linear function. + * + * @param testData Feature to be labeled. + * @return Predicted label. + * 1) If testData exactly matches a boundary then associated prediction is returned. + * In case there are multiple predictions with the same boundary then one of them + * is returned. Which one is undefined (same as java.util.Arrays.binarySearch). + * 2) If testData is lower or higher than all boundaries then first or last prediction + * is returned respectively. In case there are multiple predictions with the same + * boundary then the lowest or highest is returned respectively. + * 3) If testData falls between two values in boundary array then prediction is treated + * as piecewise linear function and interpolated value is returned. In case there are + * multiple values with the same boundary then the same rules as in 2) are used. + */ + def predict(testData: Double): Double = { + + def linearInterpolation(x1: Double, y1: Double, x2: Double, y2: Double, x: Double): Double = { + y1 + (y2 - y1) * (x - x1) / (x2 - x1) + } + + val foundIndex = binarySearch(boundaries, testData) + val insertIndex = -foundIndex - 1 + + // Find if the index was lower than all values, + // higher than all values, in between two values or exact match. + if (insertIndex == 0) { + predictions.head + } else if (insertIndex == boundaries.length){ + predictions.last + } else if (foundIndex < 0) { + linearInterpolation( + boundaries(insertIndex - 1), + predictions(insertIndex - 1), + boundaries(insertIndex), + predictions(insertIndex), + testData) + } else { + predictions(foundIndex) + } + } +} + +/** + * Isotonic regression. + * Currently implemented using parallelized pool adjacent violators algorithm. + * Only univariate (single feature) algorithm supported. + * + * Sequential PAV implementation based on: + * Tibshirani, Ryan J., Holger Hoefling, and Robert Tibshirani. + * "Nearly-isotonic regression." Technometrics 53.1 (2011): 54-61. + * Available from http://www.stat.cmu.edu/~ryantibs/papers/neariso.pdf + * + * Sequential PAV parallelization based on: + * Kearsley, Anthony J., Richard A. Tapia, and Michael W. Trosset. + * "An approach to parallelizing isotonic regression." + * Applied Mathematics and Parallel Computing. Physica-Verlag HD, 1996. 141-147. + * Available from http://softlib.rice.edu/pub/CRPC-TRs/reports/CRPC-TR96640.pdf + */ +class IsotonicRegression private (private var isotonic: Boolean) extends Serializable { + + /** + * Constructs IsotonicRegression instance with default parameter isotonic = true. + * + * @return New instance of IsotonicRegression. + */ + def this() = this(true) + + /** + * Sets the isotonic parameter. + * + * @param isotonic Isotonic (increasing) or antitonic (decreasing) sequence. + * @return This instance of IsotonicRegression. + */ + def setIsotonic(isotonic: Boolean): this.type = { + this.isotonic = isotonic + this + } + + /** + * Run IsotonicRegression algorithm to obtain isotonic regression model. + * + * @param input RDD of tuples (label, feature, weight) where label is dependent variable + * for which we calculate isotonic regression, feature is independent variable + * and weight represents number of measures with default 1. + * If multiple labels share the same feature value then they are ordered before + * the algorithm is executed. + * @return Isotonic regression model. + */ + def run(input: RDD[(Double, Double, Double)]): IsotonicRegressionModel = { + val preprocessedInput = if (isotonic) { + input + } else { + input.map(x => (-x._1, x._2, x._3)) + } + + val pooled = parallelPoolAdjacentViolators(preprocessedInput) + + val predictions = if (isotonic) pooled.map(_._1) else pooled.map(-_._1) + val boundaries = pooled.map(_._2) + + new IsotonicRegressionModel(boundaries, predictions, isotonic) + } + + /** + * Run pool adjacent violators algorithm to obtain isotonic regression model. + * + * @param input JavaRDD of tuples (label, feature, weight) where label is dependent variable + * for which we calculate isotonic regression, feature is independent variable + * and weight represents number of measures with default 1. + * If multiple labels share the same feature value then they are ordered before + * the algorithm is executed. + * @return Isotonic regression model. + */ + def run(input: JavaRDD[(JDouble, JDouble, JDouble)]): IsotonicRegressionModel = { + run(input.rdd.retag.asInstanceOf[RDD[(Double, Double, Double)]]) + } + + /** + * Performs a pool adjacent violators algorithm (PAV). + * Uses approach with single processing of data where violators + * in previously processed data created by pooling are fixed immediately. + * Uses optimization of discovering monotonicity violating sequences (blocks). + * + * @param input Input data of tuples (label, feature, weight). + * @return Result tuples (label, feature, weight) where labels were updated + * to form a monotone sequence as per isotonic regression definition. + */ + private def poolAdjacentViolators( + input: Array[(Double, Double, Double)]): Array[(Double, Double, Double)] = { + + if (input.isEmpty) { + return Array.empty + } + + // Pools sub array within given bounds assigning weighted average value to all elements. + def pool(input: Array[(Double, Double, Double)], start: Int, end: Int): Unit = { + val poolSubArray = input.slice(start, end + 1) + + val weightedSum = poolSubArray.map(lp => lp._1 * lp._3).sum + val weight = poolSubArray.map(_._3).sum + + var i = start + while (i <= end) { + input(i) = (weightedSum / weight, input(i)._2, input(i)._3) + i = i + 1 + } + } + + var i = 0 + while (i < input.length) { + var j = i + + // Find monotonicity violating sequence, if any. + while (j < input.length - 1 && input(j)._1 > input(j + 1)._1) { + j = j + 1 + } + + // If monotonicity was not violated, move to next data point. + if (i == j) { + i = i + 1 + } else { + // Otherwise pool the violating sequence + // and check if pooling caused monotonicity violation in previously processed points. + while (i >= 0 && input(i)._1 > input(i + 1)._1) { + pool(input, i, j) + i = i - 1 + } + + i = j + } + } + + // For points having the same prediction, we only keep two boundary points. + val compressed = ArrayBuffer.empty[(Double, Double, Double)] + + var (curLabel, curFeature, curWeight) = input.head + var rightBound = curFeature + def merge(): Unit = { + compressed += ((curLabel, curFeature, curWeight)) + if (rightBound > curFeature) { + compressed += ((curLabel, rightBound, 0.0)) + } + } + i = 1 + while (i < input.length) { + val (label, feature, weight) = input(i) + if (label == curLabel) { + curWeight += weight + rightBound = feature + } else { + merge() + curLabel = label + curFeature = feature + curWeight = weight + rightBound = curFeature + } + i += 1 + } + merge() + + compressed.toArray + } + + /** + * Performs parallel pool adjacent violators algorithm. + * Performs Pool adjacent violators algorithm on each partition and then again on the result. + * + * @param input Input data of tuples (label, feature, weight). + * @return Result tuples (label, feature, weight) where labels were updated + * to form a monotone sequence as per isotonic regression definition. + */ + private def parallelPoolAdjacentViolators( + input: RDD[(Double, Double, Double)]): Array[(Double, Double, Double)] = { + val parallelStepResult = input + .sortBy(x => (x._2, x._1)) + .glom() + .flatMap(poolAdjacentViolators) + .collect() + .sortBy(x => (x._2, x._1)) // Sort again because collect() doesn't promise ordering. + poolAdjacentViolators(parallelStepResult) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala index 0ef9c6181a0a0..b6099259971b7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala @@ -29,8 +29,8 @@ object Algo extends Enumeration { val Classification, Regression = Value private[mllib] def fromString(name: String): Algo = name match { - case "classification" => Classification - case "regression" => Regression + case "classification" | "Classification" => Classification + case "regression" | "Regression" => Regression case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name") } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index 951733fada6be..f1a6ed230186e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -183,7 +183,7 @@ private[tree] object DecisionTreeMetadata extends Logging { } /** - * Version of [[buildMetadata()]] for DecisionTree. + * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. */ def buildMetadata( input: RDD[LabeledPoint], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala index 4bca9039ebe1d..e1169d9f66ea4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/loss/Loss.scala @@ -45,7 +45,7 @@ trait Loss extends Serializable { * purposes. * @param model Model of the weak learner. * @param data Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. - * @return + * @return Measure of model error on data */ def computeError(model: TreeEnsembleModel, data: RDD[LabeledPoint]): Double diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index 69299c219878c..97f54aa62d31c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -62,7 +62,7 @@ object LinearDataGenerator { * @param nPoints Number of points in sample. * @param seed Random seed * @param eps Epsilon scaling factor. - * @return + * @return Seq of input. */ def generateLinearInput( intercept: Double, diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 47f1f46c6c260..56a9dbdd58b64 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +37,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -65,7 +65,7 @@ public void pipeline() { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 2eba83335bb58..f4ba23c44563e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -34,7 +34,7 @@ public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { @@ -55,7 +55,7 @@ public void logisticRegression() { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } @@ -67,7 +67,7 @@ public void logisticRegressionWithSetters() { LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold .registerTempTable("prediction"); - SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); predictions.collectAsList(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index a9f1c4a2c3ca7..074b58c07df7a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,7 @@ import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,7 +38,7 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient SchemaRDD dataset; + private transient DataFrame dataset; @Before public void setUp() { diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java new file mode 100644 index 0000000000000..d38fc91ace3cf --- /dev/null +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression; + +import java.io.Serializable; +import java.util.List; + +import scala.Tuple3; + +import com.google.common.collect.Lists; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; + +public class JavaIsotonicRegressionSuite implements Serializable { + private transient JavaSparkContext sc; + + private List> generateIsotonicInput(double[] labels) { + List> input = Lists.newArrayList(); + + for (int i = 1; i <= labels.length; i++) { + input.add(new Tuple3(labels[i-1], (double) i, 1d)); + } + + return input; + } + + private IsotonicRegressionModel runIsotonicRegression(double[] labels) { + JavaRDD> trainRDD = + sc.parallelize(generateIsotonicInput(labels), 2).cache(); + + return new IsotonicRegression().run(trainRDD); + } + + @Before + public void setUp() { + sc = new JavaSparkContext("local", "JavaLinearRegressionSuite"); + } + + @After + public void tearDown() { + sc.stop(); + sc = null; + } + + @Test + public void testIsotonicRegressionJavaRDD() { + IsotonicRegressionModel model = + runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); + + Assert.assertArrayEquals( + new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14); + } + + @Test + public void testIsotonicRegressionPredictionsJavaRDD() { + IsotonicRegressionModel model = + runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); + + JavaDoubleRDD testRDD = sc.parallelizeDoubles(Lists.newArrayList(0.0, 1.0, 9.5, 12.0, 13.0)); + List predictions = model.predict(testRDD).collect(); + + Assert.assertTrue(predictions.get(0) == 1d); + Assert.assertTrue(predictions.get(1) == 1d); + Assert.assertTrue(predictions.get(2) == 10d); + Assert.assertTrue(predictions.get(3) == 12d); + Assert.assertTrue(predictions.get(4) == 12d); + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 4515084bc7ae9..2f175fb117941 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.ml.param.ParamMap -import org.apache.spark.sql.SchemaRDD +import org.apache.spark.sql.DataFrame class PipelineSuite extends FunSuite { @@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite { val estimator2 = mock[Estimator[MyModel]] val model2 = mock[MyModel] val transformer3 = mock[Transformer] - val dataset0 = mock[SchemaRDD] - val dataset1 = mock[SchemaRDD] - val dataset2 = mock[SchemaRDD] - val dataset3 = mock[SchemaRDD] - val dataset4 = mock[SchemaRDD] + val dataset0 = mock[DataFrame] + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0) when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1) @@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite { val estimator = mock[Estimator[MyModel]] val pipeline = new Pipeline() .setStages(Array(estimator, estimator)) - val dataset = mock[SchemaRDD] + val dataset = mock[DataFrame] intercept[IllegalArgumentException] { pipeline.fit(dataset) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e8030fef55b1d..33e40dc7410cc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -21,49 +21,43 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext { @transient var sqlContext: SQLContext = _ - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() sqlContext = new SQLContext(sc) - dataset = sqlContext.createSchemaRDD( + dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } test("logistic regression") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset) model.transform(dataset) - .select('label, 'prediction) + .select("label", "prediction") .collect() } test("logistic regression with setters") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression() .setMaxIter(10) .setRegParam(1.0) val model = lr.fit(dataset) model.transform(dataset, model.threshold -> 0.8) // overwrite threshold - .select('label, 'score, 'prediction) + .select("label", "score", "prediction") .collect() } test("logistic regression fit and transform with varargs") { - val sqlContext = this.sqlContext - import sqlContext._ val lr = new LogisticRegression val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0) model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability") - .select('label, 'probability, 'prediction) + .select("label", "probability", "prediction") .collect() } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index cdd4db1b5b7dc..07aff56fb7d2f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -155,7 +155,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { } test("RatingBlockBuilder") { - val emptyBuilder = new RatingBlockBuilder() + val emptyBuilder = new RatingBlockBuilder[Int]() assert(emptyBuilder.size === 0) val emptyBlock = emptyBuilder.build() assert(emptyBlock.srcIds.isEmpty) @@ -179,12 +179,12 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { test("UncompressedInBlock") { val encoder = new LocalIndexEncoder(10) - val uncompressed = new UncompressedInBlockBuilder(encoder) + val uncompressed = new UncompressedInBlockBuilder[Int](encoder) .add(0, Array(1, 0, 2), Array(0, 1, 4), Array(1.0f, 2.0f, 3.0f)) .add(1, Array(3, 0), Array(2, 5), Array(4.0f, 5.0f)) .build() - assert(uncompressed.size === 5) - val records = Seq.tabulate(uncompressed.size) { i => + assert(uncompressed.length === 5) + val records = Seq.tabulate(uncompressed.length) { i => val dstEncodedIndex = uncompressed.dstEncodedIndices(i) val dstBlockId = encoder.blockId(dstEncodedIndex) val dstLocalIndex = encoder.localIndex(dstEncodedIndex) @@ -228,15 +228,15 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItems: Int, rank: Int, noiseStd: Double = 0.0, - seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = { val trainingFraction = 0.6 val testFraction = 0.3 val totalFraction = trainingFraction + testFraction val random = new Random(seed) val userFactors = genFactors(numUsers, rank, random) val itemFactors = genFactors(numItems, rank, random) - val training = ArrayBuffer.empty[Rating] - val test = ArrayBuffer.empty[Rating] + val training = ArrayBuffer.empty[Rating[Int]] + val test = ArrayBuffer.empty[Rating[Int]] for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { val x = random.nextDouble() if (x < totalFraction) { @@ -268,7 +268,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItems: Int, rank: Int, noiseStd: Double = 0.0, - seed: Long = 11L): (RDD[Rating], RDD[Rating]) = { + seed: Long = 11L): (RDD[Rating[Int]], RDD[Rating[Int]]) = { // The assumption of the implicit feedback model is that unobserved ratings are more likely to // be negatives. val positiveFraction = 0.8 @@ -279,8 +279,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val random = new Random(seed) val userFactors = genFactors(numUsers, rank, random) val itemFactors = genFactors(numItems, rank, random) - val training = ArrayBuffer.empty[Rating] - val test = ArrayBuffer.empty[Rating] + val training = ArrayBuffer.empty[Rating[Int]] + val test = ArrayBuffer.empty[Rating[Int]] for ((userId, userFactor) <- userFactors; (itemId, itemFactor) <- itemFactors) { val rating = blas.sdot(rank, userFactor, 1, itemFactor, 1) val threshold = if (rating > 0) positiveFraction else negativeFraction @@ -340,8 +340,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { * @param targetRMSE target test RMSE */ def testALS( - training: RDD[Rating], - test: RDD[Rating], + training: RDD[Rating[Int]], + test: RDD[Rating[Int]], rank: Int, maxIter: Int, regParam: Double, @@ -350,7 +350,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { numItemBlocks: Int = 3, targetRMSE: Double = 0.05): Unit = { val sqlContext = this.sqlContext - import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute} + import sqlContext.createDataFrame val als = new ALS() .setRank(rank) .setRegParam(regParam) @@ -360,7 +360,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { val alpha = als.getAlpha val model = als.fit(training) val predictions = model.transform(test) - .select('rating, 'prediction) + .select("rating", "prediction") .map { case Row(rating: Float, prediction: Float) => (rating.toDouble, prediction.toDouble) } @@ -432,4 +432,16 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging { testALS(training, test, maxIter = 4, rank = 2, regParam = 0.01, implicitPrefs = true, targetRMSE = 0.3) } + + test("using generic ID types") { + val (ratings, _) = genImplicitTestData(numUsers = 20, numItems = 40, rank = 2, noiseStd = 0.01) + + val longRatings = ratings.map(r => Rating(r.user.toLong, r.item.toLong, r.rating)) + val (longUserFactors, _) = ALS.train(longRatings, rank = 2, maxIter = 4) + assert(longUserFactors.first()._1.getClass === classOf[Long]) + + val strRatings = ratings.map(r => Rating(r.user.toString, r.item.toString, r.rating)) + val (strUserFactors, _) = ALS.train(strRatings, rank = 2, maxIter = 4) + assert(strUserFactors.first()._1.getClass === classOf[String]) + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 41cc13da4d5b1..761ea821ef7c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -23,16 +23,16 @@ import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{SQLContext, DataFrame} class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext { - @transient var dataset: SchemaRDD = _ + @transient var dataset: DataFrame = _ override def beforeAll(): Unit = { super.beforeAll() val sqlContext = new SQLContext(sc) - dataset = sqlContext.createSchemaRDD( + dataset = sqlContext.createDataFrame( sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala similarity index 94% rename from mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala rename to mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index 198997b5bb2b2..c2cd56ea40adc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GMMExpectationMaximizationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContext { +class GaussianMixtureSuite extends FunSuite with MLlibTestSparkContext { test("single cluster") { val data = sc.parallelize(Array( Vectors.dense(6.0, 9.0), @@ -39,7 +39,7 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex val seeds = Array(314589, 29032897, 50181, 494821, 4660) seeds.foreach { seed => - val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data) + val gmm = new GaussianMixture().setK(1).setSeed(seed).run(data) assert(gmm.weights(0) ~== Ew absTol 1E-5) assert(gmm.gaussians(0).mu ~== Emu absTol 1E-5) assert(gmm.gaussians(0).sigma ~== Esigma absTol 1E-5) @@ -68,7 +68,7 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex val Emu = Array(Vectors.dense(-4.3673), Vectors.dense(5.1604)) val Esigma = Array(Matrices.dense(1, 1, Array(1.1098)), Matrices.dense(1, 1, Array(0.86644))) - val gmm = new GaussianMixtureEM() + val gmm = new GaussianMixture() .setK(2) .setInitialModel(initialGmm) .run(data) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala new file mode 100644 index 0000000000000..2bae465d392aa --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.clustering + +import scala.collection.mutable + +import org.scalatest.FunSuite + +import org.apache.spark.graphx.{Edge, Graph} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext { + + import org.apache.spark.mllib.clustering.PowerIterationClustering._ + + test("power iteration clustering") { + /* + We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for + edge (3, 4). + + 15-14 -13 -12 + | | + 4 . 3 - 2 11 + | | x | | + 5 0 - 1 10 + | | + 6 - 7 - 8 - 9 + */ + + val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), + (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge + (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0), + (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0)) + val model = new PowerIterationClustering() + .setK(2) + .run(sc.parallelize(similarities, 2)) + val predictions = Array.fill(2)(mutable.Set.empty[Long]) + model.assignments.collect().foreach { case (i, c) => + predictions(c) += i + } + assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet)) + } + + test("normalize and powerIter") { + /* + Test normalize() with the following graph: + + 0 - 3 + | \ | + 1 - 2 + + The affinity matrix (A) is + + 0 1 1 1 + 1 0 1 0 + 1 1 0 1 + 1 0 1 0 + + D is diag(3, 2, 3, 2) and hence W is + + 0 1/3 1/3 1/3 + 1/2 0 1/2 0 + 1/3 1/3 0 1/3 + 1/2 0 1/2 0 + */ + val similarities = Seq[(Long, Long, Double)]( + (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0)) + val expected = Array( + Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0), + Array(1.0/2.0, 0.0, 1.0/2.0, 0.0), + Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0), + Array(1.0/2.0, 0.0, 1.0/2.0, 0.0)) + val w = normalize(sc.parallelize(similarities, 2)) + w.edges.collect().foreach { case Edge(i, j, x) => + assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14) + } + val v0 = sc.parallelize(Seq[(Long, Double)]((0, 0.1), (1, 0.2), (2, 0.3), (3, 0.4)), 2) + val w0 = Graph(v0, w.edges) + val v1 = powerIter(w0, maxIterations = 1).collect() + val u = Array(0.3, 0.2, 0.7/3.0, 0.2) + val norm = u.sum + val u1 = u.map(x => x / norm) + v1.foreach { case (i, x) => + assert(x ~== u1(i.toInt) absTol 1e-14) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 4c93c0ca4f86c..7f94564b2a3ae 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -22,29 +22,114 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { + // When the input data is all constant, the variance is zero. The standardization against + // zero variance is not well-defined, but we decide to just set it into zero here. + val constantData = Array( + Vectors.dense(2.0), + Vectors.dense(2.0), + Vectors.dense(2.0) + ) + + val sparseData = Array( + Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), + Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), + Vectors.sparse(3, Seq((1, -5.1))), + Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), + Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), + Vectors.sparse(3, Seq((1, 1.9))) + ) + + val denseData = Array( + Vectors.dense(-2.0, 2.3, 0), + Vectors.dense(0.0, -1.0, -3.0), + Vectors.dense(0.0, -5.1, 0.0), + Vectors.dense(3.8, 0.0, 1.9), + Vectors.dense(1.7, -0.6, 0.0), + Vectors.dense(0.0, 1.9, 0.0) + ) + private def computeSummary(data: RDD[Vector]): MultivariateStatisticalSummary = { data.treeAggregate(new MultivariateOnlineSummarizer)( (aggregator, data) => aggregator.add(data), (aggregator1, aggregator2) => aggregator1.merge(aggregator2)) } + test("Standardization with dense input when means and stds are provided") { + + val dataRDD = sc.parallelize(denseData, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) + + val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean) + val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false) + val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true) + + val data1 = denseData.map(equivalentModel1.transform) + val data2 = denseData.map(equivalentModel2.transform) + val data3 = denseData.map(equivalentModel3.transform) + + val data1RDD = equivalentModel1.transform(dataRDD) + val data2RDD = equivalentModel2.transform(dataRDD) + val data3RDD = equivalentModel3.transform(dataRDD) + + val summary = computeSummary(dataRDD) + val summary1 = computeSummary(data1RDD) + val summary2 = computeSummary(data2RDD) + val summary3 = computeSummary(data3RDD) + + assert((denseData, data1, data1RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((denseData, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((denseData, data3, data3RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data1, data1RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + assert((data3, data3RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary1.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary1.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(summary3.mean ~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary3.variance ~== summary.variance absTol 1E-5) + + assert(data1(0) ~== Vectors.dense(-1.31527964, 1.023470449, 0.11637768424) absTol 1E-5) + assert(data1(3) ~== Vectors.dense(1.637735298, 0.156973995, 1.32247368462) absTol 1E-5) + assert(data2(4) ~== Vectors.dense(0.865538862, -0.22604255, 0.0) absTol 1E-5) + assert(data2(5) ~== Vectors.dense(0.0, 0.71580142, 0.0) absTol 1E-5) + assert(data3(1) ~== Vectors.dense(-0.58333333, -0.58333333, -2.8166666666) absTol 1E-5) + assert(data3(5) ~== Vectors.dense(-0.58333333, 2.316666666, 0.18333333333) absTol 1E-5) + } + test("Standardization with dense input") { - val data = Array( - Vectors.dense(-2.0, 2.3, 0), - Vectors.dense(0.0, -1.0, -3.0), - Vectors.dense(0.0, -5.1, 0.0), - Vectors.dense(3.8, 0.0, 1.9), - Vectors.dense(1.7, -0.6, 0.0), - Vectors.dense(0.0, 1.9, 0.0) - ) - val dataRDD = sc.parallelize(data, 3) + val dataRDD = sc.parallelize(denseData, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -54,9 +139,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val model2 = standardizer2.fit(dataRDD) val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(model1.transform) - val data2 = data.map(model2.transform) - val data3 = data.map(model3.transform) + val data1 = denseData.map(model1.transform) + val data2 = denseData.map(model2.transform) + val data3 = denseData.map(model3.transform) val data1RDD = model1.transform(dataRDD) val data2RDD = model2.transform(dataRDD) @@ -67,19 +152,19 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val summary2 = computeSummary(data2RDD) val summary3 = computeSummary(data3RDD) - assert((data, data1, data1RDD.collect()).zipped.forall { + assert((denseData, data1, data1RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false }, "The vector type should be preserved after standardization.") - assert((data, data2, data2RDD.collect()).zipped.forall { + assert((denseData, data2, data2RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false }, "The vector type should be preserved after standardization.") - assert((data, data3, data3RDD.collect()).zipped.forall { + assert((denseData, data3, data3RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false @@ -107,17 +192,58 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { } + test("Standardization with sparse input when means and stds are provided") { + + val dataRDD = sc.parallelize(sparseData, 3) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler() + val standardizer3 = new StandardScaler(withMean = true, withStd = false) + + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) + + val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean) + val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false) + val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true) + + val data2 = sparseData.map(equivalentModel2.transform) + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + sparseData.map(equivalentModel1.transform) + } + } + + withClue("Standardization with mean can not be applied on sparse input.") { + intercept[IllegalArgumentException] { + sparseData.map(equivalentModel3.transform) + } + } + + val data2RDD = equivalentModel2.transform(dataRDD) + + val summary = computeSummary(data2RDD) + + assert((sparseData, data2, data2RDD.collect()).zipped.forall { + case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true + case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true + case _ => false + }, "The vector type should be preserved after standardization.") + + assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) + + assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + + assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) + assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) + } + test("Standardization with sparse input") { - val data = Array( - Vectors.sparse(3, Seq((0, -2.0), (1, 2.3))), - Vectors.sparse(3, Seq((1, -1.0), (2, -3.0))), - Vectors.sparse(3, Seq((1, -5.1))), - Vectors.sparse(3, Seq((0, 3.8), (2, 1.9))), - Vectors.sparse(3, Seq((0, 1.7), (1, -0.6))), - Vectors.sparse(3, Seq((1, 1.9))) - ) - val dataRDD = sc.parallelize(data, 3) + val dataRDD = sc.parallelize(sparseData, 3) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler() @@ -127,25 +253,26 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val model2 = standardizer2.fit(dataRDD) val model3 = standardizer3.fit(dataRDD) - val data2 = data.map(model2.transform) + val data2 = sparseData.map(model2.transform) withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(model1.transform) + sparseData.map(model1.transform) } } withClue("Standardization with mean can not be applied on sparse input.") { intercept[IllegalArgumentException] { - data.map(model3.transform) + sparseData.map(model3.transform) } } val data2RDD = model2.transform(dataRDD) - val summary2 = computeSummary(data2RDD) - assert((data, data2, data2RDD.collect()).zipped.forall { + val summary = computeSummary(data2RDD) + + assert((sparseData, data2, data2RDD.collect()).zipped.forall { case (v1: DenseVector, v2: DenseVector, v3: DenseVector) => true case (v1: SparseVector, v2: SparseVector, v3: SparseVector) => true case _ => false @@ -153,23 +280,44 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { assert((data2, data2RDD.collect()).zipped.forall((v1, v2) => v1 ~== v2 absTol 1E-5)) - assert(summary2.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) - assert(summary2.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) + assert(summary.mean !~== Vectors.dense(0.0, 0.0, 0.0) absTol 1E-5) + assert(summary.variance ~== Vectors.dense(1.0, 1.0, 1.0) absTol 1E-5) assert(data2(4) ~== Vectors.sparse(3, Seq((0, 0.865538862), (1, -0.22604255))) absTol 1E-5) assert(data2(5) ~== Vectors.sparse(3, Seq((1, 0.71580142))) absTol 1E-5) } + test("Standardization with constant input when means and stds are provided") { + + val dataRDD = sc.parallelize(constantData, 2) + + val standardizer1 = new StandardScaler(withMean = true, withStd = true) + val standardizer2 = new StandardScaler(withMean = true, withStd = false) + val standardizer3 = new StandardScaler(withMean = false, withStd = true) + + val model1 = standardizer1.fit(dataRDD) + val model2 = standardizer2.fit(dataRDD) + val model3 = standardizer3.fit(dataRDD) + + val equivalentModel1 = new StandardScalerModel(model1.std, model1.mean) + val equivalentModel2 = new StandardScalerModel(model2.std, model2.mean, true, false) + val equivalentModel3 = new StandardScalerModel(model3.std, model3.mean, false, true) + + val data1 = constantData.map(equivalentModel1.transform) + val data2 = constantData.map(equivalentModel2.transform) + val data3 = constantData.map(equivalentModel3.transform) + + assert(data1.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data2.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + assert(data3.forall(_.toArray.forall(_ == 0.0)), + "The variance is zero, so the transformed result should be 0.0") + } + test("Standardization with constant input") { - // When the input data is all constant, the variance is zero. The standardization against - // zero variance is not well-defined, but we decide to just set it into zero here. - val data = Array( - Vectors.dense(2.0), - Vectors.dense(2.0), - Vectors.dense(2.0) - ) - val dataRDD = sc.parallelize(data, 2) + val dataRDD = sc.parallelize(constantData, 2) val standardizer1 = new StandardScaler(withMean = true, withStd = true) val standardizer2 = new StandardScaler(withMean = true, withStd = false) @@ -179,9 +327,9 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { val model2 = standardizer2.fit(dataRDD) val model3 = standardizer3.fit(dataRDD) - val data1 = data.map(model1.transform) - val data2 = data.map(model2.transform) - val data3 = data.map(model3.transform) + val data1 = constantData.map(model1.transform) + val data2 = constantData.map(model2.transform) + val data3 = constantData.map(model3.transform) assert(data1.forall(_.toArray.forall(_ == 0.0)), "The variance is zero, so the transformed result should be 0.0") @@ -191,4 +339,29 @@ class StandardScalerSuite extends FunSuite with MLlibTestSparkContext { "The variance is zero, so the transformed result should be 0.0") } + test("StandardScalerModel argument nulls are properly handled") { + + withClue("model needs at least one of std or mean vectors") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(null, null) + } + } + withClue("model needs std to set withStd to true") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(null, Vectors.dense(0.0)) + model.setWithStd(true) + } + } + withClue("model needs mean to set withMean to true") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(Vectors.dense(0.0), null) + model.setWithMean(true) + } + } + withClue("model needs std and mean vectors to be equal size when both are provided") { + intercept[IllegalArgumentException] { + val model = new StandardScalerModel(Vectors.dense(0.0), Vectors.dense(0.0,1.0)) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 771878e925ea7..b0b78acd6df16 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -169,16 +169,17 @@ class BLASSuite extends FunSuite { } test("gemm") { - val dA = new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) val B = new DenseMatrix(3, 2, Array(1.0, 0.0, 0.0, 0.0, 2.0, 1.0)) val expected = new DenseMatrix(4, 2, Array(0.0, 1.0, 0.0, 0.0, 4.0, 0.0, 2.0, 3.0)) + val BTman = new DenseMatrix(2, 3, Array(1.0, 0.0, 0.0, 2.0, 0.0, 1.0)) + val BT = B.transpose - assert(dA multiply B ~== expected absTol 1e-15) - assert(sA multiply B ~== expected absTol 1e-15) + assert(dA.multiply(B) ~== expected absTol 1e-15) + assert(sA.multiply(B) ~== expected absTol 1e-15) val C1 = new DenseMatrix(4, 2, Array(1.0, 0.0, 2.0, 1.0, 0.0, 0.0, 1.0, 0.0)) val C2 = C1.copy @@ -188,6 +189,10 @@ class BLASSuite extends FunSuite { val C6 = C1.copy val C7 = C1.copy val C8 = C1.copy + val C9 = C1.copy + val C10 = C1.copy + val C11 = C1.copy + val C12 = C1.copy val expected2 = new DenseMatrix(4, 2, Array(2.0, 1.0, 4.0, 2.0, 4.0, 0.0, 4.0, 3.0)) val expected3 = new DenseMatrix(4, 2, Array(2.0, 2.0, 4.0, 2.0, 8.0, 0.0, 6.0, 6.0)) @@ -202,26 +207,40 @@ class BLASSuite extends FunSuite { withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemm(true, false, 1.0, dA, B, 2.0, C1) + gemm(1.0, dA.transpose, B, 2.0, C1) } } - val dAT = + val dATman = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) - val sAT = + val sATman = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply B ~== expected absTol 1e-15) - assert(sAT transposeMultiply B ~== expected absTol 1e-15) - - gemm(true, false, 1.0, dAT, B, 2.0, C5) - gemm(true, false, 1.0, sAT, B, 2.0, C6) - gemm(true, false, 2.0, dAT, B, 2.0, C7) - gemm(true, false, 2.0, sAT, B, 2.0, C8) + val dATT = dATman.transpose + val sATT = sATman.transpose + val BTT = BTman.transpose.asInstanceOf[DenseMatrix] + + assert(dATT.multiply(B) ~== expected absTol 1e-15) + assert(sATT.multiply(B) ~== expected absTol 1e-15) + assert(dATT.multiply(BTT) ~== expected absTol 1e-15) + assert(sATT.multiply(BTT) ~== expected absTol 1e-15) + + gemm(1.0, dATT, BTT, 2.0, C5) + gemm(1.0, sATT, BTT, 2.0, C6) + gemm(2.0, dATT, BTT, 2.0, C7) + gemm(2.0, sATT, BTT, 2.0, C8) + gemm(1.0, dA, BTT, 2.0, C9) + gemm(1.0, sA, BTT, 2.0, C10) + gemm(2.0, dA, BTT, 2.0, C11) + gemm(2.0, sA, BTT, 2.0, C12) assert(C5 ~== expected2 absTol 1e-15) assert(C6 ~== expected2 absTol 1e-15) assert(C7 ~== expected3 absTol 1e-15) assert(C8 ~== expected3 absTol 1e-15) + assert(C9 ~== expected2 absTol 1e-15) + assert(C10 ~== expected2 absTol 1e-15) + assert(C11 ~== expected3 absTol 1e-15) + assert(C12 ~== expected3 absTol 1e-15) } test("gemv") { @@ -233,17 +252,13 @@ class BLASSuite extends FunSuite { val x = new DenseVector(Array(1.0, 2.0, 3.0)) val expected = new DenseVector(Array(4.0, 1.0, 2.0, 9.0)) - assert(dA multiply x ~== expected absTol 1e-15) - assert(sA multiply x ~== expected absTol 1e-15) + assert(dA.multiply(x) ~== expected absTol 1e-15) + assert(sA.multiply(x) ~== expected absTol 1e-15) val y1 = new DenseVector(Array(1.0, 3.0, 1.0, 0.0)) val y2 = y1.copy val y3 = y1.copy val y4 = y1.copy - val y5 = y1.copy - val y6 = y1.copy - val y7 = y1.copy - val y8 = y1.copy val expected2 = new DenseVector(Array(6.0, 7.0, 4.0, 9.0)) val expected3 = new DenseVector(Array(10.0, 8.0, 6.0, 18.0)) @@ -257,25 +272,18 @@ class BLASSuite extends FunSuite { assert(y4 ~== expected3 absTol 1e-15) withClue("columns of A don't match the rows of B") { intercept[Exception] { - gemv(true, 1.0, dA, x, 2.0, y1) + gemv(1.0, dA.transpose, x, 2.0, y1) } } - val dAT = new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) val sAT = new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) - assert(dAT transposeMultiply x ~== expected absTol 1e-15) - assert(sAT transposeMultiply x ~== expected absTol 1e-15) - - gemv(true, 1.0, dAT, x, 2.0, y5) - gemv(true, 1.0, sAT, x, 2.0, y6) - gemv(true, 2.0, dAT, x, 2.0, y7) - gemv(true, 2.0, sAT, x, 2.0, y8) - assert(y5 ~== expected2 absTol 1e-15) - assert(y6 ~== expected2 absTol 1e-15) - assert(y7 ~== expected3 absTol 1e-15) - assert(y8 ~== expected3 absTol 1e-15) + val dATT = dAT.transpose + val sATT = sAT.transpose + + assert(dATT.multiply(x) ~== expected absTol 1e-15) + assert(sATT.multiply(x) ~== expected absTol 1e-15) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index 73a6d3a27d868..2031032373971 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -36,6 +36,11 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + // transposed matrix + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[DenseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(matTransposed.values.eq(breeze.data), "should not copy data") } test("sparse matrix to breeze") { @@ -58,5 +63,9 @@ class BreezeMatrixConversionSuite extends FunSuite { assert(mat.numRows === breeze.rows) assert(mat.numCols === breeze.cols) assert(mat.values.eq(breeze.data), "should not copy data") + val matTransposed = Matrices.fromBreeze(breeze.t).asInstanceOf[SparseMatrix] + assert(matTransposed.numRows === breeze.cols) + assert(matTransposed.numCols === breeze.rows) + assert(!matTransposed.values.eq(breeze.data), "has to copy data") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala index a35d0fe389fdd..b1ebfde0e5e57 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala @@ -22,6 +22,9 @@ import java.util.Random import org.mockito.Mockito.when import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar._ +import scala.collection.mutable.{Map => MutableMap} + +import org.apache.spark.mllib.util.TestingUtils._ class MatricesSuite extends FunSuite { test("dense matrix construction") { @@ -32,7 +35,6 @@ class MatricesSuite extends FunSuite { assert(mat.numRows === m) assert(mat.numCols === n) assert(mat.values.eq(values), "should not copy data") - assert(mat.toArray.eq(values), "toArray should not copy data") } test("dense matrix construction with wrong dimension") { @@ -161,6 +163,66 @@ class MatricesSuite extends FunSuite { assert(deMat1.toArray === deMat2.toArray) } + test("transpose") { + val dA = + new DenseMatrix(4, 3, Array(0.0, 1.0, 0.0, 0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 3.0)) + val sA = new SparseMatrix(4, 3, Array(0, 1, 3, 4), Array(1, 0, 2, 3), Array(1.0, 2.0, 1.0, 3.0)) + + val dAT = dA.transpose.asInstanceOf[DenseMatrix] + val sAT = sA.transpose.asInstanceOf[SparseMatrix] + val dATexpected = + new DenseMatrix(3, 4, Array(0.0, 2.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 3.0)) + val sATexpected = + new SparseMatrix(3, 4, Array(0, 1, 2, 3, 4), Array(1, 0, 1, 2), Array(2.0, 1.0, 1.0, 3.0)) + + assert(dAT.toBreeze === dATexpected.toBreeze) + assert(sAT.toBreeze === sATexpected.toBreeze) + assert(dA(1, 0) === dAT(0, 1)) + assert(dA(2, 1) === dAT(1, 2)) + assert(sA(1, 0) === sAT(0, 1)) + assert(sA(2, 1) === sAT(1, 2)) + + assert(!dA.toArray.eq(dAT.toArray), "has to have a new array") + assert(dA.values.eq(dAT.transpose.asInstanceOf[DenseMatrix].values), "should not copy array") + + assert(dAT.toSparse().toBreeze === sATexpected.toBreeze) + assert(sAT.toDense().toBreeze === dATexpected.toBreeze) + } + + test("foreachActive") { + val m = 3 + val n = 2 + val values = Array(1.0, 2.0, 4.0, 5.0) + val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) + val colPtrs = Array(0, 2, 4) + val rowIndices = Array(0, 1, 1, 2) + + val sp = new SparseMatrix(m, n, colPtrs, rowIndices, values) + val dn = new DenseMatrix(m, n, allValues) + + val dnMap = MutableMap[(Int, Int), Double]() + dn.foreachActive { (i, j, value) => + dnMap.put((i, j), value) + } + assert(dnMap.size === 6) + assert(dnMap(0, 0) === 1.0) + assert(dnMap(1, 0) === 2.0) + assert(dnMap(2, 0) === 0.0) + assert(dnMap(0, 1) === 0.0) + assert(dnMap(1, 1) === 4.0) + assert(dnMap(2, 1) === 5.0) + + val spMap = MutableMap[(Int, Int), Double]() + sp.foreachActive { (i, j, value) => + spMap.put((i, j), value) + } + assert(spMap.size === 4) + assert(spMap(0, 0) === 1.0) + assert(spMap(1, 0) === 2.0) + assert(spMap(1, 1) === 4.0) + assert(spMap(2, 1) === 5.0) + } + test("horzcat, vertcat, eye, speye") { val m = 3 val n = 2 @@ -168,9 +230,20 @@ class MatricesSuite extends FunSuite { val allValues = Array(1.0, 2.0, 0.0, 0.0, 4.0, 5.0) val colPtrs = Array(0, 2, 4) val rowIndices = Array(0, 1, 1, 2) + // transposed versions + val allValuesT = Array(1.0, 0.0, 2.0, 4.0, 0.0, 5.0) + val colPtrsT = Array(0, 1, 3, 4) + val rowIndicesT = Array(0, 0, 1, 1) val spMat1 = new SparseMatrix(m, n, colPtrs, rowIndices, values) val deMat1 = new DenseMatrix(m, n, allValues) + val spMat1T = new SparseMatrix(n, m, colPtrsT, rowIndicesT, values) + val deMat1T = new DenseMatrix(n, m, allValuesT) + + // should equal spMat1 & deMat1 respectively + val spMat1TT = spMat1T.transpose + val deMat1TT = deMat1T.transpose + val deMat2 = Matrices.eye(3) val spMat2 = Matrices.speye(3) val deMat3 = Matrices.eye(2) @@ -180,7 +253,6 @@ class MatricesSuite extends FunSuite { val spHorz2 = Matrices.horzcat(Array(spMat1, deMat2)) val spHorz3 = Matrices.horzcat(Array(deMat1, spMat2)) val deHorz1 = Matrices.horzcat(Array(deMat1, deMat2)) - val deHorz2 = Matrices.horzcat(Array[Matrix]()) assert(deHorz1.numRows === 3) @@ -195,8 +267,8 @@ class MatricesSuite extends FunSuite { assert(deHorz2.numCols === 0) assert(deHorz2.toArray.length === 0) - assert(deHorz1.toBreeze.toDenseMatrix === spHorz2.toBreeze.toDenseMatrix) - assert(spHorz2.toBreeze === spHorz3.toBreeze) + assert(deHorz1 ~== spHorz2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spHorz2 ~== spHorz3 absTol 1e-15) assert(spHorz(0, 0) === 1.0) assert(spHorz(2, 1) === 5.0) assert(spHorz(0, 2) === 1.0) @@ -212,6 +284,17 @@ class MatricesSuite extends FunSuite { assert(deHorz1(2, 4) === 1.0) assert(deHorz1(1, 4) === 0.0) + // containing transposed matrices + val spHorzT = Matrices.horzcat(Array(spMat1TT, spMat2)) + val spHorz2T = Matrices.horzcat(Array(spMat1TT, deMat2)) + val spHorz3T = Matrices.horzcat(Array(deMat1TT, spMat2)) + val deHorz1T = Matrices.horzcat(Array(deMat1TT, deMat2)) + + assert(deHorz1T ~== deHorz1 absTol 1e-15) + assert(spHorzT ~== spHorz absTol 1e-15) + assert(spHorz2T ~== spHorz2 absTol 1e-15) + assert(spHorz3T ~== spHorz3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.horzcat(Array(spMat1, spMat3)) } @@ -238,8 +321,8 @@ class MatricesSuite extends FunSuite { assert(deVert2.numCols === 0) assert(deVert2.toArray.length === 0) - assert(deVert1.toBreeze.toDenseMatrix === spVert2.toBreeze.toDenseMatrix) - assert(spVert2.toBreeze === spVert3.toBreeze) + assert(deVert1 ~== spVert2.asInstanceOf[SparseMatrix].toDense absTol 1e-15) + assert(spVert2 ~== spVert3 absTol 1e-15) assert(spVert(0, 0) === 1.0) assert(spVert(2, 1) === 5.0) assert(spVert(3, 0) === 1.0) @@ -251,6 +334,17 @@ class MatricesSuite extends FunSuite { assert(deVert1(3, 1) === 0.0) assert(deVert1(4, 1) === 1.0) + // containing transposed matrices + val spVertT = Matrices.vertcat(Array(spMat1TT, spMat3)) + val deVert1T = Matrices.vertcat(Array(deMat1TT, deMat3)) + val spVert2T = Matrices.vertcat(Array(spMat1TT, deMat3)) + val spVert3T = Matrices.vertcat(Array(deMat1TT, spMat3)) + + assert(deVert1T ~== deVert1 absTol 1e-15) + assert(spVertT ~== spVert absTol 1e-15) + assert(spVert2T ~== spVert2 absTol 1e-15) + assert(spVert3T ~== spVert3 absTol 1e-15) + intercept[IllegalArgumentException] { Matrices.vertcat(Array(spMat1, spMat2)) } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala new file mode 100644 index 0000000000000..949d1c9939570 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.linalg.distributed + +import java.{util => ju} + +import breeze.linalg.{DenseMatrix => BDM} +import org.scalatest.FunSuite + +import org.apache.spark.SparkException +import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class BlockMatrixSuite extends FunSuite with MLlibTestSparkContext { + + val m = 5 + val n = 4 + val rowPerPart = 2 + val colPerPart = 2 + val numPartitions = 3 + var gridBasedMat: BlockMatrix = _ + + override def beforeAll() { + super.beforeAll() + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + + gridBasedMat = new BlockMatrix(sc.parallelize(blocks, numPartitions), rowPerPart, colPerPart) + } + + test("size") { + assert(gridBasedMat.numRows() === m) + assert(gridBasedMat.numCols() === n) + } + + test("grid partitioner") { + val random = new ju.Random() + // This should generate a 4x4 grid of 1x2 blocks. + val part0 = GridPartitioner(4, 7, suggestedNumPartitions = 12) + val expected0 = Array( + Array(0, 0, 4, 4, 8, 8, 12), + Array(1, 1, 5, 5, 9, 9, 13), + Array(2, 2, 6, 6, 10, 10, 14), + Array(3, 3, 7, 7, 11, 11, 15)) + for (i <- 0 until 4; j <- 0 until 7) { + assert(part0.getPartition((i, j)) === expected0(i)(j)) + assert(part0.getPartition((i, j, random.nextInt())) === expected0(i)(j)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((-1, 0)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((4, 0)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((0, -1)) + } + + intercept[IllegalArgumentException] { + part0.getPartition((0, 7)) + } + + val part1 = GridPartitioner(2, 2, suggestedNumPartitions = 5) + val expected1 = Array( + Array(0, 2), + Array(1, 3)) + for (i <- 0 until 2; j <- 0 until 2) { + assert(part1.getPartition((i, j)) === expected1(i)(j)) + assert(part1.getPartition((i, j, random.nextInt())) === expected1(i)(j)) + } + + val part2 = GridPartitioner(2, 2, suggestedNumPartitions = 5) + assert(part0 !== part2) + assert(part1 === part2) + + val part3 = new GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2) + val expected3 = Array( + Array(0, 0, 2), + Array(1, 1, 3)) + for (i <- 0 until 2; j <- 0 until 3) { + assert(part3.getPartition((i, j)) === expected3(i)(j)) + assert(part3.getPartition((i, j, random.nextInt())) === expected3(i)(j)) + } + + val part4 = GridPartitioner(2, 3, rowsPerPart = 1, colsPerPart = 2) + assert(part3 === part4) + + intercept[IllegalArgumentException] { + new GridPartitioner(2, 2, rowsPerPart = 0, colsPerPart = 1) + } + + intercept[IllegalArgumentException] { + GridPartitioner(2, 2, rowsPerPart = 1, colsPerPart = 0) + } + + intercept[IllegalArgumentException] { + GridPartitioner(2, 2, suggestedNumPartitions = 0) + } + } + + test("toCoordinateMatrix") { + val coordMat = gridBasedMat.toCoordinateMatrix() + assert(coordMat.numRows() === m) + assert(coordMat.numCols() === n) + assert(coordMat.toBreeze() === gridBasedMat.toBreeze()) + } + + test("toIndexedRowMatrix") { + val rowMat = gridBasedMat.toIndexedRowMatrix() + assert(rowMat.numRows() === m) + assert(rowMat.numCols() === n) + assert(rowMat.toBreeze() === gridBasedMat.toBreeze()) + } + + test("toBreeze and toLocalMatrix") { + val expected = BDM( + (1.0, 0.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 0.0), + (3.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 1.0, 5.0)) + + val dense = Matrices.fromBreeze(expected).asInstanceOf[DenseMatrix] + assert(gridBasedMat.toLocalMatrix() === dense) + assert(gridBasedMat.toBreeze() === expected) + } + + test("add") { + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 0), new DenseMatrix(1, 2, Array(1.0, 0.0))), // Added block that doesn't exist in A + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val B = new BlockMatrix(rdd, rowPerPart, colPerPart) + + val expected = BDM( + (2.0, 0.0, 0.0, 0.0), + (0.0, 4.0, 2.0, 0.0), + (6.0, 2.0, 2.0, 0.0), + (0.0, 2.0, 4.0, 2.0), + (1.0, 0.0, 2.0, 10.0)) + + val AplusB = gridBasedMat.add(B) + assert(AplusB.numRows() === m) + assert(AplusB.numCols() === B.numCols()) + assert(AplusB.toBreeze() === expected) + + val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m, n + 1) // columns don't match + intercept[IllegalArgumentException] { + gridBasedMat.add(C) + } + val largerBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(4, 4, new Array[Double](16))), + ((1, 0), new DenseMatrix(1, 4, Array(1.0, 0.0, 1.0, 5.0)))) + val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4, m, n) + intercept[SparkException] { // partitioning doesn't match + gridBasedMat.add(C2) + } + // adding BlockMatrices composed of SparseMatrices + val sparseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), SparseMatrix.speye(4)) + val denseBlocks = for (i <- 0 until 4) yield ((i / 2, i % 2), DenseMatrix.eye(4)) + val sparseBM = new BlockMatrix(sc.makeRDD(sparseBlocks, 4), 4, 4, 8, 8) + val denseBM = new BlockMatrix(sc.makeRDD(denseBlocks, 4), 4, 4, 8, 8) + + assert(sparseBM.add(sparseBM).toBreeze() === sparseBM.add(denseBM).toBreeze()) + } + + test("multiply") { + // identity matrix + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 1.0)))) + val rdd = sc.parallelize(blocks, 2) + val B = new BlockMatrix(rdd, colPerPart, rowPerPart) + val expected = BDM( + (1.0, 0.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 0.0), + (3.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 1.0, 5.0)) + + val AtimesB = gridBasedMat.multiply(B) + assert(AtimesB.numRows() === m) + assert(AtimesB.numCols() === n) + assert(AtimesB.toBreeze() === expected) + val C = new BlockMatrix(rdd, rowPerPart, colPerPart, m + 1, n) // dimensions don't match + intercept[IllegalArgumentException] { + gridBasedMat.multiply(C) + } + val largerBlocks = Seq(((0, 0), DenseMatrix.eye(4))) + val C2 = new BlockMatrix(sc.parallelize(largerBlocks, numPartitions), 4, 4) + intercept[SparkException] { + // partitioning doesn't match + gridBasedMat.multiply(C2) + } + val rand = new ju.Random(42) + val largerAblocks = for (i <- 0 until 20) yield ((i % 5, i / 5), DenseMatrix.rand(6, 4, rand)) + val largerBblocks = for (i <- 0 until 16) yield ((i % 4, i / 4), DenseMatrix.rand(4, 4, rand)) + + // Try it with increased number of partitions + val largeA = new BlockMatrix(sc.parallelize(largerAblocks, 10), 6, 4) + val largeB = new BlockMatrix(sc.parallelize(largerBblocks, 8), 4, 4) + val largeC = largeA.multiply(largeB) + val localC = largeC.toLocalMatrix() + val result = largeA.toLocalMatrix().multiply(largeB.toLocalMatrix().asInstanceOf[DenseMatrix]) + assert(largeC.numRows() === largeA.numRows()) + assert(largeC.numCols() === largeB.numCols()) + assert(localC ~== result absTol 1e-8) + } + + test("validate") { + // No error + gridBasedMat.validate() + // Wrong MatrixBlock dimensions + val blocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 1), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 0), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val rdd = sc.parallelize(blocks, numPartitions) + val wrongRowPerParts = new BlockMatrix(rdd, rowPerPart + 1, colPerPart) + val wrongColPerParts = new BlockMatrix(rdd, rowPerPart, colPerPart + 1) + intercept[SparkException] { + wrongRowPerParts.validate() + } + intercept[SparkException] { + wrongColPerParts.validate() + } + // Wrong BlockMatrix dimensions + val wrongRowSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 4, 4) + intercept[AssertionError] { + wrongRowSize.validate() + } + val wrongColSize = new BlockMatrix(rdd, rowPerPart, colPerPart, 5, 2) + intercept[AssertionError] { + wrongColSize.validate() + } + // Duplicate indices + val duplicateBlocks: Seq[((Int, Int), Matrix)] = Seq( + ((0, 0), new DenseMatrix(2, 2, Array(1.0, 0.0, 0.0, 2.0))), + ((0, 0), new DenseMatrix(2, 2, Array(0.0, 1.0, 0.0, 0.0))), + ((1, 1), new DenseMatrix(2, 2, Array(3.0, 0.0, 1.0, 1.0))), + ((1, 1), new DenseMatrix(2, 2, Array(1.0, 2.0, 0.0, 1.0))), + ((2, 1), new DenseMatrix(1, 2, Array(1.0, 5.0)))) + val dupMatrix = new BlockMatrix(sc.parallelize(duplicateBlocks, numPartitions), 2, 2) + intercept[SparkException] { + dupMatrix.validate() + } + } + + test("transpose") { + val expected = BDM( + (1.0, 0.0, 3.0, 0.0, 0.0), + (0.0, 2.0, 1.0, 1.0, 0.0), + (0.0, 1.0, 1.0, 2.0, 1.0), + (0.0, 0.0, 0.0, 1.0, 5.0)) + + val AT = gridBasedMat.transpose + assert(AT.numRows() === gridBasedMat.numCols()) + assert(AT.numCols() === gridBasedMat.numRows()) + assert(AT.toBreeze() === expected) + + // make sure it works when matrices are cached as well + gridBasedMat.cache() + val AT2 = gridBasedMat.transpose + AT2.cache() + assert(AT2.toBreeze() === AT.toBreeze()) + val A = AT2.transpose + assert(A.toBreeze() === gridBasedMat.toBreeze()) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index 80bef814ce50d..04b36a9ef9990 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -100,4 +100,18 @@ class CoordinateMatrixSuite extends FunSuite with MLlibTestSparkContext { Vectors.dense(0.0, 9.0, 0.0, 0.0)) assert(rows === expected) } + + test("toBlockMatrix") { + val blockMat = mat.toBlockMatrix(2, 2) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === mat.toBreeze()) + + intercept[IllegalArgumentException] { + mat.toBlockMatrix(-1, 2) + } + intercept[IllegalArgumentException] { + mat.toBlockMatrix(2, 0) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index b86c2ca5ff136..2ab53cc13db71 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -88,6 +88,21 @@ class IndexedRowMatrixSuite extends FunSuite with MLlibTestSparkContext { assert(coordMat.toBreeze() === idxRowMat.toBreeze()) } + test("toBlockMatrix") { + val idxRowMat = new IndexedRowMatrix(indexedRows) + val blockMat = idxRowMat.toBlockMatrix(2, 2) + assert(blockMat.numRows() === m) + assert(blockMat.numCols() === n) + assert(blockMat.toBreeze() === idxRowMat.toBreeze()) + + intercept[IllegalArgumentException] { + idxRowMat.toBlockMatrix(-1, 2) + } + intercept[IllegalArgumentException] { + idxRowMat.toBlockMatrix(2, 0) + } + } + test("multiply a local matrix") { val A = new IndexedRowMatrix(indexedRows) val B = Matrices.dense(3, 2, Array(0.0, 1.0, 2.0, 3.0, 4.0, 5.0)) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index 681ce9263933b..6d6c0aa5be812 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -46,22 +46,4 @@ class RDDFunctionsSuite extends FunSuite with MLlibTestSparkContext { val expected = data.flatMap(x => x).sliding(3).toSeq.map(_.toSeq) assert(sliding === expected) } - - test("treeAggregate") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - def seqOp = (c: Long, x: Int) => c + x - def combOp = (c1: Long, c2: Long) => c1 + c2 - for (depth <- 1 until 10) { - val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) - assert(sum === -1000L) - } - } - - test("treeReduce") { - val rdd = sc.makeRDD(-1000 until 1000, 10) - for (depth <- 1 until 10) { - val sum = rdd.treeReduce(_ + _, depth) - assert(sum === -1000) - } - } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala new file mode 100644 index 0000000000000..7ef45248281e9 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/IsotonicRegressionSuite.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.regression + +import org.scalatest.{Matchers, FunSuite} + +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ + +class IsotonicRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers { + + private def round(d: Double) = { + Math.round(d * 100).toDouble / 100 + } + + private def generateIsotonicInput(labels: Seq[Double]): Seq[(Double, Double, Double)] = { + Seq.tabulate(labels.size)(i => (labels(i), i.toDouble, 1d)) + } + + private def generateIsotonicInput( + labels: Seq[Double], + weights: Seq[Double]): Seq[(Double, Double, Double)] = { + Seq.tabulate(labels.size)(i => (labels(i), i.toDouble, weights(i))) + } + + private def runIsotonicRegression( + labels: Seq[Double], + weights: Seq[Double], + isotonic: Boolean): IsotonicRegressionModel = { + val trainRDD = sc.parallelize(generateIsotonicInput(labels, weights)).cache() + new IsotonicRegression().setIsotonic(isotonic).run(trainRDD) + } + + private def runIsotonicRegression( + labels: Seq[Double], + isotonic: Boolean): IsotonicRegressionModel = { + runIsotonicRegression(labels, Array.fill(labels.size)(1d), isotonic) + } + + test("increasing isotonic regression") { + /* + The following result could be re-produced with sklearn. + + > from sklearn.isotonic import IsotonicRegression + > x = range(9) + > y = [1, 2, 3, 1, 6, 17, 16, 17, 18] + > ir = IsotonicRegression(x, y) + > print ir.predict(x) + + array([ 1. , 2. , 2. , 2. , 6. , 16.5, 16.5, 17. , 18. ]) + */ + val model = runIsotonicRegression(Seq(1, 2, 3, 1, 6, 17, 16, 17, 18), true) + + assert(Array.tabulate(9)(x => model.predict(x)) === Array(1, 2, 2, 2, 6, 16.5, 16.5, 17, 18)) + + assert(model.boundaries === Array(0, 1, 3, 4, 5, 6, 7, 8)) + assert(model.predictions === Array(1, 2, 2, 6, 16.5, 16.5, 17.0, 18.0)) + assert(model.isotonic) + } + + test("isotonic regression with size 0") { + val model = runIsotonicRegression(Seq(), true) + + assert(model.predictions === Array()) + } + + test("isotonic regression with size 1") { + val model = runIsotonicRegression(Seq(1), true) + + assert(model.predictions === Array(1.0)) + } + + test("isotonic regression strictly increasing sequence") { + val model = runIsotonicRegression(Seq(1, 2, 3, 4, 5), true) + + assert(model.predictions === Array(1, 2, 3, 4, 5)) + } + + test("isotonic regression strictly decreasing sequence") { + val model = runIsotonicRegression(Seq(5, 4, 3, 2, 1), true) + + assert(model.boundaries === Array(0, 4)) + assert(model.predictions === Array(3, 3)) + } + + test("isotonic regression with last element violating monotonicity") { + val model = runIsotonicRegression(Seq(1, 2, 3, 4, 2), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions === Array(1, 2, 3, 3)) + } + + test("isotonic regression with first element violating monotonicity") { + val model = runIsotonicRegression(Seq(4, 2, 3, 4, 5), true) + + assert(model.boundaries === Array(0, 2, 3, 4)) + assert(model.predictions === Array(3, 3, 4, 5)) + } + + test("isotonic regression with negative labels") { + val model = runIsotonicRegression(Seq(-1, -2, 0, 1, -1), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions === Array(-1.5, -1.5, 0, 0)) + } + + test("isotonic regression with unordered input") { + val trainRDD = sc.parallelize(generateIsotonicInput(Seq(1, 2, 3, 4, 5)).reverse, 2).cache() + + val model = new IsotonicRegression().run(trainRDD) + assert(model.predictions === Array(1, 2, 3, 4, 5)) + } + + test("weighted isotonic regression") { + val model = runIsotonicRegression(Seq(1, 2, 3, 4, 2), Seq(1, 1, 1, 1, 2), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions === Array(1, 2, 2.75, 2.75)) + } + + test("weighted isotonic regression with weights lower than 1") { + val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(1, 1, 1, 0.1, 0.1), true) + + assert(model.boundaries === Array(0, 1, 2, 4)) + assert(model.predictions.map(round) === Array(1, 2, 3.3/1.2, 3.3/1.2)) + } + + test("weighted isotonic regression with negative weights") { + val model = runIsotonicRegression(Seq(1, 2, 3, 2, 1), Seq(-1, 1, -3, 1, -5), true) + + assert(model.boundaries === Array(0.0, 1.0, 4.0)) + assert(model.predictions === Array(1.0, 10.0/6, 10.0/6)) + } + + test("weighted isotonic regression with zero weights") { + val model = runIsotonicRegression(Seq[Double](1, 2, 3, 2, 1), Seq[Double](0, 0, 0, 1, 0), true) + + assert(model.boundaries === Array(0.0, 1.0, 4.0)) + assert(model.predictions === Array(1, 2, 2)) + } + + test("isotonic regression prediction") { + val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true) + + assert(model.predict(-2) === 1) + assert(model.predict(-1) === 1) + assert(model.predict(0.5) === 1.5) + assert(model.predict(0.75) === 1.75) + assert(model.predict(1) === 2) + assert(model.predict(2) === 10d/3) + assert(model.predict(9) === 10d/3) + } + + test("isotonic regression prediction with duplicate features") { + val trainRDD = sc.parallelize( + Seq[(Double, Double, Double)]( + (2, 1, 1), (1, 1, 1), (4, 2, 1), (2, 2, 1), (6, 3, 1), (5, 3, 1)), 2).cache() + val model = new IsotonicRegression().run(trainRDD) + + assert(model.predict(0) === 1) + assert(model.predict(1.5) === 2) + assert(model.predict(2.5) === 4.5) + assert(model.predict(4) === 6) + } + + test("antitonic regression prediction with duplicate features") { + val trainRDD = sc.parallelize( + Seq[(Double, Double, Double)]( + (5, 1, 1), (6, 1, 1), (2, 2, 1), (4, 2, 1), (1, 3, 1), (2, 3, 1)), 2).cache() + val model = new IsotonicRegression().setIsotonic(false).run(trainRDD) + + assert(model.predict(0) === 6) + assert(model.predict(1.5) === 4.5) + assert(model.predict(2.5) === 2) + assert(model.predict(4) === 1) + } + + test("isotonic regression RDD prediction") { + val model = runIsotonicRegression(Seq(1, 2, 7, 1, 2), true) + + val testRDD = sc.parallelize(List(-2.0, -1.0, 0.5, 0.75, 1.0, 2.0, 9.0), 2).cache() + val predictions = testRDD.map(x => (x, model.predict(x))).collect().sortBy(_._1).map(_._2) + assert(predictions === Array(1, 1, 1.5, 1.75, 2, 10.0/3, 10.0/3)) + } + + test("antitonic regression prediction") { + val model = runIsotonicRegression(Seq(7, 5, 3, 5, 1), false) + + assert(model.predict(-2) === 7) + assert(model.predict(-1) === 7) + assert(model.predict(0.5) === 6) + assert(model.predict(0.75) === 5.5) + assert(model.predict(1) === 5) + assert(model.predict(2) === 4) + assert(model.predict(9) === 1) + } + + test("model construction") { + val model = new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = true) + assert(model.predict(-0.5) === 1.0) + assert(model.predict(0.0) === 1.0) + assert(model.predict(0.5) ~== 1.5 absTol 1e-14) + assert(model.predict(1.0) === 2.0) + assert(model.predict(1.5) === 2.0) + + intercept[IllegalArgumentException] { + // different array sizes. + new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0), isotonic = true) + } + + intercept[IllegalArgumentException] { + // unordered boundaries + new IsotonicRegressionModel(Array(1.0, 0.0), Array(1.0, 2.0), isotonic = true) + } + + intercept[IllegalArgumentException] { + // unordered predictions (isotonic) + new IsotonicRegressionModel(Array(0.0, 1.0), Array(2.0, 1.0), isotonic = true) + } + + intercept[IllegalArgumentException] { + // unordered predictions (antitonic) + new IsotonicRegressionModel(Array(0.0, 1.0), Array(1.0, 2.0), isotonic = false) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 3aa97e544680b..e8341a5d0d104 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -128,6 +128,11 @@ class GradientBoostedTreesSuite extends FunSuite with MLlibTestSparkContext { } } + test("SPARK-5496: BoostingStrategy.defaultParams should recognize Classification") { + for (algo <- Seq("classification", "Classification", "regression", "Regression")) { + BoostingStrategy.defaultParams(algo) + } + } } object GradientBoostedTreesSuite { diff --git a/network/common/pom.xml b/network/common/pom.xml index 245a96b8c4038..5a9bbe105d9f1 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -48,10 +48,15 @@ slf4j-api provided + com.google.guava guava - provided + compile @@ -87,11 +92,6 @@ maven-jar-plugin 2.2 - - - test-jar - - test-jar-on-test-compile test-compile @@ -101,6 +101,18 @@ + + org.apache.maven.plugins + maven-shade-plugin + + false + + + com.google.guava:guava + + + + diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 5bfa1ac9c373e..c2d0300ecd904 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -52,7 +52,6 @@ com.google.guava guava - provided diff --git a/pom.xml b/pom.xml index 05cb3797fc55b..b855f2371b7f0 100644 --- a/pom.xml +++ b/pom.xml @@ -136,7 +136,7 @@ 1.2.3 8.1.14.v20131031 0.5.0 - 3.0.0 + 3.1.0 1.7.6 0.7.1 @@ -521,27 +521,27 @@ ${derby.version} - com.codahale.metrics + io.dropwizard.metrics metrics-core ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-jvm ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-json ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-ganglia ${codahale.metrics.version} - com.codahale.metrics + io.dropwizard.metrics metrics-graphite ${codahale.metrics.version} @@ -1264,7 +1264,10 @@ - + org.apache.maven.plugins maven-shade-plugin @@ -1276,6 +1279,23 @@ org.spark-project.spark:unused + + + com.google.common + org.spark-project.guava + + + com/google/common/base/Absent* + com/google/common/base/Function + com/google/common/base/Optional* + com/google/common/base/Present* + com/google/common/base/Supplier + + + diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index bc5d81f12d746..14ba03ed4634b 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -52,6 +52,20 @@ object MimaExcludes { "org.apache.spark.mllib.linalg.Matrices.randn"), ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.mllib.linalg.Matrices.rand") + ) ++ Seq( + // SPARK-5321 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.SparseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.transpose"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.DenseMatrix.transposeMultiply"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix." + + "org$apache$spark$mllib$linalg$Matrix$_setter_$isTransposed_="), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.isTransposed"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.mllib.linalg.Matrix.foreachActive") ) ++ Seq( // SPARK-3325 ProblemFilters.exclude[MissingMethodProblem]( @@ -81,11 +95,30 @@ object MimaExcludes { ) ++ Seq( // SPARK-5166 Spark SQL API stabilization ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Transformer.transform"), - ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit") + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Transformer.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Pipeline.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.PipelineModel.transform"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Estimator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.ml.Evaluator.evaluate"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidator.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.tuning.CrossValidatorModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScaler.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.feature.StandardScalerModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.transform"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.classification.LogisticRegression.fit"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.ml.evaluation.BinaryClassificationEvaluator.evaluate") ) ++ Seq( // SPARK-5270 ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaRDDLike.isEmpty") + ) ++ Seq( + // SPARK-5430 + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeReduce"), + ProblemFilters.exclude[MissingMethodProblem]( + "org.apache.spark.api.java.JavaRDDLike.treeAggregate") ) ++ Seq( // SPARK-5297 Java FileStream do not work with custom key/values ProblemFilters.exclude[MissingMethodProblem]( diff --git a/project/build.properties b/project/build.properties index 32a3aeefaf9fb..064ec843da9ea 100644 --- a/project/build.properties +++ b/project/build.properties @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # -sbt.version=0.13.6 +sbt.version=0.13.7 diff --git a/python/docs/conf.py b/python/docs/conf.py index e58d97ae6a746..b00dce95d65b4 100644 --- a/python/docs/conf.py +++ b/python/docs/conf.py @@ -55,9 +55,9 @@ # built documents. # # The short X.Y version. -version = '1.2-SNAPSHOT' +version = '1.3-SNAPSHOT' # The full version, including alpha/beta/rc tags. -release = '1.2-SNAPSHOT' +release = '1.3-SNAPSHOT' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. diff --git a/python/docs/index.rst b/python/docs/index.rst index 703bef644de28..d150de9d5c502 100644 --- a/python/docs/index.rst +++ b/python/docs/index.rst @@ -14,6 +14,7 @@ Contents: pyspark pyspark.sql pyspark.streaming + pyspark.ml pyspark.mllib diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst new file mode 100644 index 0000000000000..f10d1339a9a8f --- /dev/null +++ b/python/docs/pyspark.ml.rst @@ -0,0 +1,29 @@ +pyspark.ml package +===================== + +Submodules +---------- + +pyspark.ml module +----------------- + +.. automodule:: pyspark.ml + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.feature module +------------------------- + +.. automodule:: pyspark.ml.feature + :members: + :undoc-members: + :inherited-members: + +pyspark.ml.classification module +-------------------------------- + +.. automodule:: pyspark.ml.classification + :members: + :undoc-members: + :inherited-members: diff --git a/python/docs/pyspark.rst b/python/docs/pyspark.rst index e81be3b6cb796..0df12c49ad033 100644 --- a/python/docs/pyspark.rst +++ b/python/docs/pyspark.rst @@ -9,6 +9,7 @@ Subpackages pyspark.sql pyspark.streaming + pyspark.ml pyspark.mllib Contents diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9556e4718e585..d3efcdf221d82 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -45,6 +45,7 @@ from pyspark.accumulators import Accumulator, AccumulatorParam from pyspark.broadcast import Broadcast from pyspark.serializers import MarshalSerializer, PickleSerializer +from pyspark.profiler import Profiler, BasicProfiler # for back compatibility from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row @@ -52,4 +53,5 @@ __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", "Accumulator", "AccumulatorParam", "MarshalSerializer", "PickleSerializer", + "Profiler", "BasicProfiler", ] diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index b8cdbbe3cf2b6..ccbca67656c8d 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -215,21 +215,6 @@ def addInPlace(self, value1, value2): COMPLEX_ACCUMULATOR_PARAM = AddingAccumulatorParam(0.0j) -class PStatsParam(AccumulatorParam): - """PStatsParam is used to merge pstats.Stats""" - - @staticmethod - def zero(value): - return None - - @staticmethod - def addInPlace(value1, value2): - if value1 is None: - return value2 - value1.add(value2) - return value1 - - class _UpdateRequestHandler(SocketServer.StreamRequestHandler): """ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 568e21f3803bf..c0dec16ac1b25 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -20,7 +20,6 @@ import sys from threading import Lock from tempfile import NamedTemporaryFile -import atexit from pyspark import accumulators from pyspark.accumulators import Accumulator @@ -33,6 +32,7 @@ from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD from pyspark.traceback_utils import CallSite, first_spark_call +from pyspark.profiler import ProfilerCollector, BasicProfiler from py4j.java_collections import ListConverter @@ -66,7 +66,7 @@ class SparkContext(object): def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, environment=None, batchSize=0, serializer=PickleSerializer(), conf=None, - gateway=None, jsc=None): + gateway=None, jsc=None, profiler_cls=BasicProfiler): """ Create a new SparkContext. At least the master and app name should be set, either through the named parameters here or through C{conf}. @@ -88,6 +88,9 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, :param conf: A L{SparkConf} object setting Spark properties. :param gateway: Use an existing gateway and JVM, otherwise a new JVM will be instantiated. + :param jsc: The JavaSparkContext instance (optional). + :param profiler_cls: A class of custom Profiler used to do profiling + (default is pyspark.profiler.BasicProfiler). >>> from pyspark.context import SparkContext @@ -102,14 +105,14 @@ def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None, SparkContext._ensure_initialized(self, gateway=gateway) try: self._do_init(master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc) + conf, jsc, profiler_cls) except: # If an error occurs, clean up in order to allow future SparkContext creation: self.stop() raise def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, serializer, - conf, jsc): + conf, jsc, profiler_cls): self.environment = environment or {} self._conf = conf or SparkConf(_jvm=self._jvm) self._batchSize = batchSize # -1 represents an unlimited batch size @@ -192,7 +195,11 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath() # profiling stats collected for each PythonRDD - self._profile_stats = [] + if self._conf.get("spark.python.profile", "false") == "true": + dump_path = self._conf.get("spark.python.profile.dump", None) + self.profiler_collector = ProfilerCollector(profiler_cls, dump_path) + else: + self.profiler_collector = None def _initialize_context(self, jconf): """ @@ -826,39 +833,14 @@ def runJob(self, rdd, partitionFunc, partitions=None, allowLocal=False): it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal) return list(mappedRDD._collect_iterator_through_file(it)) - def _add_profile(self, id, profileAcc): - if not self._profile_stats: - dump_path = self._conf.get("spark.python.profile.dump") - if dump_path: - atexit.register(self.dump_profiles, dump_path) - else: - atexit.register(self.show_profiles) - - self._profile_stats.append([id, profileAcc, False]) - def show_profiles(self): """ Print the profile stats to stdout """ - for i, (id, acc, showed) in enumerate(self._profile_stats): - stats = acc.value - if not showed and stats: - print "=" * 60 - print "Profile of RDD" % id - print "=" * 60 - stats.sort_stats("time", "cumulative").print_stats() - # mark it as showed - self._profile_stats[i][2] = True + self.profiler_collector.show_profiles() def dump_profiles(self, path): """ Dump the profile stats into directory `path` """ - if not os.path.exists(path): - os.makedirs(path) - for id, acc, _ in self._profile_stats: - stats = acc.value - if stats: - p = os.path.join(path, "rdd_%d.pstats" % id) - stats.dump_stats(p) - self._profile_stats = [] + self.profiler_collector.dump_profiles(path) def _test(): diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a975dc19cb78e..a0a028446d5fd 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -111,10 +111,9 @@ def run(self): java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") - java_import(gateway.jvm, "org.apache.spark.sql.SQLContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.HiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.LocalHiveContext") - java_import(gateway.jvm, "org.apache.spark.sql.hive.TestHiveContext") + # TODO(davies): move into sql + java_import(gateway.jvm, "org.apache.spark.sql.*") + java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") return gateway diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py new file mode 100644 index 0000000000000..47fed80f42e13 --- /dev/null +++ b/python/pyspark/ml/__init__.py @@ -0,0 +1,21 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.param import * +from pyspark.ml.pipeline import * + +__all__ = ["Param", "Params", "Transformer", "Estimator", "Pipeline"] diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py new file mode 100644 index 0000000000000..6bd2aa8e47837 --- /dev/null +++ b/python/pyspark/ml/classification.py @@ -0,0 +1,76 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.util import inherit_doc +from pyspark.ml.wrapper import JavaEstimator, JavaModel +from pyspark.ml.param.shared import HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,\ + HasRegParam + + +__all__ = ['LogisticRegression', 'LogisticRegressionModel'] + + +@inherit_doc +class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, + HasRegParam): + """ + Logistic regression. + + >>> from pyspark.sql import Row + >>> from pyspark.mllib.linalg import Vectors + >>> dataset = sqlCtx.inferSchema(sc.parallelize([ \ + Row(label=1.0, features=Vectors.dense(1.0)), \ + Row(label=0.0, features=Vectors.sparse(1, [], []))])) + >>> lr = LogisticRegression() \ + .setMaxIter(5) \ + .setRegParam(0.01) + >>> model = lr.fit(dataset) + >>> test0 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.dense(-1.0))])) + >>> print model.transform(test0).head().prediction + 0.0 + >>> test1 = sqlCtx.inferSchema(sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))])) + >>> print model.transform(test1).head().prediction + 1.0 + """ + _java_class = "org.apache.spark.ml.classification.LogisticRegression" + + def _create_model(self, java_model): + return LogisticRegressionModel(java_model) + + +class LogisticRegressionModel(JavaModel): + """ + Model fitted by LogisticRegression. + """ + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlCtx = SQLContext(sc) + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py new file mode 100644 index 0000000000000..e088acd0ca82d --- /dev/null +++ b/python/pyspark/ml/feature.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures +from pyspark.ml.util import inherit_doc +from pyspark.ml.wrapper import JavaTransformer + +__all__ = ['Tokenizer', 'HashingTF'] + + +@inherit_doc +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): + """ + A tokenizer that converts the input string to lowercase and then + splits it by white spaces. + + >>> from pyspark.sql import Row + >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(text="a b c")])) + >>> tokenizer = Tokenizer() \ + .setInputCol("text") \ + .setOutputCol("words") + >>> print tokenizer.transform(dataset).head() + Row(text=u'a b c', words=[u'a', u'b', u'c']) + >>> print tokenizer.transform(dataset, {tokenizer.outputCol: "tokens"}).head() + Row(text=u'a b c', tokens=[u'a', u'b', u'c']) + """ + + _java_class = "org.apache.spark.ml.feature.Tokenizer" + + +@inherit_doc +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): + """ + Maps a sequence of terms to their term frequencies using the + hashing trick. + + >>> from pyspark.sql import Row + >>> dataset = sqlCtx.inferSchema(sc.parallelize([Row(words=["a", "b", "c"])])) + >>> hashingTF = HashingTF() \ + .setNumFeatures(10) \ + .setInputCol("words") \ + .setOutputCol("features") + >>> print hashingTF.transform(dataset).head().features + (10,[7,8,9],[1.0,1.0,1.0]) + >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} + >>> print hashingTF.transform(dataset, params).head().vector + (5,[2,3,4],[1.0,1.0,1.0]) + """ + + _java_class = "org.apache.spark.ml.feature.HashingTF" + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.feature tests") + sqlCtx = SQLContext(sc) + globs['sc'] = sc + globs['sqlCtx'] = sqlCtx + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py new file mode 100644 index 0000000000000..5566792cead48 --- /dev/null +++ b/python/pyspark/ml/param/__init__.py @@ -0,0 +1,82 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from pyspark.ml.util import Identifiable + + +__all__ = ['Param', 'Params'] + + +class Param(object): + """ + A param with self-contained documentation and optionally default value. + """ + + def __init__(self, parent, name, doc, defaultValue=None): + if not isinstance(parent, Identifiable): + raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__) + self.parent = parent + self.name = str(name) + self.doc = str(doc) + self.defaultValue = defaultValue + + def __str__(self): + return str(self.parent) + "-" + self.name + + def __repr__(self): + return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \ + (self.parent, self.name, self.doc, self.defaultValue) + + +class Params(Identifiable): + """ + Components that take parameters. This also provides an internal + param map to store parameter values attached to the instance. + """ + + __metaclass__ = ABCMeta + + def __init__(self): + super(Params, self).__init__() + #: embedded param map + self.paramMap = {} + + @property + def params(self): + """ + Returns all params. The default implementation uses + :py:func:`dir` to get all attributes of type + :py:class:`Param`. + """ + return filter(lambda attr: isinstance(attr, Param), + [getattr(self, x) for x in dir(self) if x != "params"]) + + def _merge_params(self, params): + paramMap = self.paramMap.copy() + paramMap.update(params) + return paramMap + + @staticmethod + def _dummy(): + """ + Returns a dummy Params instance used as a placeholder to generate docs. + """ + dummy = Params() + dummy.uid = "undefined" + return dummy diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_gen_shared_params.py new file mode 100644 index 0000000000000..5eb81106f116c --- /dev/null +++ b/python/pyspark/ml/param/_gen_shared_params.py @@ -0,0 +1,98 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +header = """# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +#""" + + +def _gen_param_code(name, doc, defaultValue): + """ + Generates Python code for a shared param class. + + :param name: param name + :param doc: param doc + :param defaultValue: string representation of the param + :return: code string + """ + # TODO: How to correctly inherit instance attributes? + template = '''class Has$Name(Params): + """ + Params with $name. + """ + + # a placeholder to make it appear in the generated doc + $name = Param(Params._dummy(), "$name", "$doc", $defaultValue) + + def __init__(self): + super(Has$Name, self).__init__() + #: param for $doc + self.$name = Param(self, "$name", "$doc", $defaultValue) + + def set$Name(self, value): + """ + Sets the value of :py:attr:`$name`. + """ + self.paramMap[self.$name] = value + return self + + def get$Name(self): + """ + Gets the value of $name or its default value. + """ + if self.$name in self.paramMap: + return self.paramMap[self.$name] + else: + return self.$name.defaultValue''' + + upperCamelName = name[0].upper() + name[1:] + return template \ + .replace("$name", name) \ + .replace("$Name", upperCamelName) \ + .replace("$doc", doc) \ + .replace("$defaultValue", defaultValue) + +if __name__ == "__main__": + print header + print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n" + print "from pyspark.ml.param import Param, Params\n\n" + shared = [ + ("maxIter", "max number of iterations", "100"), + ("regParam", "regularization constant", "0.1"), + ("featuresCol", "features column name", "'features'"), + ("labelCol", "label column name", "'label'"), + ("predictionCol", "prediction column name", "'prediction'"), + ("inputCol", "input column name", "'input'"), + ("outputCol", "output column name", "'output'"), + ("numFeatures", "number of features", "1 << 18")] + code = [] + for name, doc, defaultValue in shared: + code.append(_gen_param_code(name, doc, defaultValue)) + print "\n\n\n".join(code) diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py new file mode 100644 index 0000000000000..586822f2de423 --- /dev/null +++ b/python/pyspark/ml/param/shared.py @@ -0,0 +1,260 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# DO NOT MODIFY. The code is generated by _gen_shared_params.py. + +from pyspark.ml.param import Param, Params + + +class HasMaxIter(Params): + """ + Params with maxIter. + """ + + # a placeholder to make it appear in the generated doc + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100) + + def __init__(self): + super(HasMaxIter, self).__init__() + #: param for max number of iterations + self.maxIter = Param(self, "maxIter", "max number of iterations", 100) + + def setMaxIter(self, value): + """ + Sets the value of :py:attr:`maxIter`. + """ + self.paramMap[self.maxIter] = value + return self + + def getMaxIter(self): + """ + Gets the value of maxIter or its default value. + """ + if self.maxIter in self.paramMap: + return self.paramMap[self.maxIter] + else: + return self.maxIter.defaultValue + + +class HasRegParam(Params): + """ + Params with regParam. + """ + + # a placeholder to make it appear in the generated doc + regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1) + + def __init__(self): + super(HasRegParam, self).__init__() + #: param for regularization constant + self.regParam = Param(self, "regParam", "regularization constant", 0.1) + + def setRegParam(self, value): + """ + Sets the value of :py:attr:`regParam`. + """ + self.paramMap[self.regParam] = value + return self + + def getRegParam(self): + """ + Gets the value of regParam or its default value. + """ + if self.regParam in self.paramMap: + return self.paramMap[self.regParam] + else: + return self.regParam.defaultValue + + +class HasFeaturesCol(Params): + """ + Params with featuresCol. + """ + + # a placeholder to make it appear in the generated doc + featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features') + + def __init__(self): + super(HasFeaturesCol, self).__init__() + #: param for features column name + self.featuresCol = Param(self, "featuresCol", "features column name", 'features') + + def setFeaturesCol(self, value): + """ + Sets the value of :py:attr:`featuresCol`. + """ + self.paramMap[self.featuresCol] = value + return self + + def getFeaturesCol(self): + """ + Gets the value of featuresCol or its default value. + """ + if self.featuresCol in self.paramMap: + return self.paramMap[self.featuresCol] + else: + return self.featuresCol.defaultValue + + +class HasLabelCol(Params): + """ + Params with labelCol. + """ + + # a placeholder to make it appear in the generated doc + labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label') + + def __init__(self): + super(HasLabelCol, self).__init__() + #: param for label column name + self.labelCol = Param(self, "labelCol", "label column name", 'label') + + def setLabelCol(self, value): + """ + Sets the value of :py:attr:`labelCol`. + """ + self.paramMap[self.labelCol] = value + return self + + def getLabelCol(self): + """ + Gets the value of labelCol or its default value. + """ + if self.labelCol in self.paramMap: + return self.paramMap[self.labelCol] + else: + return self.labelCol.defaultValue + + +class HasPredictionCol(Params): + """ + Params with predictionCol. + """ + + # a placeholder to make it appear in the generated doc + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction') + + def __init__(self): + super(HasPredictionCol, self).__init__() + #: param for prediction column name + self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction') + + def setPredictionCol(self, value): + """ + Sets the value of :py:attr:`predictionCol`. + """ + self.paramMap[self.predictionCol] = value + return self + + def getPredictionCol(self): + """ + Gets the value of predictionCol or its default value. + """ + if self.predictionCol in self.paramMap: + return self.paramMap[self.predictionCol] + else: + return self.predictionCol.defaultValue + + +class HasInputCol(Params): + """ + Params with inputCol. + """ + + # a placeholder to make it appear in the generated doc + inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input') + + def __init__(self): + super(HasInputCol, self).__init__() + #: param for input column name + self.inputCol = Param(self, "inputCol", "input column name", 'input') + + def setInputCol(self, value): + """ + Sets the value of :py:attr:`inputCol`. + """ + self.paramMap[self.inputCol] = value + return self + + def getInputCol(self): + """ + Gets the value of inputCol or its default value. + """ + if self.inputCol in self.paramMap: + return self.paramMap[self.inputCol] + else: + return self.inputCol.defaultValue + + +class HasOutputCol(Params): + """ + Params with outputCol. + """ + + # a placeholder to make it appear in the generated doc + outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output') + + def __init__(self): + super(HasOutputCol, self).__init__() + #: param for output column name + self.outputCol = Param(self, "outputCol", "output column name", 'output') + + def setOutputCol(self, value): + """ + Sets the value of :py:attr:`outputCol`. + """ + self.paramMap[self.outputCol] = value + return self + + def getOutputCol(self): + """ + Gets the value of outputCol or its default value. + """ + if self.outputCol in self.paramMap: + return self.paramMap[self.outputCol] + else: + return self.outputCol.defaultValue + + +class HasNumFeatures(Params): + """ + Params with numFeatures. + """ + + # a placeholder to make it appear in the generated doc + numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18) + + def __init__(self): + super(HasNumFeatures, self).__init__() + #: param for number of features + self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18) + + def setNumFeatures(self, value): + """ + Sets the value of :py:attr:`numFeatures`. + """ + self.paramMap[self.numFeatures] = value + return self + + def getNumFeatures(self): + """ + Gets the value of numFeatures or its default value. + """ + if self.numFeatures in self.paramMap: + return self.paramMap[self.numFeatures] + else: + return self.numFeatures.defaultValue diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py new file mode 100644 index 0000000000000..2d239f8c802a0 --- /dev/null +++ b/python/pyspark/ml/pipeline.py @@ -0,0 +1,154 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta, abstractmethod + +from pyspark.ml.param import Param, Params +from pyspark.ml.util import inherit_doc + + +__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] + + +@inherit_doc +class Estimator(Params): + """ + Abstract class for estimators that fit models to data. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def fit(self, dataset, params={}): + """ + Fits a model to the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: an optional param map that overwrites embedded + params + :returns: fitted model + """ + raise NotImplementedError() + + +@inherit_doc +class Transformer(Params): + """ + Abstract class for transformers that transform one dataset into + another. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def transform(self, dataset, params={}): + """ + Transforms the input dataset with optional parameters. + + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: an optional param map that overwrites embedded + params + :returns: transformed dataset + """ + raise NotImplementedError() + + +@inherit_doc +class Pipeline(Estimator): + """ + A simple pipeline, which acts as an estimator. A Pipeline consists + of a sequence of stages, each of which is either an + :py:class:`Estimator` or a :py:class:`Transformer`. When + :py:meth:`Pipeline.fit` is called, the stages are executed in + order. If a stage is an :py:class:`Estimator`, its + :py:meth:`Estimator.fit` method will be called on the input + dataset to fit a model. Then the model, which is a transformer, + will be used to transform the dataset as the input to the next + stage. If a stage is a :py:class:`Transformer`, its + :py:meth:`Transformer.transform` method will be called to produce + the dataset for the next stage. The fitted model from a + :py:class:`Pipeline` is an :py:class:`PipelineModel`, which + consists of fitted models and transformers, corresponding to the + pipeline stages. If there are no stages, the pipeline acts as an + identity transformer. + """ + + def __init__(self): + super(Pipeline, self).__init__() + #: Param for pipeline stages. + self.stages = Param(self, "stages", "pipeline stages") + + def setStages(self, value): + """ + Set pipeline stages. + :param value: a list of transformers or estimators + :return: the pipeline instance + """ + self.paramMap[self.stages] = value + return self + + def getStages(self): + """ + Get pipeline stages. + """ + if self.stages in self.paramMap: + return self.paramMap[self.stages] + + def fit(self, dataset, params={}): + paramMap = self._merge_params(params) + stages = paramMap[self.stages] + for stage in stages: + if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)): + raise ValueError( + "Cannot recognize a pipeline stage of type %s." % type(stage).__name__) + indexOfLastEstimator = -1 + for i, stage in enumerate(stages): + if isinstance(stage, Estimator): + indexOfLastEstimator = i + transformers = [] + for i, stage in enumerate(stages): + if i <= indexOfLastEstimator: + if isinstance(stage, Transformer): + transformers.append(stage) + dataset = stage.transform(dataset, paramMap) + else: # must be an Estimator + model = stage.fit(dataset, paramMap) + transformers.append(model) + if i < indexOfLastEstimator: + dataset = model.transform(dataset, paramMap) + else: + transformers.append(stage) + return PipelineModel(transformers) + + +@inherit_doc +class PipelineModel(Transformer): + """ + Represents a compiled pipeline with transformers and fitted models. + """ + + def __init__(self, transformers): + super(PipelineModel, self).__init__() + self.transformers = transformers + + def transform(self, dataset, params={}): + paramMap = self._merge_params(params) + for t in self.transformers: + dataset = t.transform(dataset, paramMap) + return dataset diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py new file mode 100644 index 0000000000000..b627c2b4e930b --- /dev/null +++ b/python/pyspark/ml/tests.py @@ -0,0 +1,115 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Unit tests for Spark ML Python APIs. +""" + +import sys + +if sys.version_info[:2] <= (2, 6): + try: + import unittest2 as unittest + except ImportError: + sys.stderr.write('Please install unittest2 to test with Python 2.6 or earlier') + sys.exit(1) +else: + import unittest + +from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase +from pyspark.sql import DataFrame +from pyspark.ml.param import Param +from pyspark.ml.pipeline import Transformer, Estimator, Pipeline + + +class MockDataset(DataFrame): + + def __init__(self): + self.index = 0 + + +class MockTransformer(Transformer): + + def __init__(self): + super(MockTransformer, self).__init__() + self.fake = Param(self, "fake", "fake", None) + self.dataset_index = None + self.fake_param_value = None + + def transform(self, dataset, params={}): + self.dataset_index = dataset.index + if self.fake in params: + self.fake_param_value = params[self.fake] + dataset.index += 1 + return dataset + + +class MockEstimator(Estimator): + + def __init__(self): + super(MockEstimator, self).__init__() + self.fake = Param(self, "fake", "fake", None) + self.dataset_index = None + self.fake_param_value = None + self.model = None + + def fit(self, dataset, params={}): + self.dataset_index = dataset.index + if self.fake in params: + self.fake_param_value = params[self.fake] + model = MockModel() + self.model = model + return model + + +class MockModel(MockTransformer, Transformer): + + def __init__(self): + super(MockModel, self).__init__() + + +class PipelineTests(PySparkTestCase): + + def test_pipeline(self): + dataset = MockDataset() + estimator0 = MockEstimator() + transformer1 = MockTransformer() + estimator2 = MockEstimator() + transformer3 = MockTransformer() + pipeline = Pipeline() \ + .setStages([estimator0, transformer1, estimator2, transformer3]) + pipeline_model = pipeline.fit(dataset, {estimator0.fake: 0, transformer1.fake: 1}) + self.assertEqual(0, estimator0.dataset_index) + self.assertEqual(0, estimator0.fake_param_value) + model0 = estimator0.model + self.assertEqual(0, model0.dataset_index) + self.assertEqual(1, transformer1.dataset_index) + self.assertEqual(1, transformer1.fake_param_value) + self.assertEqual(2, estimator2.dataset_index) + model2 = estimator2.model + self.assertIsNone(model2.dataset_index, "The model produced by the last estimator should " + "not be called during fit.") + dataset = pipeline_model.transform(dataset) + self.assertEqual(2, model0.dataset_index) + self.assertEqual(3, transformer1.dataset_index) + self.assertEqual(4, model2.dataset_index) + self.assertEqual(5, transformer3.dataset_index) + self.assertEqual(6, dataset.index) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py new file mode 100644 index 0000000000000..b1caa84b6306a --- /dev/null +++ b/python/pyspark/ml/util.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import uuid + + +def inherit_doc(cls): + for name, func in vars(cls).items(): + # only inherit docstring for public functions + if name.startswith("_"): + continue + if not func.__doc__: + for parent in cls.__bases__: + parent_func = getattr(parent, name, None) + if parent_func and getattr(parent_func, "__doc__", None): + func.__doc__ = parent_func.__doc__ + break + return cls + + +class Identifiable(object): + """ + Object with a unique ID. + """ + + def __init__(self): + #: A unique id for the object. The default implementation + #: concatenates the class name, "-", and 8 random hex chars. + self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8] + + def __repr__(self): + return self.uid diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py new file mode 100644 index 0000000000000..9e12ddc3d9b8f --- /dev/null +++ b/python/pyspark/ml/wrapper.py @@ -0,0 +1,149 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from abc import ABCMeta + +from pyspark import SparkContext +from pyspark.sql import DataFrame +from pyspark.ml.param import Params +from pyspark.ml.pipeline import Estimator, Transformer +from pyspark.ml.util import inherit_doc + + +def _jvm(): + """ + Returns the JVM view associated with SparkContext. Must be called + after SparkContext is initialized. + """ + jvm = SparkContext._jvm + if jvm: + return jvm + else: + raise AttributeError("Cannot load _jvm from SparkContext. Is SparkContext initialized?") + + +@inherit_doc +class JavaWrapper(Params): + """ + Utility class to help create wrapper classes from Java/Scala + implementations of pipeline components. + """ + + __metaclass__ = ABCMeta + + #: Fully-qualified class name of the wrapped Java component. + _java_class = None + + def _java_obj(self): + """ + Returns or creates a Java object. + """ + java_obj = _jvm() + for name in self._java_class.split("."): + java_obj = getattr(java_obj, name) + return java_obj() + + def _transfer_params_to_java(self, params, java_obj): + """ + Transforms the embedded params and additional params to the + input Java object. + :param params: additional params (overwriting embedded values) + :param java_obj: Java object to receive the params + """ + paramMap = self._merge_params(params) + for param in self.params: + if param in paramMap: + java_obj.set(param.name, paramMap[param]) + + def _empty_java_param_map(self): + """ + Returns an empty Java ParamMap reference. + """ + return _jvm().org.apache.spark.ml.param.ParamMap() + + def _create_java_param_map(self, params, java_obj): + paramMap = self._empty_java_param_map() + for param, value in params.items(): + if param.parent is self: + paramMap.put(java_obj.getParam(param.name), value) + return paramMap + + +@inherit_doc +class JavaEstimator(Estimator, JavaWrapper): + """ + Base class for :py:class:`Estimator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _create_model(self, java_model): + """ + Creates a model from the input Java model reference. + """ + return JavaModel(java_model) + + def _fit_java(self, dataset, params={}): + """ + Fits a Java model to the input dataset. + :param dataset: input dataset, which is an instance of + :py:class:`pyspark.sql.SchemaRDD` + :param params: additional params (overwriting embedded values) + :return: fitted Java model + """ + java_obj = self._java_obj() + self._transfer_params_to_java(params, java_obj) + return java_obj.fit(dataset._jdf, self._empty_java_param_map()) + + def fit(self, dataset, params={}): + java_model = self._fit_java(dataset, params) + return self._create_model(java_model) + + +@inherit_doc +class JavaTransformer(Transformer, JavaWrapper): + """ + Base class for :py:class:`Transformer`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def transform(self, dataset, params={}): + java_obj = self._java_obj() + self._transfer_params_to_java({}, java_obj) + java_param_map = self._create_java_param_map(params, java_obj) + return DataFrame(java_obj.transform(dataset._jdf, java_param_map), + dataset.sql_ctx) + + +@inherit_doc +class JavaModel(JavaTransformer): + """ + Base class for :py:class:`Model`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def __init__(self, java_model): + super(JavaTransformer, self).__init__() + self._java_model = java_model + + def _java_obj(self): + return self._java_model diff --git a/python/pyspark/mllib/stat/__init__.py b/python/pyspark/mllib/stat/__init__.py new file mode 100644 index 0000000000000..799d260c096b1 --- /dev/null +++ b/python/pyspark/mllib/stat/__init__.py @@ -0,0 +1,24 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Python package for statistical functions in MLlib. +""" + +from pyspark.mllib.stat._statistics import * + +__all__ = ["Statistics", "MultivariateStatisticalSummary"] diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat/_statistics.py similarity index 88% rename from python/pyspark/mllib/stat.py rename to python/pyspark/mllib/stat/_statistics.py index c8af777a8b00d..218ac148ca992 100644 --- a/python/pyspark/mllib/stat.py +++ b/python/pyspark/mllib/stat/_statistics.py @@ -15,17 +15,14 @@ # limitations under the License. # -""" -Python package for statistical functions in MLlib. -""" - from pyspark import RDD from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import Matrix, _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.stat.test import ChiSqTestResult -__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics'] +__all__ = ['MultivariateStatisticalSummary', 'Statistics'] class MultivariateStatisticalSummary(JavaModelWrapper): @@ -53,54 +50,6 @@ def min(self): return self.call("min").toArray() -class ChiSqTestResult(JavaModelWrapper): - """ - .. note:: Experimental - - Object containing the test results for the chi-squared hypothesis test. - """ - @property - def method(self): - """ - Name of the test method - """ - return self._java_model.method() - - @property - def pValue(self): - """ - The probability of obtaining a test statistic result at least as - extreme as the one that was actually observed, assuming that the - null hypothesis is true. - """ - return self._java_model.pValue() - - @property - def degreesOfFreedom(self): - """ - Returns the degree(s) of freedom of the hypothesis test. - Return type should be Number(e.g. Int, Double) or tuples of Numbers. - """ - return self._java_model.degreesOfFreedom() - - @property - def statistic(self): - """ - Test statistic. - """ - return self._java_model.statistic() - - @property - def nullHypothesis(self): - """ - Null hypothesis of the test. - """ - return self._java_model.nullHypothesis() - - def __str__(self): - return self._java_model.toString() - - class Statistics(object): @staticmethod diff --git a/python/pyspark/mllib/stat/test.py b/python/pyspark/mllib/stat/test.py new file mode 100644 index 0000000000000..762506e952b43 --- /dev/null +++ b/python/pyspark/mllib/stat/test.py @@ -0,0 +1,69 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.mllib.common import JavaModelWrapper + + +__all__ = ["ChiSqTestResult"] + + +class ChiSqTestResult(JavaModelWrapper): + """ + .. note:: Experimental + + Object containing the test results for the chi-squared hypothesis test. + """ + @property + def method(self): + """ + Name of the test method + """ + return self._java_model.method() + + @property + def pValue(self): + """ + The probability of obtaining a test statistic result at least as + extreme as the one that was actually observed, assuming that the + null hypothesis is true. + """ + return self._java_model.pValue() + + @property + def degreesOfFreedom(self): + """ + Returns the degree(s) of freedom of the hypothesis test. + Return type should be Number(e.g. Int, Double) or tuples of Numbers. + """ + return self._java_model.degreesOfFreedom() + + @property + def statistic(self): + """ + Test statistic. + """ + return self._java_model.statistic() + + @property + def nullHypothesis(self): + """ + Null hypothesis of the test. + """ + return self._java_model.nullHypothesis() + + def __str__(self): + return self._java_model.toString() diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index ccfa0e9d37e9c..4ac6f37cdd0c6 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -195,7 +195,7 @@ def test_gmm_deterministic(self): def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes - from pyspark.mllib.tree import DecisionTree + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees data = [ LabeledPoint(0.0, [1, 0, 0]), LabeledPoint(1.0, [0, 1, 1]), @@ -224,18 +224,31 @@ def test_classification(self): self.assertTrue(nb_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 3} # feature 0 has 3 categories - dt_model = \ - DecisionTree.trainClassifier(rdd, numClasses=2, - categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) + rf_model = RandomForest.trainClassifier( + rdd, numClasses=2, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainClassifier( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + def test_regression(self): from pyspark.mllib.regression import LinearRegressionWithSGD, LassoWithSGD, \ RidgeRegressionWithSGD - from pyspark.mllib.tree import DecisionTree + from pyspark.mllib.tree import DecisionTree, RandomForest, GradientBoostedTrees data = [ LabeledPoint(-1.0, [0, -1]), LabeledPoint(1.0, [0, 1]), @@ -264,13 +277,27 @@ def test_regression(self): self.assertTrue(rr_model.predict(features[3]) > 0) categoricalFeaturesInfo = {0: 2} # feature 0 has 2 categories - dt_model = \ - DecisionTree.trainRegressor(rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + dt_model = DecisionTree.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) self.assertTrue(dt_model.predict(features[0]) <= 0) self.assertTrue(dt_model.predict(features[1]) > 0) self.assertTrue(dt_model.predict(features[2]) <= 0) self.assertTrue(dt_model.predict(features[3]) > 0) + rf_model = RandomForest.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo, numTrees=100) + self.assertTrue(rf_model.predict(features[0]) <= 0) + self.assertTrue(rf_model.predict(features[1]) > 0) + self.assertTrue(rf_model.predict(features[2]) <= 0) + self.assertTrue(rf_model.predict(features[3]) > 0) + + gbt_model = GradientBoostedTrees.trainRegressor( + rdd, categoricalFeaturesInfo=categoricalFeaturesInfo) + self.assertTrue(gbt_model.predict(features[0]) <= 0) + self.assertTrue(gbt_model.predict(features[1]) > 0) + self.assertTrue(gbt_model.predict(features[2]) <= 0) + self.assertTrue(gbt_model.predict(features[3]) > 0) + class StatTests(PySparkTestCase): # SPARK-4023 diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py index 66702478474dc..aae48f213246b 100644 --- a/python/pyspark/mllib/tree.py +++ b/python/pyspark/mllib/tree.py @@ -24,16 +24,48 @@ from pyspark.mllib.linalg import _convert_to_vector from pyspark.mllib.regression import LabeledPoint -__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', 'RandomForest'] +__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', + 'RandomForest', 'GradientBoostedTrees'] -class DecisionTreeModel(JavaModelWrapper): +class TreeEnsembleModel(JavaModelWrapper): + def predict(self, x): + """ + Predict values for a single data point or an RDD of points using + the model trained. + """ + if isinstance(x, RDD): + return self.call("predict", x.map(_convert_to_vector)) + + else: + return self.call("predict", _convert_to_vector(x)) + + def numTrees(self): + """ + Get number of trees in ensemble. + """ + return self.call("numTrees") + + def totalNumNodes(self): + """ + Get total number of nodes, summed over all trees in the ensemble. + """ + return self.call("totalNumNodes") + + def __repr__(self): + """ Summary of model """ + return self._java_model.toString() + + def toDebugString(self): + """ Full model """ + return self._java_model.toDebugString() + +class DecisionTreeModel(JavaModelWrapper): """ - A decision tree model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + A decision tree model for classification or regression. """ def predict(self, x): """ @@ -64,12 +96,10 @@ def toDebugString(self): class DecisionTree(object): - """ - Learning algorithm for a decision tree model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Learning algorithm for a decision tree model for classification or regression. """ @classmethod @@ -186,51 +216,19 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain) -class RandomForestModel(JavaModelWrapper): +class RandomForestModel(TreeEnsembleModel): """ - Represents a random forest model. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Represents a random forest model. """ - def predict(self, x): - """ - Predict values for a single data point or an RDD of points using - the model trained. - """ - if isinstance(x, RDD): - return self.call("predict", x.map(_convert_to_vector)) - - else: - return self.call("predict", _convert_to_vector(x)) - - def numTrees(self): - """ - Get number of trees in forest. - """ - return self.call("numTrees") - - def totalNumNodes(self): - """ - Get total number of nodes, summed over all trees in the forest. - """ - return self.call("totalNumNodes") - - def __repr__(self): - """ Summary of model """ - return self._java_model.toString() - - def toDebugString(self): - """ Full model """ - return self._java_model.toDebugString() class RandomForest(object): """ - Learning algorithm for a random forest model for classification or regression. + .. note:: Experimental - EXPERIMENTAL: This is an experimental API. - It will probably be modified in future. + Learning algorithm for a random forest model for classification or regression. """ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird") @@ -383,6 +381,137 @@ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetSt featureSubsetStrategy, impurity, maxDepth, maxBins, seed) +class GradientBoostedTreesModel(TreeEnsembleModel): + """ + .. note:: Experimental + + Represents a gradient-boosted tree model. + """ + + +class GradientBoostedTrees(object): + """ + .. note:: Experimental + + Learning algorithm for a gradient boosted trees model for classification or regression. + """ + + @classmethod + def _train(cls, data, algo, categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth): + first = data.first() + assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint" + model = callMLlibFunc("trainGradientBoostedTreesModel", data, algo, categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + return GradientBoostedTreesModel(model) + + @classmethod + def trainClassifier(cls, data, categoricalFeaturesInfo, + loss="logLoss", numIterations=100, learningRate=0.1, maxDepth=3): + """ + Method to train a gradient-boosted trees model for classification. + + :param data: Training dataset: RDD of LabeledPoint. Labels should take values {0, 1}. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param loss: Loss function used for minimization during gradient boosting. + Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}. + :param numIterations: Number of iterations of boosting. + (default: 100) + :param learningRate: Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1] + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 + leaf node; depth 1 means 1 internal node + 2 leaf nodes. + (default: 3) + :return: GradientBoostedTreesModel that can be used for prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import GradientBoostedTrees + >>> + >>> data = [ + ... LabeledPoint(0.0, [0.0]), + ... LabeledPoint(0.0, [1.0]), + ... LabeledPoint(1.0, [2.0]), + ... LabeledPoint(1.0, [3.0]) + ... ] + >>> + >>> model = GradientBoostedTrees.trainClassifier(sc.parallelize(data), {}) + >>> model.numTrees() + 100 + >>> model.totalNumNodes() + 300 + >>> print model, # it already has newline + TreeEnsembleModel classifier with 100 trees + >>> model.predict([2.0]) + 1.0 + >>> model.predict([0.0]) + 0.0 + >>> rdd = sc.parallelize([[2.0], [0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "classification", categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + + @classmethod + def trainRegressor(cls, data, categoricalFeaturesInfo, + loss="leastSquaresError", numIterations=100, learningRate=0.1, maxDepth=3): + """ + Method to train a gradient-boosted trees model for regression. + + :param data: Training dataset: RDD of LabeledPoint. Labels are + real numbers. + :param categoricalFeaturesInfo: Map storing arity of categorical + features. E.g., an entry (n -> k) indicates that feature + n is categorical with k categories indexed from 0: + {0, 1, ..., k-1}. + :param loss: Loss function used for minimization during gradient boosting. + Supported: {"logLoss" (default), "leastSquaresError", "leastAbsoluteError"}. + :param numIterations: Number of iterations of boosting. + (default: 100) + :param learningRate: Learning rate for shrinking the contribution of each estimator. + The learning rate should be between in the interval (0, 1] + (default: 0.1) + :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 + leaf node; depth 1 means 1 internal node + 2 leaf nodes. + (default: 3) + :return: GradientBoostedTreesModel that can be used for prediction + + Example usage: + + >>> from pyspark.mllib.regression import LabeledPoint + >>> from pyspark.mllib.tree import GradientBoostedTrees + >>> from pyspark.mllib.linalg import SparseVector + >>> + >>> sparse_data = [ + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})), + ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})), + ... LabeledPoint(1.0, SparseVector(2, {1: 2.0})) + ... ] + >>> + >>> model = GradientBoostedTrees.trainRegressor(sc.parallelize(sparse_data), {}) + >>> model.numTrees() + 100 + >>> model.totalNumNodes() + 102 + >>> model.predict(SparseVector(2, {1: 1.0})) + 1.0 + >>> model.predict(SparseVector(2, {0: 1.0})) + 0.0 + >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]]) + >>> model.predict(rdd).collect() + [1.0, 0.0] + """ + return cls._train(data, "regression", categoricalFeaturesInfo, + loss, numIterations, learningRate, maxDepth) + + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/profiler.py b/python/pyspark/profiler.py new file mode 100644 index 0000000000000..4408996db0790 --- /dev/null +++ b/python/pyspark/profiler.py @@ -0,0 +1,172 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import cProfile +import pstats +import os +import atexit + +from pyspark.accumulators import AccumulatorParam + + +class ProfilerCollector(object): + """ + This class keeps track of different profilers on a per + stage basis. Also this is used to create new profilers for + the different stages. + """ + + def __init__(self, profiler_cls, dump_path=None): + self.profiler_cls = profiler_cls + self.profile_dump_path = dump_path + self.profilers = [] + + def new_profiler(self, ctx): + """ Create a new profiler using class `profiler_cls` """ + return self.profiler_cls(ctx) + + def add_profiler(self, id, profiler): + """ Add a profiler for RDD `id` """ + if not self.profilers: + if self.profile_dump_path: + atexit.register(self.dump_profiles, self.profile_dump_path) + else: + atexit.register(self.show_profiles) + + self.profilers.append([id, profiler, False]) + + def dump_profiles(self, path): + """ Dump the profile stats into directory `path` """ + for id, profiler, _ in self.profilers: + profiler.dump(id, path) + self.profilers = [] + + def show_profiles(self): + """ Print the profile stats to stdout """ + for i, (id, profiler, showed) in enumerate(self.profilers): + if not showed and profiler: + profiler.show(id) + # mark it as showed + self.profilers[i][2] = True + + +class Profiler(object): + """ + .. note:: DeveloperApi + + PySpark supports custom profilers, this is to allow for different profilers to + be used as well as outputting to different formats than what is provided in the + BasicProfiler. + + A custom profiler has to define or inherit the following methods: + profile - will produce a system profile of some sort. + stats - return the collected stats. + dump - dumps the profiles to a path + add - adds a profile to the existing accumulated profile + + The profiler class is chosen when creating a SparkContext + + >>> from pyspark import SparkConf, SparkContext + >>> from pyspark import BasicProfiler + >>> class MyCustomProfiler(BasicProfiler): + ... def show(self, id): + ... print "My custom profiles for RDD:%s" % id + ... + >>> conf = SparkConf().set("spark.python.profile", "true") + >>> sc = SparkContext('local', 'test', conf=conf, profiler_cls=MyCustomProfiler) + >>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) + [0, 2, 4, 6, 8, 10, 12, 14, 16, 18] + >>> sc.show_profiles() + My custom profiles for RDD:1 + My custom profiles for RDD:2 + >>> sc.stop() + """ + + def __init__(self, ctx): + pass + + def profile(self, func): + """ Do profiling on the function `func`""" + raise NotImplemented + + def stats(self): + """ Return the collected profiling stats (pstats.Stats)""" + raise NotImplemented + + def show(self, id): + """ Print the profile stats to stdout, id is the RDD id """ + stats = self.stats() + if stats: + print "=" * 60 + print "Profile of RDD" % id + print "=" * 60 + stats.sort_stats("time", "cumulative").print_stats() + + def dump(self, id, path): + """ Dump the profile into path, id is the RDD id """ + if not os.path.exists(path): + os.makedirs(path) + stats = self.stats() + if stats: + p = os.path.join(path, "rdd_%d.pstats" % id) + stats.dump_stats(p) + + +class PStatsParam(AccumulatorParam): + """PStatsParam is used to merge pstats.Stats""" + + @staticmethod + def zero(value): + return None + + @staticmethod + def addInPlace(value1, value2): + if value1 is None: + return value2 + value1.add(value2) + return value1 + + +class BasicProfiler(Profiler): + """ + BasicProfiler is the default profiler, which is implemented based on + cProfile and Accumulator + """ + def __init__(self, ctx): + Profiler.__init__(self, ctx) + # Creates a new accumulator for combining the profiles of different + # partitions of a stage + self._accumulator = ctx.accumulator(None, PStatsParam) + + def profile(self, func): + """ Runs and profiles the method to_profile passed in. A profile object is returned. """ + pr = cProfile.Profile() + pr.runcall(func) + st = pstats.Stats(pr) + st.stream = None # make it picklable + st.strip_dirs() + + # Adds a new profile to the existing accumulated value + self._accumulator.add(st) + + def stats(self): + return self._accumulator.value + + +if __name__ == "__main__": + import doctest + doctest.testmod() diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index f4cfe4845dc20..2f8a0edfe9644 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -29,9 +29,8 @@ import heapq import bisect import random -from math import sqrt, log, isinf, isnan +from math import sqrt, log, isinf, isnan, pow, ceil -from pyspark.accumulators import PStatsParam from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer @@ -727,6 +726,43 @@ def func(iterator): return reduce(f, vals) raise ValueError("Can not reduce() empty RDD") + def treeReduce(self, f, depth=2): + """ + Reduces the elements of this RDD in a multi-level tree pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeReduce(add) + -5 + >>> rdd.treeReduce(add, 1) + -5 + >>> rdd.treeReduce(add, 2) + -5 + >>> rdd.treeReduce(add, 5) + -5 + >>> rdd.treeReduce(add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + zeroValue = None, True # Use the second entry to indicate whether this is a dummy value. + + def op(x, y): + if x[1]: + return y + elif y[1]: + return x + else: + return f(x[0], y[0]), False + + reduced = self.map(lambda x: (x, False)).treeAggregate(zeroValue, op, op, depth) + if reduced[1]: + raise ValueError("Cannot reduce empty RDD.") + return reduced[0] + def fold(self, zeroValue, op): """ Aggregate the elements of each partition, and then the results for all @@ -778,6 +814,58 @@ def func(iterator): return self.mapPartitions(func).fold(zeroValue, combOp) + def treeAggregate(self, zeroValue, seqOp, combOp, depth=2): + """ + Aggregates the elements of this RDD in a multi-level tree + pattern. + + :param depth: suggested depth of the tree (default: 2) + + >>> add = lambda x, y: x + y + >>> rdd = sc.parallelize([-5, -4, -3, -2, -1, 1, 2, 3, 4], 10) + >>> rdd.treeAggregate(0, add, add) + -5 + >>> rdd.treeAggregate(0, add, add, 1) + -5 + >>> rdd.treeAggregate(0, add, add, 2) + -5 + >>> rdd.treeAggregate(0, add, add, 5) + -5 + >>> rdd.treeAggregate(0, add, add, 10) + -5 + """ + if depth < 1: + raise ValueError("Depth cannot be smaller than 1 but got %d." % depth) + + if self.getNumPartitions() == 0: + return zeroValue + + def aggregatePartition(iterator): + acc = zeroValue + for obj in iterator: + acc = seqOp(acc, obj) + yield acc + + partiallyAggregated = self.mapPartitions(aggregatePartition) + numPartitions = partiallyAggregated.getNumPartitions() + scale = max(int(ceil(pow(numPartitions, 1.0 / depth))), 2) + # If creating an extra level doesn't help reduce the wall-clock time, we stop the tree + # aggregation. + while numPartitions > scale + numPartitions / scale: + numPartitions /= scale + curNumPartitions = numPartitions + + def mapPartition(i, iterator): + for obj in iterator: + yield (i % curNumPartitions, obj) + + partiallyAggregated = partiallyAggregated \ + .mapPartitionsWithIndex(mapPartition) \ + .reduceByKey(combOp, curNumPartitions) \ + .values() + + return partiallyAggregated.reduce(combOp) + def max(self, key=None): """ Find the maximum item in this RDD. @@ -1634,8 +1722,8 @@ def groupByKey(self, numPartitions=None): Hash-partitions the resulting RDD with into numPartitions partitions. Note: If you are grouping in order to perform an aggregation (such as a - sum or average) over each key, using reduceByKey will provide much - better performance. + sum or average) over each key, using reduceByKey or aggregateByKey will + provide much better performance. >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> map((lambda (x,y): (x, list(y))), sorted(x.groupByKey().collect())) @@ -2059,6 +2147,20 @@ def countApproxDistinct(self, relativeSD=0.05): hashRDD = self.map(lambda x: portable_hash(x) & 0xFFFFFFFF) return hashRDD._to_java_object_rdd().countApproxDistinct(relativeSD) + def toLocalIterator(self): + """ + Return an iterator that contains all of the elements in this RDD. + The iterator will consume as much memory as the largest partition in this RDD. + >>> rdd = sc.parallelize(range(10)) + >>> [x for x in rdd.toLocalIterator()] + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + """ + partitions = xrange(self.getNumPartitions()) + for partition in partitions: + rows = self.context.runJob(self, lambda x: x, [partition]) + for row in rows: + yield row + class PipelinedRDD(RDD): @@ -2118,9 +2220,13 @@ def _jrdd(self): return self._jrdd_val if self._bypass_serializer: self._jrdd_deserializer = NoOpSerializer() - enable_profile = self.ctx._conf.get("spark.python.profile", "false") == "true" - profileStats = self.ctx.accumulator(None, PStatsParam) if enable_profile else None - command = (self.func, profileStats, self._prev_jrdd_deserializer, + + if self.ctx.profiler_collector: + profiler = self.ctx.profiler_collector.new_profiler(self.ctx) + else: + profiler = None + + command = (self.func, profiler, self._prev_jrdd_deserializer, self._jrdd_deserializer) # the serialized command will be compressed by broadcast ser = CloudPickleSerializer() @@ -2143,9 +2249,9 @@ def _jrdd(self): broadcast_vars, self.ctx._javaAccumulator) self._jrdd_val = python_rdd.asJavaRDD() - if enable_profile: + if profiler: self._id = self._jrdd_val.id() - self.ctx._add_profile(self._id, profileStats) + self.ctx.profiler_collector.add_profiler(self._id, profiler) return self._jrdd_val def id(self): diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 1990323249cf6..3f2d7ac82585f 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -20,15 +20,19 @@ - L{SQLContext} Main entry point for SQL functionality. - - L{SchemaRDD} + - L{DataFrame} A Resilient Distributed Dataset (RDD) with Schema information for the data contained. In - addition to normal RDD operations, SchemaRDDs also support SQL. + addition to normal RDD operations, DataFrames also support SQL. + - L{GroupedDataFrame} + - L{Column} + Column is a DataFrame with a single column. - L{Row} A Row of data returned by a Spark SQL query. - L{HiveContext} Main entry point for accessing data stored in Apache Hive.. """ +import sys import itertools import decimal import datetime @@ -36,6 +40,9 @@ import warnings import json import re +import random +import os +from tempfile import NamedTemporaryFile from array import array from operator import itemgetter from itertools import imap @@ -43,6 +50,7 @@ from py4j.protocol import Py4JError from py4j.java_collections import ListConverter, MapConverter +from pyspark.context import SparkContext from pyspark.rdd import RDD from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \ CloudPickleSerializer, UTF8Deserializer @@ -54,7 +62,8 @@ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", - "SQLContext", "HiveContext", "SchemaRDD", "Row"] + "SQLContext", "HiveContext", "DataFrame", "GroupedDataFrame", "Column", "Row", + "SchemaRDD"] class DataType(object): @@ -922,7 +931,7 @@ def _parse_schema_abstract(s): def _infer_schema_type(obj, dataType): """ - Fill the dataType with types infered from obj + Fill the dataType with types inferred from obj >>> schema = _parse_schema_abstract("a b c d") >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) @@ -1171,7 +1180,7 @@ def Dict(d): class Row(tuple): - """ Row in SchemaRDD """ + """ Row in DataFrame """ __DATATYPE__ = dataType __FIELDS__ = tuple(f.name for f in dataType.fields) __slots__ = () @@ -1198,7 +1207,7 @@ class SQLContext(object): """Main entry point for Spark SQL functionality. - A SQLContext can be used create L{SchemaRDD}, register L{SchemaRDD} as + A SQLContext can be used create L{DataFrame}, register L{DataFrame} as tables, execute SQL over tables, cache tables, and read parquet files. """ @@ -1209,8 +1218,8 @@ def __init__(self, sparkContext, sqlContext=None): :param sqlContext: An optional JVM Scala SQLContext. If set, we do not instatiate a new SQLContext in the JVM, instead we make all calls to this object. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.inferSchema(srdd) # doctest: +IGNORE_EXCEPTION_DETAIL + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.inferSchema(df) # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... TypeError:... @@ -1225,12 +1234,12 @@ def __init__(self, sparkContext, sqlContext=None): >>> allTypes = sc.parallelize([Row(i=1, s="string", d=1.0, l=1L, ... b=True, list=[1, 2, 3], dict={"s": 0}, row=Row(a=1), ... time=datetime(2014, 8, 1, 14, 1, 5))]) - >>> srdd = sqlCtx.inferSchema(allTypes) - >>> srdd.registerTempTable("allTypes") + >>> df = sqlCtx.inferSchema(allTypes) + >>> df.registerTempTable("allTypes") >>> sqlCtx.sql('select i+1, d+1, not b, list[1], dict["s"], time, row.a ' ... 'from allTypes where b and i > 0').collect() [Row(c0=2, c1=2.0, c2=False, c3=2, c4=0...8, 1, 14, 1, 5), a=1)] - >>> srdd.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, + >>> df.map(lambda x: (x.i, x.s, x.d, x.l, x.b, x.time, ... x.row.a, x.list)).collect() [(1, u'string', 1.0, 1, True, ...(2014, 8, 1, 14, 1, 5), 1, [1, 2, 3])] """ @@ -1309,23 +1318,23 @@ def inferSchema(self, rdd, samplingRatio=None): ... [Row(field1=1, field2="row1"), ... Row(field1=2, field2="row2"), ... Row(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') >>> NestedRow = Row("f1", "f2") >>> nestedRdd1 = sc.parallelize([ ... NestedRow(array('i', [1, 2]), {"row1": 1.0}), ... NestedRow(array('i', [2, 3]), {"row2": 2.0})]) - >>> srdd = sqlCtx.inferSchema(nestedRdd1) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd1) + >>> df.collect() [Row(f1=[1, 2], f2={u'row1': 1.0}), ..., f2={u'row2': 2.0})] >>> nestedRdd2 = sc.parallelize([ ... NestedRow([[1, 2], [2, 3]], [1, 2]), ... NestedRow([[2, 3], [3, 4]], [2, 3])]) - >>> srdd = sqlCtx.inferSchema(nestedRdd2) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(nestedRdd2) + >>> df.collect() [Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), ..., f2=[2, 3])] >>> from collections import namedtuple @@ -1334,13 +1343,13 @@ def inferSchema(self, rdd, samplingRatio=None): ... [CustomRow(field1=1, field2="row1"), ... CustomRow(field1=2, field2="row2"), ... CustomRow(field1=3, field2="row3")]) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect()[0] + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect()[0] Row(field1=1, field2=u'row1') """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") first = rdd.first() if not first: @@ -1384,10 +1393,10 @@ def applySchema(self, rdd, schema): >>> rdd2 = sc.parallelize([(1, "row1"), (2, "row2"), (3, "row3")]) >>> schema = StructType([StructField("field1", IntegerType(), False), ... StructField("field2", StringType(), False)]) - >>> srdd = sqlCtx.applySchema(rdd2, schema) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT * from table1") - >>> srdd2.collect() + >>> df = sqlCtx.applySchema(rdd2, schema) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT * from table1") + >>> df2.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] >>> from datetime import date, datetime @@ -1410,15 +1419,15 @@ def applySchema(self, rdd, schema): ... StructType([StructField("b", ShortType(), False)]), False), ... StructField("list", ArrayType(ByteType(), False), False), ... StructField("null", DoubleType(), True)]) - >>> srdd = sqlCtx.applySchema(rdd, schema) - >>> results = srdd.map( + >>> df = sqlCtx.applySchema(rdd, schema) + >>> results = df.map( ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, ... x.time, x.map["a"], x.struct.b, x.list, x.null)) >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) - >>> srdd.registerTempTable("table2") + >>> df.registerTempTable("table2") >>> sqlCtx.sql( ... "SELECT byte1 - 1 AS byte1, byte2 + 1 AS byte2, " + ... "short1 + 1 AS short1, short2 - 1 AS short2, int - 1 AS int, " + @@ -1431,13 +1440,13 @@ def applySchema(self, rdd, schema): >>> abstract = "byte short float time map{} struct(b) list[]" >>> schema = _parse_schema_abstract(abstract) >>> typedSchema = _infer_schema_type(rdd.first(), schema) - >>> srdd = sqlCtx.applySchema(rdd, typedSchema) - >>> srdd.collect() + >>> df = sqlCtx.applySchema(rdd, typedSchema) + >>> df.collect() [Row(byte=127, short=-32768, float=1.0, time=..., list=[1, 2, 3])] """ - if isinstance(rdd, SchemaRDD): - raise TypeError("Cannot apply schema to SchemaRDD") + if isinstance(rdd, DataFrame): + raise TypeError("Cannot apply schema to DataFrame") if not isinstance(schema, StructType): raise TypeError("schema should be StructType") @@ -1457,8 +1466,8 @@ def applySchema(self, rdd, schema): rdd = rdd.map(converter) jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd()) - srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json()) + return DataFrame(df, self) def registerRDDAsTable(self, rdd, tableName): """Registers the given RDD as a temporary table in the catalog. @@ -1466,34 +1475,34 @@ def registerRDDAsTable(self, rdd, tableName): Temporary tables exist only during the lifetime of this instance of SQLContext. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") """ - if (rdd.__class__ is SchemaRDD): - srdd = rdd._jschema_rdd.baseSchemaRDD() - self._ssql_ctx.registerRDDAsTable(srdd, tableName) + if (rdd.__class__ is DataFrame): + df = rdd._jdf + self._ssql_ctx.registerRDDAsTable(df, tableName) else: - raise ValueError("Can only register SchemaRDD as table") + raise ValueError("Can only register DataFrame as table") def parquetFile(self, path): - """Loads a Parquet file, returning the result as a L{SchemaRDD}. + """Loads a Parquet file, returning the result as a L{DataFrame}. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - jschema_rdd = self._ssql_ctx.parquetFile(path) - return SchemaRDD(jschema_rdd, self) + jdf = self._ssql_ctx.parquetFile(path) + return DataFrame(jdf, self) def jsonFile(self, path, schema=None, samplingRatio=1.0): """ Loads a text file storing one JSON object per line as a - L{SchemaRDD}. + L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1508,23 +1517,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): >>> for json in jsonStrings: ... print>>ofn, json >>> ofn.close() - >>> srdd1 = sqlCtx.jsonFile(jsonFile) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonFile(jsonFile) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonFile(jsonFile, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonFile(jsonFile, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22,..., f4=[Row(field7=u'row2')]) @@ -1536,23 +1545,23 @@ def jsonFile(self, path, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonFile(jsonFile, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonFile(jsonFile, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)] """ if schema is None: - srdd = self._ssql_ctx.jsonFile(path, samplingRatio) + df = self._ssql_ctx.jsonFile(path, samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonFile(path, scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonFile(path, scala_datatype) + return DataFrame(df, self) def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a L{SchemaRDD}. + """Loads an RDD storing one JSON object per string as a L{DataFrame}. If the schema is provided, applies the given schema to this JSON dataset. @@ -1560,23 +1569,23 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): Otherwise, it samples the dataset with ratio `samplingRatio` to determine the schema. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table1") - >>> for r in srdd2.collect(): + >>> for r in df2.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) Row(f1=None, f2=u'row3', f3=Row(field4=33, field5=[]), f4=None) - >>> srdd3 = sqlCtx.jsonRDD(json, srdd1.schema()) - >>> sqlCtx.registerRDDAsTable(srdd3, "table2") - >>> srdd4 = sqlCtx.sql( + >>> df3 = sqlCtx.jsonRDD(json, df1.schema()) + >>> sqlCtx.registerRDDAsTable(df3, "table2") + >>> df4 = sqlCtx.sql( ... "SELECT field1 AS f1, field2 as f2, field3 as f3, " ... "field6 as f4 from table2") - >>> for r in srdd4.collect(): + >>> for r in df4.collect(): ... print r Row(f1=1, f2=u'row1', f3=Row(field4=11, field5=None), f4=None) Row(f1=2, f2=None, f3=Row(field4=22..., f4=[Row(field7=u'row2')]) @@ -1588,12 +1597,12 @@ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): ... StructType([ ... StructField("field5", ... ArrayType(IntegerType(), False), True)]), False)]) - >>> srdd5 = sqlCtx.jsonRDD(json, schema) - >>> sqlCtx.registerRDDAsTable(srdd5, "table3") - >>> srdd6 = sqlCtx.sql( + >>> df5 = sqlCtx.jsonRDD(json, schema) + >>> sqlCtx.registerRDDAsTable(df5, "table3") + >>> df6 = sqlCtx.sql( ... "SELECT field2 AS f1, field3.field5 as f2, " ... "field3.field5[0] as f3 from table3") - >>> srdd6.collect() + >>> df6.collect() [Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)] >>> sqlCtx.jsonRDD(sc.parallelize(['{}', @@ -1615,33 +1624,33 @@ def func(iterator): keyed._bypass_serializer = True jrdd = keyed._jrdd.map(self._jvm.BytesToString()) if schema is None: - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) else: scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return SchemaRDD(srdd, self) + df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) + return DataFrame(df, self) def sql(self, sqlQuery): - """Return a L{SchemaRDD} representing the result of the given query. + """Return a L{DataFrame} representing the result of the given query. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") - >>> srdd2.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.sql("SELECT field1 AS f1, field2 as f2 from table1") + >>> df2.collect() [Row(f1=1, f2=u'row1'), Row(f1=2, f2=u'row2'), Row(f1=3, f2=u'row3')] """ - return SchemaRDD(self._ssql_ctx.sql(sqlQuery), self) + return DataFrame(self._ssql_ctx.sql(sqlQuery), self) def table(self, tableName): - """Returns the specified table as a L{SchemaRDD}. + """Returns the specified table as a L{DataFrame}. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> sqlCtx.registerRDDAsTable(srdd, "table1") - >>> srdd2 = sqlCtx.table("table1") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> sqlCtx.registerRDDAsTable(df, "table1") + >>> df2 = sqlCtx.table("table1") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - return SchemaRDD(self._ssql_ctx.table(tableName), self) + return DataFrame(self._ssql_ctx.table(tableName), self) def cacheTable(self, tableName): """Caches the specified table in-memory.""" @@ -1707,7 +1716,7 @@ def _create_row(fields, values): class Row(tuple): """ - A row in L{SchemaRDD}. The fields in it can be accessed like attributes. + A row in L{DataFrame}. The fields in it can be accessed like attributes. Row can be used to create a row object by using named arguments, the fields will be sorted by names. @@ -1785,125 +1794,119 @@ def __repr__(self): return "" % ", ".join(self) -def inherit_doc(cls): - for name, func in vars(cls).items(): - # only inherit docstring for public functions - if name.startswith("_"): - continue - if not func.__doc__: - for parent in cls.__bases__: - parent_func = getattr(parent, name, None) - if parent_func and getattr(parent_func, "__doc__", None): - func.__doc__ = parent_func.__doc__ - break - return cls +class DataFrame(object): + """A collection of rows that have the same columns. -@inherit_doc -class SchemaRDD(RDD): + A :class:`DataFrame` is equivalent to a relational table in Spark SQL, + and can be created using various functions in :class:`SQLContext`:: - """An RDD of L{Row} objects that has an associated schema. + people = sqlContext.parquetFile("...") - The underlying JVM object is a SchemaRDD, not a PythonRDD, so we can - utilize the relational query api exposed by Spark SQL. + Once created, it can be manipulated using the various domain-specific-language + (DSL) functions defined in: [[DataFrame]], [[Column]]. - For normal L{pyspark.rdd.RDD} operations (map, count, etc.) the - L{SchemaRDD} is not operated on directly, as it's underlying - implementation is an RDD composed of Java objects. Instead it is - converted to a PythonRDD in the JVM, on which Python operations can - be done. + To select a column from the data frame, use the apply method:: - This class receives raw tuples from Java but assigns a class to it in - all its data-collection methods (mapPartitionsWithIndex, collect, take, - etc) so that PySpark sees them as Row objects with named fields. + ageCol = people.age + + Note that the :class:`Column` type can also be manipulated + through its various functions:: + + # The following creates a new column that increases everybody's age by 10. + people.age + 10 + + + A more concrete example:: + + # To create DataFrame using SQLContext + people = sqlContext.parquetFile("...") + department = sqlContext.parquetFile("...") + + people.filter(people.age > 30).join(department, people.deptId == department.id)) \ + .groupBy(department.name, "gender").agg({"salary": "avg", "age": "max"}) """ - def __init__(self, jschema_rdd, sql_ctx): + def __init__(self, jdf, sql_ctx): + self._jdf = jdf self.sql_ctx = sql_ctx - self._sc = sql_ctx._sc - clsName = jschema_rdd.getClass().getName() - assert clsName.endswith("SchemaRDD"), "jschema_rdd must be SchemaRDD" - self._jschema_rdd = jschema_rdd - self._id = None + self._sc = sql_ctx and sql_ctx._sc self.is_cached = False - self.is_checkpointed = False - self.ctx = self.sql_ctx._sc - # the _jrdd is created by javaToPython(), serialized by pickle - self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer()) @property - def _jrdd(self): - """Lazy evaluation of PythonRDD object. + def rdd(self): + """Return the content of the :class:`DataFrame` as an :class:`RDD` + of :class:`Row`s. """ + if not hasattr(self, '_lazy_rdd'): + jrdd = self._jdf.javaToPython() + rdd = RDD(jrdd, self.sql_ctx._sc, BatchedSerializer(PickleSerializer())) + schema = self.schema() - Only done when a user calls methods defined by the - L{pyspark.rdd.RDD} super class (map, filter, etc.). - """ - if not hasattr(self, '_lazy_jrdd'): - self._lazy_jrdd = self._jschema_rdd.baseSchemaRDD().javaToPython() - return self._lazy_jrdd + def applySchema(it): + cls = _create_cls(schema) + return itertools.imap(cls, it) - def id(self): - if self._id is None: - self._id = self._jrdd.id() - return self._id + self._lazy_rdd = rdd.mapPartitions(applySchema) + + return self._lazy_rdd def limit(self, num): """Limit the result count to the number specified. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.limit(2).collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.limit(2).collect() [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] - >>> srdd.limit(0).collect() + >>> df.limit(0).collect() [] """ - rdd = self._jschema_rdd.baseSchemaRDD().limit(num) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.limit(num) + return DataFrame(jdf, self.sql_ctx) def toJSON(self, use_unicode=False): - """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row. + """Convert a DataFrame into a MappedRDD of JSON documents; one document per row. - >>> srdd1 = sqlCtx.jsonRDD(json) - >>> sqlCtx.registerRDDAsTable(srdd1, "table1") - >>> srdd2 = sqlCtx.sql( "SELECT * from table1") - >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' + >>> df1 = sqlCtx.jsonRDD(json) + >>> sqlCtx.registerRDDAsTable(df1, "table1") + >>> df2 = sqlCtx.sql( "SELECT * from table1") + >>> df2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}' True - >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1") - >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] + >>> df3 = sqlCtx.sql( "SELECT field3.field4 from table1") + >>> df3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}'] True """ - rdd = self._jschema_rdd.baseSchemaRDD().toJSON() + rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) def saveAsParquetFile(self, path): """Save the contents as a Parquet file, preserving the schema. Files that are written out using this method can be read back in as - a SchemaRDD using the L{SQLContext.parquetFile} method. + a DataFrame using the L{SQLContext.parquetFile} method. >>> import tempfile, shutil >>> parquetFile = tempfile.mkdtemp() >>> shutil.rmtree(parquetFile) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.saveAsParquetFile(parquetFile) - >>> srdd2 = sqlCtx.parquetFile(parquetFile) - >>> sorted(srdd2.collect()) == sorted(srdd.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.saveAsParquetFile(parquetFile) + >>> df2 = sqlCtx.parquetFile(parquetFile) + >>> sorted(df2.collect()) == sorted(df.collect()) True """ - self._jschema_rdd.saveAsParquetFile(path) + self._jdf.saveAsParquetFile(path) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. The lifetime of this temporary table is tied to the L{SQLContext} - that was used to create this SchemaRDD. + that was used to create this DataFrame. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.registerTempTable("test") - >>> srdd2 = sqlCtx.sql("select * from test") - >>> sorted(srdd.collect()) == sorted(srdd2.collect()) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.registerTempTable("test") + >>> df2 = sqlCtx.sql("select * from test") + >>> sorted(df.collect()) == sorted(df2.collect()) True """ - self._jschema_rdd.registerTempTable(name) + self._jdf.registerTempTable(name) def registerAsTable(self, name): """DEPRECATED: use registerTempTable() instead""" @@ -1911,62 +1914,61 @@ def registerAsTable(self, name): self.registerTempTable(name) def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this SchemaRDD into the specified table. + """Inserts the contents of this DataFrame into the specified table. Optionally overwriting any existing data. """ - self._jschema_rdd.insertInto(tableName, overwrite) + self._jdf.insertInto(tableName, overwrite) def saveAsTable(self, tableName): - """Creates a new table with the contents of this SchemaRDD.""" - self._jschema_rdd.saveAsTable(tableName) + """Creates a new table with the contents of this DataFrame.""" + self._jdf.saveAsTable(tableName) def schema(self): - """Returns the schema of this SchemaRDD (represented by + """Returns the schema of this DataFrame (represented by a L{StructType}).""" - return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json()) - - def schemaString(self): - """Returns the output schema in the tree format.""" - return self._jschema_rdd.schemaString() + return _parse_datatype_json_string(self._jdf.schema().json()) def printSchema(self): """Prints out the schema in the tree format.""" - print self.schemaString() + print (self._jdf.schema().treeString()) def count(self): """Return the number of elements in this RDD. Unlike the base RDD implementation of count, this implementation - leverages the query optimizer to compute the count on the SchemaRDD, + leverages the query optimizer to compute the count on the DataFrame, which supports features such as filter pushdown. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.count() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.count() 3L - >>> srdd.count() == srdd.map(lambda x: x).count() + >>> df.count() == df.map(lambda x: x).count() True """ - return self._jschema_rdd.count() + return self._jdf.count() def collect(self): - """Return a list that contains all of the rows in this RDD. + """Return a list that contains all of the rows. Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of collect, this implementation - leverages the query optimizer to perform a collect on the SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.collect() + >>> df = sqlCtx.inferSchema(rdd) + >>> df.collect() [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')] """ - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator() + with SCCallSiteSync(self._sc) as css: + bytesInJava = self._jdf.javaToPython().collect().iterator() cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir) + tempFile.close() + self._sc._writeToFile(bytesInJava, tempFile.name) + # Read the data into Python and deserialize it: + with open(tempFile.name, 'rb') as tempFile: + rs = list(BatchedSerializer(PickleSerializer()).load_stream(tempFile)) + os.unlink(tempFile.name) + return [cls(r) for r in rs] def take(self, num): """Take the first num rows of the RDD. @@ -1974,130 +1976,561 @@ def take(self, num): Each object in the list is a Row, the fields can be accessed as attributes. - Unlike the base RDD implementation of take, this implementation - leverages the query optimizer to perform a collect on a SchemaRDD, - which supports features such as filter pushdown. - - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.take(2) + >>> df = sqlCtx.inferSchema(rdd) + >>> df.take(2) [Row(field1=1, field2=u'row1'), Row(field1=2, field2=u'row2')] """ return self.limit(num).collect() - # Convert each object in the RDD to a Row with the right class - # for this SchemaRDD, so that fields can be accessed as attributes. - def mapPartitionsWithIndex(self, f, preservesPartitioning=False): + def map(self, f): + """ Return a new RDD by applying a function to each Row, it's a + shorthand for df.rdd.map() """ - Return a new RDD by applying a function to each partition of this RDD, - while tracking the index of the original partition. + return self.rdd.map(f) - >>> rdd = sc.parallelize([1, 2, 3, 4], 4) - >>> def f(splitIndex, iterator): yield splitIndex - >>> rdd.mapPartitionsWithIndex(f).sum() - 6 + def mapPartitions(self, f, preservesPartitioning=False): """ - rdd = RDD(self._jrdd, self._sc, self._jrdd_deserializer) - - schema = self.schema() + Return a new RDD by applying a function to each partition. - def applySchema(_, it): - cls = _create_cls(schema) - return itertools.imap(cls, it) - - objrdd = rdd.mapPartitionsWithIndex(applySchema, preservesPartitioning) - return objrdd.mapPartitionsWithIndex(f, preservesPartitioning) + >>> rdd = sc.parallelize([1, 2, 3, 4], 4) + >>> def f(iterator): yield 1 + >>> rdd.mapPartitions(f).sum() + 4 + """ + return self.rdd.mapPartitions(f, preservesPartitioning) - # We override the default cache/persist/checkpoint behavior - # as we want to cache the underlying SchemaRDD object in the JVM, - # not the PythonRDD checkpointed by the super class def cache(self): + """ Persist with the default storage level (C{MEMORY_ONLY_SER}). + """ self.is_cached = True - self._jschema_rdd.cache() + self._jdf.cache() return self def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + """ Set the storage level to persist its values across operations + after the first time it is computed. This can only be used to assign + a new storage level if the RDD does not have a storage level set yet. + If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + """ self.is_cached = True - javaStorageLevel = self.ctx._getJavaStorageLevel(storageLevel) - self._jschema_rdd.persist(javaStorageLevel) + javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) + self._jdf.persist(javaStorageLevel) return self def unpersist(self, blocking=True): + """ Mark it as non-persistent, and remove all blocks for it from + memory and disk. + """ self.is_cached = False - self._jschema_rdd.unpersist(blocking) + self._jdf.unpersist(blocking) return self - def checkpoint(self): - self.is_checkpointed = True - self._jschema_rdd.checkpoint() + # def coalesce(self, numPartitions, shuffle=False): + # rdd = self._jdf.coalesce(numPartitions, shuffle, None) + # return DataFrame(rdd, self.sql_ctx) + + def repartition(self, numPartitions): + """ Return a new :class:`DataFrame` that has exactly `numPartitions` + partitions. + """ + rdd = self._jdf.repartition(numPartitions, None) + return DataFrame(rdd, self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ + Return a sampled subset of this DataFrame. + + >>> df = sqlCtx.inferSchema(rdd) + >>> df.sample(False, 0.5, 97).count() + 2L + """ + assert fraction >= 0.0, "Negative fraction value: %s" % fraction + seed = seed if seed is not None else random.randint(0, sys.maxint) + rdd = self._jdf.sample(withReplacement, fraction, long(seed)) + return DataFrame(rdd, self.sql_ctx) + + # def takeSample(self, withReplacement, num, seed=None): + # """Return a fixed-size sampled subset of this DataFrame. + # + # >>> df = sqlCtx.inferSchema(rdd) + # >>> df.takeSample(False, 2, 97) + # [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + # """ + # seed = seed if seed is not None else random.randint(0, sys.maxint) + # with SCCallSiteSync(self.context) as css: + # bytesInJava = self._jdf \ + # .takeSampleToPython(withReplacement, num, long(seed)) \ + # .iterator() + # cls = _create_cls(self.schema()) + # return map(cls, self._collect_iterator_through_file(bytesInJava)) + + @property + def dtypes(self): + """Return all column names and their data types as a list. + """ + return [(f.name, str(f.dataType)) for f in self.schema().fields] - def isCheckpointed(self): - return self._jschema_rdd.isCheckpointed() + @property + def columns(self): + """ Return all column names as a list. + """ + return [f.name for f in self.schema().fields] - def getCheckpointFile(self): - checkpointFile = self._jschema_rdd.getCheckpointFile() - if checkpointFile.isDefined(): - return checkpointFile.get() + def show(self): + raise NotImplemented - def coalesce(self, numPartitions, shuffle=False): - rdd = self._jschema_rdd.coalesce(numPartitions, shuffle, None) - return SchemaRDD(rdd, self.sql_ctx) + def join(self, other, joinExprs=None, joinType=None): + """ + Join with another DataFrame, using the given join expression. + The following performs a full outer join between `df1` and `df2`:: - def distinct(self, numPartitions=None): - if numPartitions is None: - rdd = self._jschema_rdd.distinct() + df1.join(df2, df1.key == df2.key, "outer") + + :param other: Right side of the join + :param joinExprs: Join expression + :param joinType: One of `inner`, `outer`, `left_outer`, `right_outer`, + `semijoin`. + """ + if joinType is None: + if joinExprs is None: + jdf = self._jdf.join(other._jdf) + else: + jdf = self._jdf.join(other._jdf, joinExprs) else: - rdd = self._jschema_rdd.distinct(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) + jdf = self._jdf.join(other._jdf, joinExprs, joinType) + return DataFrame(jdf, self.sql_ctx) + + def sort(self, *cols): + """ Return a new [[DataFrame]] sorted by the specified column, + in ascending column. - def intersection(self, other): - if (other.__class__ is SchemaRDD): - rdd = self._jschema_rdd.intersection(other._jschema_rdd) - return SchemaRDD(rdd, self.sql_ctx) + :param cols: The columns or expressions used for sorting + """ + if not cols: + raise ValueError("should sort by at least one column") + for i, c in enumerate(cols): + if isinstance(c, basestring): + cols[i] = Column(c) + jcols = [c._jc for c in cols] + jdf = self._jdf.join(*jcols) + return DataFrame(jdf, self.sql_ctx) + + sortBy = sort + + def head(self, n=None): + """ Return the first `n` rows or the first row if n is None. """ + if n is None: + rs = self.head(1) + return rs[0] if rs else None + return self.take(n) + + def tail(self): + raise NotImplemented + + def __getitem__(self, item): + if isinstance(item, basestring): + return Column(self._jdf.apply(item)) + + # TODO projection + raise IndexError + + def __getattr__(self, name): + """ Return the column by given name """ + if name.startswith("__"): + raise AttributeError(name) + return Column(self._jdf.apply(name)) + + def alias(self, name): + """ Alias the current DataFrame """ + return DataFrame(getattr(self._jdf, "as")(name), self.sql_ctx) + + def select(self, *cols): + """ Selecting a set of expressions.:: + + df.select() + df.select('colA', 'colB') + df.select(df.colA, df.colB + 1) + + """ + if not cols: + cols = ["*"] + if isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] else: - raise ValueError("Can only intersect with another SchemaRDD") + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.select(self._jdf.toColumnArray(jcols)) + return DataFrame(jdf, self.sql_ctx) - def repartition(self, numPartitions): - rdd = self._jschema_rdd.repartition(numPartitions, None) - return SchemaRDD(rdd, self.sql_ctx) + def filter(self, condition): + """ Filtering rows using the given condition:: - def subtract(self, other, numPartitions=None): - if (other.__class__ is SchemaRDD): - if numPartitions is None: - rdd = self._jschema_rdd.subtract(other._jschema_rdd) - else: - rdd = self._jschema_rdd.subtract(other._jschema_rdd, - numPartitions) - return SchemaRDD(rdd, self.sql_ctx) + df.filter(df.age > 15) + df.where(df.age > 15) + + """ + return DataFrame(self._jdf.filter(condition._jc), self.sql_ctx) + + where = filter + + def groupBy(self, *cols): + """ Group the [[DataFrame]] using the specified columns, + so we can run aggregation on them. See :class:`GroupedDataFrame` + for all the available aggregate functions:: + + df.groupBy(df.department).avg() + df.groupBy("department", "gender").agg({ + "salary": "avg", + "age": "max", + }) + """ + if cols and isinstance(cols[0], basestring): + cols = [_create_column_from_name(n) for n in cols] else: - raise ValueError("Can only subtract another SchemaRDD") + cols = [c._jc for c in cols] + jcols = ListConverter().convert(cols, self._sc._gateway._gateway_client) + jdf = self._jdf.groupBy(self._jdf.toColumnArray(jcols)) + return GroupedDataFrame(jdf, self.sql_ctx) - def sample(self, withReplacement, fraction, seed=None): + def agg(self, *exprs): + """ Aggregate on the entire [[DataFrame]] without groups + (shorthand for df.groupBy.agg()):: + + df.agg({"age": "max", "salary": "avg"}) """ - Return a sampled subset of this SchemaRDD. + return self.groupBy().agg(*exprs) - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.sample(False, 0.5, 97).count() - 2L + def unionAll(self, other): + """ Return a new DataFrame containing union of rows in this + frame and another frame. + + This is equivalent to `UNION ALL` in SQL. """ - assert fraction >= 0.0, "Negative fraction value: %s" % fraction - seed = seed if seed is not None else random.randint(0, sys.maxint) - rdd = self._jschema_rdd.sample(withReplacement, fraction, long(seed)) - return SchemaRDD(rdd, self.sql_ctx) + return DataFrame(self._jdf.unionAll(other._jdf), self.sql_ctx) - def takeSample(self, withReplacement, num, seed=None): - """Return a fixed-size sampled subset of this SchemaRDD. + def intersect(self, other): + """ Return a new [[DataFrame]] containing rows only in + both this frame and another frame. - >>> srdd = sqlCtx.inferSchema(rdd) - >>> srdd.takeSample(False, 2, 97) - [Row(field1=3, field2=u'row3'), Row(field1=1, field2=u'row1')] + This is equivalent to `INTERSECT` in SQL. """ - seed = seed if seed is not None else random.randint(0, sys.maxint) - with SCCallSiteSync(self.context) as css: - bytesInJava = self._jschema_rdd.baseSchemaRDD() \ - .takeSampleToPython(withReplacement, num, long(seed)) \ - .iterator() - cls = _create_cls(self.schema()) - return map(cls, self._collect_iterator_through_file(bytesInJava)) + return DataFrame(self._jdf.intersect(other._jdf), self.sql_ctx) + + def subtract(self, other): + """ Return a new [[DataFrame]] containing rows in this frame + but not in another frame. + + This is equivalent to `EXCEPT` in SQL. + """ + return DataFrame(getattr(self._jdf, "except")(other._jdf), self.sql_ctx) + + def sample(self, withReplacement, fraction, seed=None): + """ Return a new DataFrame by sampling a fraction of rows. """ + if seed is None: + jdf = self._jdf.sample(withReplacement, fraction) + else: + jdf = self._jdf.sample(withReplacement, fraction, seed) + return DataFrame(jdf, self.sql_ctx) + + def addColumn(self, colName, col): + """ Return a new [[DataFrame]] by adding a column. """ + return self.select('*', col.alias(colName)) + + def removeColumn(self, colName): + raise NotImplemented + + +# Having SchemaRDD for backward compatibility (for docs) +class SchemaRDD(DataFrame): + """ + SchemaRDD is deprecated, please use DataFrame + """ + + +def dfapi(f): + def _api(self): + name = f.__name__ + jdf = getattr(self._jdf, name)() + return DataFrame(jdf, self.sql_ctx) + _api.__name__ = f.__name__ + _api.__doc__ = f.__doc__ + return _api + + +class GroupedDataFrame(object): + + """ + A set of methods for aggregations on a :class:`DataFrame`, + created by DataFrame.groupBy(). + """ + + def __init__(self, jdf, sql_ctx): + self._jdf = jdf + self.sql_ctx = sql_ctx + + def agg(self, *exprs): + """ Compute aggregates by specifying a map from column name + to aggregate methods. + + The available aggregate methods are `avg`, `max`, `min`, + `sum`, `count`. + + :param exprs: list or aggregate columns or a map from column + name to agregate methods. + """ + if len(exprs) == 1 and isinstance(exprs[0], dict): + jmap = MapConverter().convert(exprs[0], + self.sql_ctx._sc._gateway._gateway_client) + jdf = self._jdf.agg(jmap) + else: + # Columns + assert all(isinstance(c, Column) for c in exprs), "all exprs should be Columns" + jdf = self._jdf.agg(*exprs) + return DataFrame(jdf, self.sql_ctx) + + @dfapi + def count(self): + """ Count the number of rows for each group. """ + + @dfapi + def mean(self): + """Compute the average value for each numeric columns + for each group. This is an alias for `avg`.""" + + @dfapi + def avg(self): + """Compute the average value for each numeric columns + for each group.""" + + @dfapi + def max(self): + """Compute the max value for each numeric columns for + each group. """ + + @dfapi + def min(self): + """Compute the min value for each numeric column for + each group.""" + + @dfapi + def sum(self): + """Compute the sum for each numeric columns for each + group.""" + + +SCALA_METHOD_MAPPINGS = { + '=': '$eq', + '>': '$greater', + '<': '$less', + '+': '$plus', + '-': '$minus', + '*': '$times', + '/': '$div', + '!': '$bang', + '@': '$at', + '#': '$hash', + '%': '$percent', + '^': '$up', + '&': '$amp', + '~': '$tilde', + '?': '$qmark', + '|': '$bar', + '\\': '$bslash', + ':': '$colon', +} + + +def _create_column_from_literal(literal): + sc = SparkContext._active_spark_context + return sc._jvm.org.apache.spark.sql.Dsl.lit(literal) + + +def _create_column_from_name(name): + sc = SparkContext._active_spark_context + return sc._jvm.Column(name) + + +def _scalaMethod(name): + """ Translate operators into methodName in Scala + + For example: + >>> _scalaMethod('+') + '$plus' + >>> _scalaMethod('>=') + '$greater$eq' + >>> _scalaMethod('cast') + 'cast' + """ + return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name) + + +def _unary_op(name): + """ Create a method for given unary operator """ + def _(self): + return Column(getattr(self._jc, _scalaMethod(name))(), self._jdf, self.sql_ctx) + return _ + + +def _bin_op(name, pass_literal_through=False): + """ Create a method for given binary operator + + Keyword arguments: + pass_literal_through -- whether to pass literal value directly through to the JVM. + """ + def _(self, other): + if isinstance(other, Column): + jc = other._jc + else: + if pass_literal_through: + jc = other + else: + jc = _create_column_from_literal(other) + return Column(getattr(self._jc, _scalaMethod(name))(jc), self._jdf, self.sql_ctx) + return _ + + +def _reverse_op(name): + """ Create a method for binary operator (this object is on right side) + """ + def _(self, other): + return Column(getattr(_create_column_from_literal(other), _scalaMethod(name))(self._jc), + self._jdf, self.sql_ctx) + return _ + + +class Column(DataFrame): + + """ + A column in a DataFrame. + + `Column` instances can be created by: + {{{ + // 1. Select a column out of a DataFrame + df.colName + df["colName"] + + // 2. Create from an expression + df["colName"] + 1 + }}} + """ + + def __init__(self, jc, jdf=None, sql_ctx=None): + self._jc = jc + super(Column, self).__init__(jdf, sql_ctx) + + # arithmetic operators + __neg__ = _unary_op("unary_-") + __add__ = _bin_op("+") + __sub__ = _bin_op("-") + __mul__ = _bin_op("*") + __div__ = _bin_op("/") + __mod__ = _bin_op("%") + __radd__ = _bin_op("+") + __rsub__ = _reverse_op("-") + __rmul__ = _bin_op("*") + __rdiv__ = _reverse_op("/") + __rmod__ = _reverse_op("%") + __abs__ = _unary_op("abs") + abs = _unary_op("abs") + sqrt = _unary_op("sqrt") + + # logistic operators + __eq__ = _bin_op("===") + __ne__ = _bin_op("!==") + __lt__ = _bin_op("<") + __le__ = _bin_op("<=") + __ge__ = _bin_op(">=") + __gt__ = _bin_op(">") + # `and`, `or`, `not` cannot be overloaded in Python + And = _bin_op('&&') + Or = _bin_op('||') + Not = _unary_op('unary_!') + + # bitwise operators + __and__ = _bin_op("&") + __or__ = _bin_op("|") + __invert__ = _unary_op("unary_~") + __xor__ = _bin_op("^") + # __lshift__ = _bin_op("<<") + # __rshift__ = _bin_op(">>") + __rand__ = _bin_op("&") + __ror__ = _bin_op("|") + __rxor__ = _bin_op("^") + # __rlshift__ = _reverse_op("<<") + # __rrshift__ = _reverse_op(">>") + + # container operators + __contains__ = _bin_op("contains") + __getitem__ = _bin_op("getItem") + # __getattr__ = _bin_op("getField") + + # string methods + rlike = _bin_op("rlike", pass_literal_through=True) + like = _bin_op("like", pass_literal_through=True) + startswith = _bin_op("startsWith", pass_literal_through=True) + endswith = _bin_op("endsWith", pass_literal_through=True) + upper = _unary_op("upper") + lower = _unary_op("lower") + + def substr(self, startPos, pos): + if type(startPos) != type(pos): + raise TypeError("Can not mix the type") + if isinstance(startPos, (int, long)): + + jc = self._jc.substr(startPos, pos) + elif isinstance(startPos, Column): + jc = self._jc.substr(startPos._jc, pos._jc) + else: + raise TypeError("Unexpected type: %s" % type(startPos)) + return Column(jc, self._jdf, self.sql_ctx) + + __getslice__ = substr + + # order + asc = _unary_op("asc") + desc = _unary_op("desc") + + isNull = _unary_op("isNull") + isNotNull = _unary_op("isNotNull") + + # `as` is keyword + def alias(self, alias): + return Column(getattr(self._jsc, "as")(alias), self._jdf, self.sql_ctx) + + def cast(self, dataType): + if self.sql_ctx is None: + sc = SparkContext._active_spark_context + ssql_ctx = sc._jvm.SQLContext(sc._jsc.sc()) + else: + ssql_ctx = self.sql_ctx._ssql_ctx + jdt = ssql_ctx.parseDataType(dataType.json()) + return Column(self._jc.cast(jdt), self._jdf, self.sql_ctx) + + +def _aggregate_func(name): + """ Create a function for aggregator by name""" + def _(col): + sc = SparkContext._active_spark_context + if isinstance(col, Column): + jcol = col._jc + else: + jcol = _create_column_from_name(col) + jc = getattr(sc._jvm.org.apache.spark.sql.Dsl, name)(jcol) + return Column(jc) + return staticmethod(_) + + +class Aggregator(object): + """ + A collections of builtin aggregators + """ + max = _aggregate_func("max") + min = _aggregate_func("min") + avg = mean = _aggregate_func("mean") + sum = _aggregate_func("sum") + first = _aggregate_func("first") + last = _aggregate_func("last") + count = _aggregate_func("count") def _test(): diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index b474fcf5bfb7e..bec1961f26393 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -23,6 +23,7 @@ from fileinput import input from glob import glob import os +import pydoc import re import shutil import subprocess @@ -53,6 +54,7 @@ from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType from pyspark import shuffle +from pyspark.profiler import BasicProfiler _have_scipy = False _have_numpy = False @@ -714,6 +716,25 @@ def test_sample(self): wr_s21 = rdd.sample(True, 0.4, 21).collect() self.assertNotEqual(set(wr_s11), set(wr_s21)) + def test_multiple_python_java_RDD_conversions(self): + # Regression test for SPARK-5361 + data = [ + (u'1', {u'director': u'David Lean'}), + (u'2', {u'director': u'Andrew Dominik'}) + ] + from pyspark.rdd import RDD + data_rdd = self.sc.parallelize(data) + data_java_rdd = data_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + + # conversion between python and java RDD threw exceptions + data_java_rdd = converted_rdd._to_java_object_rdd() + data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd) + converted_rdd = RDD(data_python_rdd, self.sc) + self.assertEqual(2, converted_rdd.count()) + class ProfilerTests(PySparkTestCase): @@ -724,16 +745,12 @@ def setUp(self): self.sc = SparkContext('local[4]', class_name, conf=conf) def test_profiler(self): + self.do_computation() - def heavy_foo(x): - for i in range(1 << 20): - x = 1 - rdd = self.sc.parallelize(range(100)) - rdd.foreach(heavy_foo) - profiles = self.sc._profile_stats - self.assertEqual(1, len(profiles)) - id, acc, _ = profiles[0] - stats = acc.value + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + id, profiler, _ = profilers[0] + stats = profiler.stats() self.assertTrue(stats is not None) width, stat_list = stats.get_print_list([]) func_names = [func_name for fname, n, func_name in stat_list] @@ -744,6 +761,31 @@ def heavy_foo(x): self.sc.dump_profiles(d) self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) + def test_custom_profiler(self): + class TestCustomProfiler(BasicProfiler): + def show(self, id): + self.result = "Custom formatting" + + self.sc.profiler_collector.profiler_cls = TestCustomProfiler + + self.do_computation() + + profilers = self.sc.profiler_collector.profilers + self.assertEqual(1, len(profilers)) + _, profiler, _ = profilers[0] + self.assertTrue(isinstance(profiler, TestCustomProfiler)) + + self.sc.show_profiles() + self.assertEqual("Custom formatting", profiler.result) + + def do_computation(self): + def heavy_foo(x): + for i in range(1 << 20): + x = 1 + + rdd = self.sc.parallelize(range(100)) + rdd.foreach(heavy_foo) + class ExamplePointUDT(UserDefinedType): """ @@ -806,6 +848,9 @@ def tearDownClass(cls): def setUp(self): self.sqlCtx = SQLContext(self.sc) + self.testData = [Row(key=i, value=str(i)) for i in range(100)] + rdd = self.sc.parallelize(self.testData) + self.df = self.sqlCtx.inferSchema(rdd) def test_udf(self): self.sqlCtx.registerFunction("twoArgs", lambda x, y: len(x) + y, IntegerType()) @@ -821,7 +866,7 @@ def test_udf2(self): def test_udf_with_array_type(self): d = [Row(l=range(3), d={"key": range(5)})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd).registerTempTable("test") + self.sqlCtx.inferSchema(rdd).registerTempTable("test") self.sqlCtx.registerFunction("copylist", lambda l: list(l), ArrayType(IntegerType())) self.sqlCtx.registerFunction("maplen", lambda d: len(d), IntegerType()) [(l1, l2)] = self.sqlCtx.sql("select copylist(l), maplen(d) from test").collect() @@ -839,68 +884,51 @@ def test_broadcast_in_udf(self): def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - srdd = self.sqlCtx.jsonRDD(rdd) - srdd.count() - srdd.collect() - srdd.schemaString() - srdd.schema() + df = self.sqlCtx.jsonRDD(rdd) + df.count() + df.collect() + df.schema() # cache and checkpoint - self.assertFalse(srdd.is_cached) - srdd.persist() - srdd.unpersist() - srdd.cache() - self.assertTrue(srdd.is_cached) - self.assertFalse(srdd.isCheckpointed()) - self.assertEqual(None, srdd.getCheckpointFile()) - - srdd = srdd.coalesce(2, True) - srdd = srdd.repartition(3) - srdd = srdd.distinct() - srdd.intersection(srdd) - self.assertEqual(2, srdd.count()) - - srdd.registerTempTable("temp") - srdd = self.sqlCtx.sql("select foo from temp") - srdd.count() - srdd.collect() - - def test_distinct(self): - rdd = self.sc.parallelize(['{"a": 1}', '{"b": 2}', '{"c": 3}']*10, 10) - srdd = self.sqlCtx.jsonRDD(rdd) - self.assertEquals(srdd.getNumPartitions(), 10) - self.assertEquals(srdd.distinct().count(), 3) - result = srdd.distinct(5) - self.assertEquals(result.getNumPartitions(), 5) - self.assertEquals(result.count(), 3) + self.assertFalse(df.is_cached) + df.persist() + df.unpersist() + df.cache() + self.assertTrue(df.is_cached) + self.assertEqual(2, df.count()) + + df.registerTempTable("temp") + df = self.sqlCtx.sql("select foo from temp") + df.count() + df.collect() def test_apply_schema_to_row(self): - srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) - srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) - self.assertEqual(srdd.collect(), srdd2.collect()) + df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df2 = self.sqlCtx.applySchema(df.map(lambda x: x), df.schema()) + self.assertEqual(df.collect(), df2.collect()) rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) - srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) - self.assertEqual(10, srdd3.count()) + df3 = self.sqlCtx.applySchema(rdd, df.schema()) + self.assertEqual(10, df3.count()) def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - row = srdd.first() + df = self.sqlCtx.inferSchema(rdd) + row = df.head() self.assertEqual(1, len(row.l)) self.assertEqual(1, row.l[0].a) self.assertEqual("2", row.d["key"].d) - l = srdd.map(lambda x: x.l).first() + l = df.map(lambda x: x.l).first() self.assertEqual(1, len(l)) self.assertEqual('s', l[0].b) - d = srdd.map(lambda x: x.d).first() + d = df.map(lambda x: x.d).first() self.assertEqual(1, len(d)) self.assertEqual(1.0, d["key"].c) - row = srdd.map(lambda x: x.d["key"]).first() + row = df.map(lambda x: x.d["key"]).first() self.assertEqual(1.0, row.c) self.assertEqual("2", row.d) @@ -908,26 +936,26 @@ def test_infer_schema(self): d = [Row(l=[], d={}), Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - self.assertEqual([], srdd.map(lambda r: r.l).first()) - self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect()) - srdd.registerTempTable("test") + df = self.sqlCtx.inferSchema(rdd) + self.assertEqual([], df.map(lambda r: r.l).first()) + self.assertEqual([None, ""], df.map(lambda r: r.s).collect()) + df.registerTempTable("test") result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) - srdd2 = self.sqlCtx.inferSchema(rdd, 1.0) - self.assertEqual(srdd.schema(), srdd2.schema()) - self.assertEqual({}, srdd2.map(lambda r: r.d).first()) - self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect()) - srdd2.registerTempTable("test2") + df2 = self.sqlCtx.inferSchema(rdd, 1.0) + self.assertEqual(df.schema(), df2.schema()) + self.assertEqual({}, df2.map(lambda r: r.d).first()) + self.assertEqual([None, ""], df2.map(lambda r: r.s).collect()) + df2.registerTempTable("test2") result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'") - self.assertEqual(1, result.first()[0]) + self.assertEqual(1, result.head()[0]) def test_struct_in_map(self): d = [Row(m={Row(i=1): Row(s="")})] rdd = self.sc.parallelize(d) - srdd = self.sqlCtx.inferSchema(rdd) - k, v = srdd.first().m.items()[0] + df = self.sqlCtx.inferSchema(rdd) + k, v = df.head().m.items()[0] self.assertEqual(1, k.i) self.assertEqual("", v.s) @@ -935,9 +963,9 @@ def test_convert_row_to_dict(self): row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}) self.assertEqual(1, row.asDict()['l'][0].a) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - srdd.registerTempTable("test") - row = self.sqlCtx.sql("select l, d from test").first() + df = self.sqlCtx.inferSchema(rdd) + df.registerTempTable("test") + row = self.sqlCtx.sql("select l, d from test").head() self.assertEqual(1, row.asDict()["l"][0].a) self.assertEqual(1.0, row.asDict()['d']['key'].c) @@ -945,12 +973,12 @@ def test_infer_schema_with_udt(self): from pyspark.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd = self.sqlCtx.inferSchema(rdd) - schema = srdd.schema() + df = self.sqlCtx.inferSchema(rdd) + schema = df.schema() field = [f for f in schema.fields if f.name == "point"][0] self.assertEqual(type(field.dataType), ExamplePointUDT) - srdd.registerTempTable("labeled_point") - point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + df.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) def test_apply_schema_with_udt(self): @@ -959,21 +987,61 @@ def test_apply_schema_with_udt(self): rdd = self.sc.parallelize([row]) schema = StructType([StructField("label", DoubleType(), False), StructField("point", ExamplePointUDT(), False)]) - srdd = self.sqlCtx.applySchema(rdd, schema) - point = srdd.first().point + df = self.sqlCtx.applySchema(rdd, schema) + point = df.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_parquet_with_udt(self): from pyspark.tests import ExamplePoint row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) rdd = self.sc.parallelize([row]) - srdd0 = self.sqlCtx.inferSchema(rdd) + df0 = self.sqlCtx.inferSchema(rdd) output_dir = os.path.join(self.tempdir.name, "labeled_point") - srdd0.saveAsParquetFile(output_dir) - srdd1 = self.sqlCtx.parquetFile(output_dir) - point = srdd1.first().point + df0.saveAsParquetFile(output_dir) + df1 = self.sqlCtx.parquetFile(output_dir) + point = df1.head().point self.assertEquals(point, ExamplePoint(1.0, 2.0)) + def test_column_operators(self): + from pyspark.sql import Column, LongType + ci = self.df.key + cs = self.df.value + c = ci == cs + self.assertTrue(isinstance((- ci - 1 - 2) % 3 * 2.5 / 3.5, Column)) + rcc = (1 + ci), (1 - ci), (1 * ci), (1 / ci), (1 % ci) + self.assertTrue(all(isinstance(c, Column) for c in rcc)) + cb = [ci == 5, ci != 0, ci > 3, ci < 4, ci >= 0, ci <= 7, ci and cs, ci or cs] + self.assertTrue(all(isinstance(c, Column) for c in cb)) + cbit = (ci & ci), (ci | ci), (ci ^ ci), (~ci) + self.assertTrue(all(isinstance(c, Column) for c in cbit)) + css = cs.like('a'), cs.rlike('a'), cs.asc(), cs.desc(), cs.startswith('a'), cs.endswith('a') + self.assertTrue(all(isinstance(c, Column) for c in css)) + self.assertTrue(isinstance(ci.cast(LongType()), Column)) + + def test_column_select(self): + df = self.df + self.assertEqual(self.testData, df.select("*").collect()) + self.assertEqual(self.testData, df.select(df.key, df.value).collect()) + self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect()) + + def test_aggregator(self): + df = self.df + g = df.groupBy() + self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0])) + self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect()) + # TODO(davies): fix aggregators + from pyspark.sql import Aggregator as Agg + # self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first())) + + def test_help_command(self): + # Regression test for SPARK-5464 + rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) + df = self.sqlCtx.jsonRDD(rdd) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(df) + pydoc.render_doc(df.foo) + pydoc.render_doc(df.take(1)) + class InputFormatTests(ReusedPySparkTestCase): diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7e5343c973dc5..8a93c320ec5d3 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,8 +23,6 @@ import time import socket import traceback -import cProfile -import pstats from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry @@ -90,19 +88,15 @@ def main(infile, outfile): command = pickleSer._read_with_length(infile) if isinstance(command, Broadcast): command = pickleSer.loads(command.value) - (func, stats, deserializer, serializer) = command + (func, profiler, deserializer, serializer) = command init_time = time.time() def process(): iterator = deserializer.load_stream(infile) serializer.dump_stream(func(split_index, iterator), outfile) - if stats: - p = cProfile.Profile() - p.runcall(process) - st = pstats.Stats(p) - st.stream = None # make it picklable - stats.add(st.strip_dirs()) + if profiler: + profiler.profile(process) else: process() except Exception: diff --git a/python/run-tests b/python/run-tests index 9ee19ed6e6b26..e91f1a875d356 100755 --- a/python/run-tests +++ b/python/run-tests @@ -57,6 +57,7 @@ function run_core_tests() { PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py" PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py" run_test "pyspark/serializers.py" + run_test "pyspark/profiler.py" run_test "pyspark/shuffle.py" run_test "pyspark/tests.py" } @@ -75,12 +76,19 @@ function run_mllib_tests() { run_test "pyspark/mllib/rand.py" run_test "pyspark/mllib/recommendation.py" run_test "pyspark/mllib/regression.py" - run_test "pyspark/mllib/stat.py" + run_test "pyspark/mllib/stat/_statistics.py" run_test "pyspark/mllib/tree.py" run_test "pyspark/mllib/util.py" run_test "pyspark/mllib/tests.py" } +function run_ml_tests() { + echo "Run ml tests ..." + run_test "pyspark/ml/feature.py" + run_test "pyspark/ml/classification.py" + run_test "pyspark/ml/tests.py" +} + function run_streaming_tests() { echo "Run streaming tests ..." run_test "pyspark/streaming/util.py" @@ -102,6 +110,7 @@ $PYSPARK_PYTHON --version run_core_tests run_sql_tests run_mllib_tests +run_ml_tests run_streaming_tests # Try to test with PyPy diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 91c9c52c3c98a..e594ad868ea1c 100644 --- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -255,14 +255,14 @@ class ReplSuite extends FunSuite { assertDoesNotContain("Exception", output) } - test("SPARK-2576 importing SQLContext.createSchemaRDD.") { + test("SPARK-2576 importing SQLContext.createDataFrame.") { // We need to use local-cluster to test this case. val output = runInterpreter("local-cluster[1,1,512]", """ |val sqlContext = new org.apache.spark.sql.SQLContext(sc) - |import sqlContext.createSchemaRDD + |import sqlContext.createDataFrame |case class TestCaseClass(value: Int) - |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toSchemaRDD.collect + |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDataFrame.collect() """.stripMargin) assertDoesNotContain("error:", output) assertDoesNotContain("Exception", output) diff --git a/sbin/start-thriftserver.sh b/sbin/start-thriftserver.sh index 50e8e06418b07..070cc7a87e6f2 100755 --- a/sbin/start-thriftserver.sh +++ b/sbin/start-thriftserver.sh @@ -26,6 +26,8 @@ set -o posix # Figure out where Spark is installed FWDIR="$(cd "`dirname "$0"`"/..; pwd)" +# NOTE: This exact class name is matched downstream by SparkSubmit. +# Any changes need to be reflected there. CLASS="org.apache.spark.sql.hive.thriftserver.HiveThriftServer2" function usage { diff --git a/sql/README.md b/sql/README.md index d058a6b011d37..61a20916a92aa 100644 --- a/sql/README.md +++ b/sql/README.md @@ -44,7 +44,7 @@ Type in expressions to have them evaluated. Type :help for more information. scala> val query = sql("SELECT * FROM (SELECT * FROM src) a") -query: org.apache.spark.sql.SchemaRDD = +query: org.apache.spark.sql.DataFrame = == Query Plan == == Physical Plan == HiveTableScan [key#10,value#11], (MetastoreRelation default, src, None), None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 191d16fb10b5f..e0db587efb08d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -57,6 +57,11 @@ trait ScalaReflection { case (obj, udt: UserDefinedType[_]) => udt.serialize(obj) case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, arrayType.elementType)) + case (s: Array[_], arrayType: ArrayType) => if (arrayType.elementType.isPrimitive) { + s.toSeq + } else { + s.toSeq.map(convertToCatalyst(_, arrayType.elementType)) + } case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) => convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, mapType.valueType) } @@ -93,7 +98,7 @@ trait ScalaReflection { /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => - s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) + s.toAttributes } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ @@ -140,7 +145,9 @@ trait ScalaReflection { // Need to decide if we actually need a special type here. case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = true) case t if t <:< typeOf[Array[_]] => - sys.error(s"Only Array[Byte] supported now, use Seq instead of $t") + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< typeOf[Seq[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, nullable) = schemaFor(elementType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index eaadbe9fd5099..594a423146d77 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -50,6 +50,7 @@ class SqlParser extends AbstractSparkSQLParser { protected val CACHE = Keyword("CACHE") protected val CASE = Keyword("CASE") protected val CAST = Keyword("CAST") + protected val COALESCE = Keyword("COALESCE") protected val COUNT = Keyword("COUNT") protected val DECIMAL = Keyword("DECIMAL") protected val DESC = Keyword("DESC") @@ -295,6 +296,7 @@ class SqlParser extends AbstractSparkSQLParser { { case s ~ p => Substring(s, p, Literal(Integer.MAX_VALUE)) } | (SUBSTR | SUBSTRING) ~ "(" ~> expression ~ ("," ~> expression) ~ ("," ~> expression) <~ ")" ^^ { case s ~ p ~ l => Substring(s, p, l) } + | COALESCE ~ "(" ~> repsep(expression, ",") <~ ")" ^^ { case exprs => Coalesce(exprs) } | SQRT ~ "(" ~> expression <~ ")" ^^ { case exp => Sqrt(exp) } | ABS ~ "(" ~> expression <~ ")" ^^ { case exp => Abs(exp) } | ident ~ ("(" ~> repsep(expression, ",")) <~ ")" ^^ @@ -348,7 +350,7 @@ class SqlParser extends AbstractSparkSQLParser { ) protected lazy val baseExpression: Parser[Expression] = - ( "*" ^^^ Star(None) + ( "*" ^^^ UnresolvedStar(None) | primary ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7f4cc234dc9cd..cefd70acf3931 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -250,6 +250,12 @@ class Analyzer(catalog: Catalog, Project( projectList.flatMap { case s: Star => s.expand(child.output, resolver) + case Alias(f @ UnresolvedFunction(_, args), name) if containsStar(args) => + val expandedArgs = args.flatMap { + case s: Star => s.expand(child.output, resolver) + case o => o :: Nil + } + Alias(child = f.copy(children = expandedArgs), name)() :: Nil case o => o :: Nil }, child) @@ -273,10 +279,9 @@ class Analyzer(catalog: Catalog, case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") q transformExpressions { - case u @ UnresolvedAttribute(name) - if resolver(name, VirtualColumn.groupingIdName) && - q.isInstanceOf[GroupingAnalytics] => - // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics + case u @ UnresolvedAttribute(name) if resolver(name, VirtualColumn.groupingIdName) && + q.isInstanceOf[GroupingAnalytics] => + // Resolve the virtual column GROUPING__ID for the operator GroupingAnalytics q.asInstanceOf[GroupingAnalytics].gid case u @ UnresolvedAttribute(name) => // Leave unchanged if resolution fails. Hopefully will be resolved next round. @@ -299,7 +304,7 @@ class Analyzer(catalog: Catalog, * Returns true if `exprs` contains a [[Star]]. */ protected def containsStar(exprs: Seq[Expression]): Boolean = - exprs.collect { case _: Star => true}.nonEmpty + exprs.exists(_.collect { case _: Star => true }.nonEmpty) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 6ef8577fd04da..34ef7d28cc7f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -503,6 +503,22 @@ trait HiveTypeCoercion { // Hive lets you do aggregation of timestamps... for some reason case Sum(e @ TimestampType()) => Sum(Cast(e, DoubleType)) case Average(e @ TimestampType()) => Average(Cast(e, DoubleType)) + + // Coalesce should return the first non-null value, which could be any column + // from the list. So we need to make sure the return type is deterministic and + // compatible with every child column. + case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => + val dt: Option[DataType] = Some(NullType) + val types = es.map(_.dataType) + val rt = types.foldLeft(dt)((r, c) => r match { + case None => None + case Some(d) => findTightestCommonType(d, c) + }) + rt match { + case Some(finaldt) => Coalesce(es.map(Cast(_, finaldt))) + case None => + sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala index 22941edef2d46..4c5fb3f45bf49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/MultiInstanceRelation.scala @@ -47,7 +47,7 @@ object NewRelationInstances extends Rule[LogicalPlan] { .toSet plan transform { - case l: MultiInstanceRelation if multiAppearance contains l => l.newInstance + case l: MultiInstanceRelation if multiAppearance.contains(l) => l.newInstance() } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 71a738a0b2ca0..66060289189ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -50,7 +50,7 @@ case class UnresolvedAttribute(name: String) extends Attribute with trees.LeafNo override def qualifiers = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance = this + override def newInstance() = this override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this override def withName(newName: String) = UnresolvedAttribute(name) @@ -77,15 +77,10 @@ case class UnresolvedFunction(name: String, children: Seq[Expression]) extends E /** * Represents all of the input attributes to a given relational operator, for example in - * "SELECT * FROM ...". - * - * @param table an optional table that should be the target of the expansion. If omitted all - * tables' columns are produced. + * "SELECT * FROM ...". A [[Star]] gets automatically expanded during analysis. */ -case class Star( - table: Option[String], - mapFunction: Attribute => Expression = identity[Attribute]) - extends Attribute with trees.LeafNode[Expression] { +trait Star extends Attribute with trees.LeafNode[Expression] { + self: Product => override def name = throw new UnresolvedException(this, "name") override def exprId = throw new UnresolvedException(this, "exprId") @@ -94,29 +89,53 @@ case class Star( override def qualifiers = throw new UnresolvedException(this, "qualifiers") override lazy val resolved = false - override def newInstance = this + override def newInstance() = this override def withNullability(newNullability: Boolean) = this override def withQualifiers(newQualifiers: Seq[String]) = this override def withName(newName: String) = this - def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { + // Star gets expanded at runtime so we never evaluate a Star. + override def eval(input: Row = null): EvaluatedType = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] +} + + +/** + * Represents all of the input attributes to a given relational operator, for example in + * "SELECT * FROM ...". + * + * @param table an optional table that should be the target of the expansion. If omitted all + * tables' columns are produced. + */ +case class UnresolvedStar(table: Option[String]) extends Star { + + override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = { val expandedAttributes: Seq[Attribute] = table match { // If there is no table specified, use all input attributes. case None => input // If there is a table, pick out attributes that are part of this table. case Some(t) => input.filter(_.qualifiers.filter(resolver(_, t)).nonEmpty) } - val mappedAttributes = expandedAttributes.map(mapFunction).zip(input).map { + expandedAttributes.zip(input).map { case (n: NamedExpression, _) => n case (e, originalAttribute) => Alias(e, originalAttribute.name)(qualifiers = originalAttribute.qualifiers) } - mappedAttributes } - // Star gets expanded at runtime so we never evaluate a Star. - override def eval(input: Row = null): EvaluatedType = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - override def toString = table.map(_ + ".").getOrElse("") + "*" } + + +/** + * Represents all the resolved input attributes to a given relational operator. This is used + * in the data frame DSL. + * + * @param expressions Expressions to expand. + */ +case class ResolvedStar(expressions: Seq[NamedExpression]) extends Star { + override def expand(input: Seq[Attribute], resolver: Resolver): Seq[NamedExpression] = expressions + override def toString = expressions.mkString("ResolvedStar(", ", ", ")") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 3035d934ff9f8..f388cd5972bac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -77,6 +77,9 @@ abstract class Attribute extends NamedExpression { * For example the SQL expression "1 + 1 AS a" could be represented as follows: * Alias(Add(Literal(1), Literal(1), "a")() * + * Note that exprId and qualifiers are in a separate parameter list because + * we only pattern match on child and name. + * * @param child the computation being performed * @param name the name to be associated with the result of computing [[child]]. * @param exprId A globally unique id used to check if an [[AttributeReference]] refers to this diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 8df150e2f855f..73ec7a6d114f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -114,7 +114,6 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } override def getString(i: Int): String = { - if (values(i) == null) sys.error("Failed to check null bit for primitive String value.") values(i).asInstanceOf[String] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 310d127506d68..b4c445b3badf1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -141,10 +141,11 @@ object PartialAggregation { // We need to pass all grouping expressions though so the grouping can happen a second // time. However some of them might be unnamed so we alias them allowing them to be // referenced in the second aggregation. - val namedGroupingExpressions: Map[Expression, NamedExpression] = groupingExpressions.map { - case n: NamedExpression => (n, n) - case other => (other, Alias(other, "PartialGroup")()) - }.toMap + val namedGroupingExpressions: Map[Expression, NamedExpression] = + groupingExpressions.filter(!_.isInstanceOf[Literal]).map { + case n: NamedExpression => (n, n) + case other => (other, Alias(other, "PartialGroup")()) + }.toMap // Replace aggregations with a new expression that computes the result from the already // computed partial evaluations and grouping values. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 613f4bb09daf5..5dc0539caec24 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -17,9 +17,24 @@ package org.apache.spark.sql.catalyst.plans +object JoinType { + def apply(typ: String): JoinType = typ.toLowerCase.replace("_", "") match { + case "inner" => Inner + case "outer" | "full" | "fullouter" => FullOuter + case "leftouter" | "left" => LeftOuter + case "rightouter" | "right" => RightOuter + case "leftsemi" => LeftSemi + } +} + sealed abstract class JoinType + case object Inner extends JoinType + case object LeftOuter extends JoinType + case object RightOuter extends JoinType + case object FullOuter extends JoinType + case object LeftSemi extends JoinType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala index 19769986ef58c..d90af45b375e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/TestRelation.scala @@ -19,10 +19,14 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.types.{StructType, StructField} object LocalRelation { - def apply(output: Attribute*) = - new LocalRelation(output) + def apply(output: Attribute*): LocalRelation = new LocalRelation(output) + + def apply(output1: StructField, output: StructField*): LocalRelation = new LocalRelation( + StructType(output1 +: output).toAttributes + ) } case class LocalRelation(output: Seq[Attribute], data: Seq[Product] = Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala index 9f30f40a173e0..6ab99aa38877f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/dataTypes.scala @@ -930,13 +930,13 @@ case class MapType( * * This interface allows a user to make their own classes more interoperable with SparkSQL; * e.g., by creating a [[UserDefinedType]] for a class X, it becomes possible to create - * a SchemaRDD which has class X in the schema. + * a `DataFrame` which has class X in the schema. * * For SparkSQL to recognize UDTs, the UDT must be annotated with * [[SQLUserDefinedType]]. * - * The conversion via `serialize` occurs when instantiating a `SchemaRDD` from another RDD. - * The conversion via `deserialize` occurs when reading from a `SchemaRDD`. + * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. + * The conversion via `deserialize` occurs when reading from a `DataFrame`. */ @DeveloperApi abstract class UserDefinedType[UserType] extends DataType with Serializable { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 5138942a55daa..d0f547d187ecb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -60,10 +60,12 @@ case class OptionalData( case class ComplexData( arrayField: Seq[Int], + arrayField1: Array[Int], arrayFieldContainsNull: Seq[java.lang.Integer], mapField: Map[Int, Long], mapFieldValueContainsNull: Map[Int, java.lang.Long], - structField: PrimitiveData) + structField: PrimitiveData, + nestedArrayField: Array[Array[Int]]) case class GenericData[A]( genericField: A) @@ -131,6 +133,10 @@ class ScalaReflectionSuite extends FunSuite { "arrayField", ArrayType(IntegerType, containsNull = false), nullable = true), + StructField( + "arrayField1", + ArrayType(IntegerType, containsNull = false), + nullable = true), StructField( "arrayFieldContainsNull", ArrayType(IntegerType, containsNull = true), @@ -153,7 +159,10 @@ class ScalaReflectionSuite extends FunSuite { StructField("shortField", ShortType, nullable = false), StructField("byteField", ByteType, nullable = false), StructField("booleanField", BooleanType, nullable = false))), - nullable = true))), + nullable = true), + StructField( + "nestedArrayField", + ArrayType(ArrayType(IntegerType, containsNull = false), containsNull = true)))), nullable = true)) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 3aea337460d42..60060bf02913b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -51,7 +51,9 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { test("union project *") { val plan = (1 to 100) .map(_ => testRelation) - .fold[LogicalPlan](testRelation)((a,b) => a.select(Star(None)).select('a).unionAll(b.select(Star(None)))) + .fold[LogicalPlan](testRelation) { (a, b) => + a.select(UnresolvedStar(None)).select('a).unionAll(b.select(UnresolvedStar(None))) + } assert(caseInsensitiveAnalyze(plan).resolved) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index f5a502b43f80b..85798d0871fda 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -114,4 +114,31 @@ class HiveTypeCoercionSuite extends FunSuite { // Stringify boolean when casting to string. ruleTest(Cast(Literal(false), StringType), If(Literal(false), Literal("true"), Literal("false"))) } + + test("coalesce casts") { + val fac = new HiveTypeCoercion { }.FunctionArgumentConversion + def ruleTest(initial: Expression, transformed: Expression) { + val testRelation = LocalRelation(AttributeReference("a", IntegerType)()) + assert(fac(Project(Seq(Alias(initial, "a")()), testRelation)) == + Project(Seq(Alias(transformed, "a")()), testRelation)) + } + ruleTest( + Coalesce(Literal(1.0) + :: Literal(1) + :: Literal(1.0, FloatType) + :: Nil), + Coalesce(Cast(Literal(1.0), DoubleType) + :: Cast(Literal(1), DoubleType) + :: Cast(Literal(1.0, FloatType), DoubleType) + :: Nil)) + ruleTest( + Coalesce(Literal(1L) + :: Literal(1) + :: Literal(new java.math.BigDecimal("1000000000000000000000")) + :: Nil), + Coalesce(Cast(Literal(1L), DecimalType()) + :: Cast(Literal(1), DecimalType()) + :: Cast(Literal(new java.math.BigDecimal("1000000000000000000000")), DecimalType()) + :: Nil)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala index e715d9434a2ab..f1949aa5dd74b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/CacheManager.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.util.concurrent.locks.ReentrantReadWriteLock +import org.apache.spark.Logging import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel @@ -32,9 +33,10 @@ private case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryR * results when subsequent queries are executed. Data is cached using byte buffers stored in an * InMemoryRelation. This relation is automatically substituted query plans that return the * `sameResult` as the originally cached query. + * + * Internal to Spark SQL. */ -private[sql] trait CacheManager { - self: SQLContext => +private[sql] class CacheManager(sqlContext: SQLContext) extends Logging { @transient private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] @@ -43,13 +45,13 @@ private[sql] trait CacheManager { private val cacheLock = new ReentrantReadWriteLock /** Returns true if the table is currently cached in-memory. */ - def isCached(tableName: String): Boolean = lookupCachedData(table(tableName)).nonEmpty + def isCached(tableName: String): Boolean = lookupCachedData(sqlContext.table(tableName)).nonEmpty /** Caches the specified table in-memory. */ - def cacheTable(tableName: String): Unit = cacheQuery(table(tableName), Some(tableName)) + def cacheTable(tableName: String): Unit = cacheQuery(sqlContext.table(tableName), Some(tableName)) /** Removes the specified table from the in-memory cache. */ - def uncacheTable(tableName: String): Unit = uncacheQuery(table(tableName)) + def uncacheTable(tableName: String): Unit = uncacheQuery(sqlContext.table(tableName)) /** Acquires a read lock on the cache for the duration of `f`. */ private def readLock[A](f: => A): A = { @@ -80,7 +82,7 @@ private[sql] trait CacheManager { * the in-memory columnar representation of the underlying table is expensive. */ private[sql] def cacheQuery( - query: SchemaRDD, + query: DataFrame, tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { val planToCache = query.queryExecution.analyzed @@ -91,16 +93,16 @@ private[sql] trait CacheManager { CachedData( planToCache, InMemoryRelation( - conf.useCompression, - conf.columnBatchSize, + sqlContext.conf.useCompression, + sqlContext.conf.columnBatchSize, storageLevel, query.queryExecution.executedPlan, tableName)) } } - /** Removes the data for the given SchemaRDD from the cache */ - private[sql] def uncacheQuery(query: SchemaRDD, blocking: Boolean = true): Unit = writeLock { + /** Removes the data for the given [[DataFrame]] from the cache */ + private[sql] def uncacheQuery(query: DataFrame, blocking: Boolean = true): Unit = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) require(dataIndex >= 0, s"Table $query is not cached.") @@ -108,9 +110,9 @@ private[sql] trait CacheManager { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given SchemaRDD from the cache if it's cached */ + /** Tries to remove the data for the given [[DataFrame]] from the cache if it's cached */ private[sql] def tryUncacheQuery( - query: SchemaRDD, + query: DataFrame, blocking: Boolean = true): Boolean = writeLock { val planToCache = query.queryExecution.analyzed val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) @@ -122,8 +124,8 @@ private[sql] trait CacheManager { found } - /** Optionally returns cached data for the given SchemaRDD */ - private[sql] def lookupCachedData(query: SchemaRDD): Option[CachedData] = readLock { + /** Optionally returns cached data for the given [[DataFrame]] */ + private[sql] def lookupCachedData(query: DataFrame): Option[CachedData] = readLock { lookupCachedData(query.queryExecution.analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala new file mode 100644 index 0000000000000..174c403059510 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -0,0 +1,584 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.language.implicitConversions + +import org.apache.spark.sql.Dsl.lit +import org.apache.spark.sql.catalyst.analysis.{UnresolvedStar, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.types._ + + +object Column { + /** + * Creates a [[Column]] based on the given column name. Same as [[Dsl.col]]. + */ + def apply(colName: String): Column = new Column(colName) + + /** For internal pattern matching. */ + private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr) +} + + +/** + * A column in a [[DataFrame]]. + * + * `Column` instances can be created by: + * {{{ + * // 1. Select a column out of a DataFrame + * df("colName") + * + * // 2. Create a literal expression + * Literal(1) + * + * // 3. Create new columns from + * }}} + * + */ +// TODO: Improve documentation. +class Column( + sqlContext: Option[SQLContext], + plan: Option[LogicalPlan], + protected[sql] val expr: Expression) + extends DataFrame(sqlContext, plan) with ExpressionApi { + + /** Turns a Catalyst expression into a `Column`. */ + protected[sql] def this(expr: Expression) = this(None, None, expr) + + /** + * Creates a new `Column` expression based on a column or attribute name. + * The resolution of this is the same as SQL. For example: + * + * - "colName" becomes an expression selecting the column named "colName". + * - "*" becomes an expression selecting all columns. + * - "df.*" becomes an expression selecting all columns in data frame "df". + */ + def this(name: String) = this(name match { + case "*" => UnresolvedStar(None) + case _ if name.endsWith(".*") => UnresolvedStar(Some(name.substring(0, name.length - 2))) + case _ => UnresolvedAttribute(name) + }) + + override def isComputable: Boolean = sqlContext.isDefined && plan.isDefined + + /** + * An implicit conversion function internal to this class. This function creates a new Column + * based on an expression. If the expression itself is not named, it aliases the expression + * by calling it "col". + */ + private[this] implicit def toColumn(expr: Expression): Column = { + val projectedPlan = plan.map { p => + Project(Seq(expr match { + case named: NamedExpression => named + case unnamed: Expression => Alias(unnamed, "col")() + }), p) + } + new Column(sqlContext, projectedPlan, expr) + } + + /** + * Unary minus, i.e. negate the expression. + * {{{ + * // Select the amount column and negates all values. + * df.select( -df("amount") ) + * }}} + */ + override def unary_- : Column = UnaryMinus(expr) + + /** + * Bitwise NOT. + * {{{ + * // Select the flags column and negate every bit. + * df.select( ~df("flags") ) + * }}} + */ + override def unary_~ : Column = BitwiseNot(expr) + + /** + * Inversion of boolean expression, i.e. NOT. + * {{ + * // Select rows that are not active (isActive === false) + * df.select( !df("isActive") ) + * }} + */ + override def unary_! : Column = Not(expr) + + + /** + * Equality test with an expression. + * {{{ + * // The following two both select rows in which colA equals colB. + * df.select( df("colA") === df("colB") ) + * df.select( df("colA".equalTo(df("colB")) ) + * }}} + */ + override def === (other: Column): Column = EqualTo(expr, other.expr) + + /** + * Equality test with a literal value. + * {{{ + * // The following two both select rows in which colA is "Zaharia". + * df.select( df("colA") === "Zaharia") + * df.select( df("colA".equalTo("Zaharia") ) + * }}} + */ + override def === (literal: Any): Column = this === lit(literal) + + /** + * Equality test with an expression. + * {{{ + * // The following two both select rows in which colA equals colB. + * df.select( df("colA") === df("colB") ) + * df.select( df("colA".equalTo(df("colB")) ) + * }}} + */ + override def equalTo(other: Column): Column = this === other + + /** + * Equality test with a literal value. + * {{{ + * // The following two both select rows in which colA is "Zaharia". + * df.select( df("colA") === "Zaharia") + * df.select( df("colA".equalTo("Zaharia") ) + * }}} + */ + override def equalTo(literal: Any): Column = this === literal + + /** + * Inequality test with an expression. + * {{{ + * // The following two both select rows in which colA does not equal colB. + * df.select( df("colA") !== df("colB") ) + * df.select( !(df("colA") === df("colB")) ) + * }}} + */ + override def !== (other: Column): Column = Not(EqualTo(expr, other.expr)) + + /** + * Inequality test with a literal value. + * {{{ + * // The following two both select rows in which colA does not equal equal 15. + * df.select( df("colA") !== 15 ) + * df.select( !(df("colA") === 15) ) + * }}} + */ + override def !== (literal: Any): Column = this !== lit(literal) + + /** + * Greater than an expression. + * {{{ + * // The following selects people older than 21. + * people.select( people("age") > Literal(21) ) + * }}} + */ + override def > (other: Column): Column = GreaterThan(expr, other.expr) + + /** + * Greater than a literal value. + * {{{ + * // The following selects people older than 21. + * people.select( people("age") > 21 ) + * }}} + */ + override def > (literal: Any): Column = this > lit(literal) + + /** + * Less than an expression. + * {{{ + * // The following selects people younger than 21. + * people.select( people("age") < Literal(21) ) + * }}} + */ + override def < (other: Column): Column = LessThan(expr, other.expr) + + /** + * Less than a literal value. + * {{{ + * // The following selects people younger than 21. + * people.select( people("age") < 21 ) + * }}} + */ + override def < (literal: Any): Column = this < lit(literal) + + /** + * Less than or equal to an expression. + * {{{ + * // The following selects people age 21 or younger than 21. + * people.select( people("age") <= Literal(21) ) + * }}} + */ + override def <= (other: Column): Column = LessThanOrEqual(expr, other.expr) + + /** + * Less than or equal to a literal value. + * {{{ + * // The following selects people age 21 or younger than 21. + * people.select( people("age") <= 21 ) + * }}} + */ + override def <= (literal: Any): Column = this <= lit(literal) + + /** + * Greater than or equal to an expression. + * {{{ + * // The following selects people age 21 or older than 21. + * people.select( people("age") >= Literal(21) ) + * }}} + */ + override def >= (other: Column): Column = GreaterThanOrEqual(expr, other.expr) + + /** + * Greater than or equal to a literal value. + * {{{ + * // The following selects people age 21 or older than 21. + * people.select( people("age") >= 21 ) + * }}} + */ + override def >= (literal: Any): Column = this >= lit(literal) + + /** + * Equality test with an expression that is safe for null values. + */ + override def <=> (other: Column): Column = other match { + case null => EqualNullSafe(expr, lit(null).expr) + case _ => EqualNullSafe(expr, other.expr) + } + + /** + * Equality test with a literal value that is safe for null values. + */ + override def <=> (literal: Any): Column = this <=> lit(literal) + + /** + * True if the current expression is null. + */ + override def isNull: Column = IsNull(expr) + + /** + * True if the current expression is NOT null. + */ + override def isNotNull: Column = IsNotNull(expr) + + /** + * Boolean OR with an expression. + * {{{ + * // The following selects people that are in school or employed. + * people.select( people("inSchool") || people("isEmployed") ) + * }}} + */ + override def || (other: Column): Column = Or(expr, other.expr) + + /** + * Boolean OR with a literal value. + * {{{ + * // The following selects everything. + * people.select( people("inSchool") || true ) + * }}} + */ + override def || (literal: Boolean): Column = this || lit(literal) + + /** + * Boolean AND with an expression. + * {{{ + * // The following selects people that are in school and employed at the same time. + * people.select( people("inSchool") && people("isEmployed") ) + * }}} + */ + override def && (other: Column): Column = And(expr, other.expr) + + /** + * Boolean AND with a literal value. + * {{{ + * // The following selects people that are in school. + * people.select( people("inSchool") && true ) + * }}} + */ + override def && (literal: Boolean): Column = this && lit(literal) + + /** + * Bitwise AND with an expression. + */ + override def & (other: Column): Column = BitwiseAnd(expr, other.expr) + + /** + * Bitwise AND with a literal value. + */ + override def & (literal: Any): Column = this & lit(literal) + + /** + * Bitwise OR with an expression. + */ + override def | (other: Column): Column = BitwiseOr(expr, other.expr) + + /** + * Bitwise OR with a literal value. + */ + override def | (literal: Any): Column = this | lit(literal) + + /** + * Bitwise XOR with an expression. + */ + override def ^ (other: Column): Column = BitwiseXor(expr, other.expr) + + /** + * Bitwise XOR with a literal value. + */ + override def ^ (literal: Any): Column = this ^ lit(literal) + + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and weight. + * people.select( people("height") + people("weight") ) + * }}} + */ + override def + (other: Column): Column = Add(expr, other.expr) + + /** + * Sum of this expression and another expression. + * {{{ + * // The following selects the sum of a person's height and 10. + * people.select( people("height") + 10 ) + * }}} + */ + override def + (literal: Any): Column = this + lit(literal) + + /** + * Subtraction. Subtract the other expression from this expression. + * {{{ + * // The following selects the difference between people's height and their weight. + * people.select( people("height") - people("weight") ) + * }}} + */ + override def - (other: Column): Column = Subtract(expr, other.expr) + + /** + * Subtraction. Subtract a literal value from this expression. + * {{{ + * // The following selects a person's height and subtract it by 10. + * people.select( people("height") - 10 ) + * }}} + */ + override def - (literal: Any): Column = this - lit(literal) + + /** + * Multiplication of this expression and another expression. + * {{{ + * // The following multiplies a person's height by their weight. + * people.select( people("height") * people("weight") ) + * }}} + */ + override def * (other: Column): Column = Multiply(expr, other.expr) + + /** + * Multiplication this expression and a literal value. + * {{{ + * // The following multiplies a person's height by 10. + * people.select( people("height") * 10 ) + * }}} + */ + override def * (literal: Any): Column = this * lit(literal) + + /** + * Division this expression by another expression. + * {{{ + * // The following divides a person's height by their weight. + * people.select( people("height") / people("weight") ) + * }}} + */ + override def / (other: Column): Column = Divide(expr, other.expr) + + /** + * Division this expression by a literal value. + * {{{ + * // The following divides a person's height by 10. + * people.select( people("height") / 10 ) + * }}} + */ + override def / (literal: Any): Column = this / lit(literal) + + /** + * Modulo (a.k.a. remainder) expression. + */ + override def % (other: Column): Column = Remainder(expr, other.expr) + + /** + * Modulo (a.k.a. remainder) expression. + */ + override def % (literal: Any): Column = this % lit(literal) + + + /** + * A boolean expression that is evaluated to true if the value of this expression is contained + * by the evaluated values of the arguments. + */ + @scala.annotation.varargs + override def in(list: Column*): Column = In(expr, list.map(_.expr)) + + override def like(literal: String): Column = Like(expr, lit(literal).expr) + + override def rlike(literal: String): Column = RLike(expr, lit(literal).expr) + + /** + * An expression that gets an item at position `ordinal` out of an array. + */ + override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal)) + + /** + * An expression that gets a field by name in a [[StructField]]. + */ + override def getField(fieldName: String): Column = GetField(expr, fieldName) + + /** + * An expression that returns a substring. + * @param startPos expression for the starting position. + * @param len expression for the length of the substring. + */ + override def substr(startPos: Column, len: Column): Column = + Substring(expr, startPos.expr, len.expr) + + /** + * An expression that returns a substring. + * @param startPos starting position. + * @param len length of the substring. + */ + override def substr(startPos: Int, len: Int): Column = this.substr(lit(startPos), lit(len)) + + override def contains(other: Column): Column = Contains(expr, other.expr) + + override def contains(literal: Any): Column = this.contains(lit(literal)) + + + override def startsWith(other: Column): Column = StartsWith(expr, other.expr) + + override def startsWith(literal: String): Column = this.startsWith(lit(literal)) + + override def endsWith(other: Column): Column = EndsWith(expr, other.expr) + + override def endsWith(literal: String): Column = this.endsWith(lit(literal)) + + /** + * Gives the column an alias. + * {{{ + * // Renames colA to colB in select output. + * df.select($"colA".as("colB")) + * }}} + */ + override def as(alias: String): Column = Alias(expr, alias)() + + /** + * Casts the column to a different data type. + * {{{ + * // Casts colA to IntegerType. + * import org.apache.spark.sql.types.IntegerType + * df.select(df("colA").cast(IntegerType)) + * + * // equivalent to + * df.select(df("colA").cast("int")) + * }}} + */ + override def cast(to: DataType): Column = Cast(expr, to) + + /** + * Casts the column to a different data type, using the canonical string representation + * of the type. The supported types are: `string`, `boolean`, `byte`, `short`, `int`, `long`, + * `float`, `double`, `decimal`, `date`, `timestamp`. + * {{{ + * // Casts colA to integer. + * df.select(df("colA").cast("int")) + * }}} + */ + override def cast(to: String): Column = Cast(expr, to.toLowerCase match { + case "string" => StringType + case "boolean" => BooleanType + case "byte" => ByteType + case "short" => ShortType + case "int" => IntegerType + case "long" => LongType + case "float" => FloatType + case "double" => DoubleType + case "decimal" => DecimalType.Unlimited + case "date" => DateType + case "timestamp" => TimestampType + case _ => throw new RuntimeException(s"""Unsupported cast type: "$to"""") + }) + + override def desc: Column = SortOrder(expr, Descending) + + override def asc: Column = SortOrder(expr, Ascending) +} + + +class ColumnName(name: String) extends Column(name) { + + /** Creates a new AttributeReference of type boolean */ + def boolean: StructField = StructField(name, BooleanType) + + /** Creates a new AttributeReference of type byte */ + def byte: StructField = StructField(name, ByteType) + + /** Creates a new AttributeReference of type short */ + def short: StructField = StructField(name, ShortType) + + /** Creates a new AttributeReference of type int */ + def int: StructField = StructField(name, IntegerType) + + /** Creates a new AttributeReference of type long */ + def long: StructField = StructField(name, LongType) + + /** Creates a new AttributeReference of type float */ + def float: StructField = StructField(name, FloatType) + + /** Creates a new AttributeReference of type double */ + def double: StructField = StructField(name, DoubleType) + + /** Creates a new AttributeReference of type string */ + def string: StructField = StructField(name, StringType) + + /** Creates a new AttributeReference of type date */ + def date: StructField = StructField(name, DateType) + + /** Creates a new AttributeReference of type decimal */ + def decimal: StructField = StructField(name, DecimalType.Unlimited) + + /** Creates a new AttributeReference of type decimal */ + def decimal(precision: Int, scale: Int): StructField = + StructField(name, DecimalType(precision, scale)) + + /** Creates a new AttributeReference of type timestamp */ + def timestamp: StructField = StructField(name, TimestampType) + + /** Creates a new AttributeReference of type binary */ + def binary: StructField = StructField(name, BinaryType) + + /** Creates a new AttributeReference of type array */ + def array(dataType: DataType): StructField = StructField(name, ArrayType(dataType)) + + /** Creates a new AttributeReference of type map */ + def map(keyType: DataType, valueType: DataType): StructField = + map(MapType(keyType, valueType)) + + def map(mapType: MapType): StructField = StructField(name, mapType) + + /** Creates a new AttributeReference of type struct */ + def struct(fields: StructField*): StructField = struct(StructType(fields)) + + def struct(structType: StructType): StructField = StructField(name, structType) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala new file mode 100644 index 0000000000000..1096e396591df --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -0,0 +1,664 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import java.util.{List => JList} + +import scala.language.implicitConversions +import scala.reflect.ClassTag +import scala.collection.JavaConversions._ + +import com.fasterxml.jackson.core.JsonFactory + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.python.SerDeUtil +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{ResolvedStar, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} +import org.apache.spark.sql.json.JsonRDD +import org.apache.spark.sql.types.{NumericType, StructType} +import org.apache.spark.util.Utils + + +/** + * A collection of rows that have the same columns. + * + * A [[DataFrame]] is equivalent to a relational table in Spark SQL, and can be created using + * various functions in [[SQLContext]]. + * {{{ + * val people = sqlContext.parquetFile("...") + * }}} + * + * Once created, it can be manipulated using the various domain-specific-language (DSL) functions + * defined in: [[DataFrame]] (this class), [[Column]], [[Dsl]] for the DSL. + * + * To select a column from the data frame, use the apply method: + * {{{ + * val ageCol = people("age") // in Scala + * Column ageCol = people.apply("age") // in Java + * }}} + * + * Note that the [[Column]] type can also be manipulated through its various functions. + * {{ + * // The following creates a new column that increases everybody's age by 10. + * people("age") + 10 // in Scala + * }} + * + * A more concrete example: + * {{{ + * // To create DataFrame using SQLContext + * val people = sqlContext.parquetFile("...") + * val department = sqlContext.parquetFile("...") + * + * people.filter("age" > 30) + * .join(department, people("deptId") === department("id")) + * .groupBy(department("name"), "gender") + * .agg(avg(people("salary")), max(people("age"))) + * }}} + */ +// TODO: Improve documentation. +class DataFrame protected[sql]( + val sqlContext: SQLContext, + private val baseLogicalPlan: LogicalPlan, + operatorsEnabled: Boolean) + extends DataFrameSpecificApi with RDDApi[Row] { + + protected[sql] def this(sqlContext: Option[SQLContext], plan: Option[LogicalPlan]) = + this(sqlContext.orNull, plan.orNull, sqlContext.isDefined && plan.isDefined) + + protected[sql] def this(sqlContext: SQLContext, plan: LogicalPlan) = this(sqlContext, plan, true) + + @transient protected[sql] lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) + + @transient protected[sql] val logicalPlan: LogicalPlan = baseLogicalPlan match { + // For various commands (like DDL) and queries with side effects, we force query optimization to + // happen right away to let these side effects take place eagerly. + case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case _ => + baseLogicalPlan + } + + /** + * An implicit conversion function internal to this class for us to avoid doing + * "new DataFrame(...)" everywhere. + */ + private implicit def logicalPlanToDataFrame(logicalPlan: LogicalPlan): DataFrame = { + new DataFrame(sqlContext, logicalPlan, true) + } + + /** Returns the list of numeric columns, useful for doing aggregation. */ + protected[sql] def numericColumns: Seq[Expression] = { + schema.fields.filter(_.dataType.isInstanceOf[NumericType]).map { n => + queryExecution.analyzed.resolve(n.name, sqlContext.analyzer.resolver).get + } + } + + /** Resolves a column name into a Catalyst [[NamedExpression]]. */ + protected[sql] def resolve(colName: String): NamedExpression = { + queryExecution.analyzed.resolve(colName, sqlContext.analyzer.resolver).getOrElse { + throw new RuntimeException( + s"""Cannot resolve column name "$colName" among (${schema.fieldNames.mkString(", ")})""") + } + } + + /** Left here for compatibility reasons. */ + @deprecated("1.3.0", "use toDataFrame") + def toSchemaRDD: DataFrame = this + + /** + * Returns the object itself. Used to force an implicit conversion from RDD to DataFrame in Scala. + */ + def toDataFrame: DataFrame = this + + /** + * Returns a new [[DataFrame]] with columns renamed. This can be quite convenient in conversion + * from a RDD of tuples into a [[DataFrame]] with meaningful names. For example: + * {{{ + * val rdd: RDD[(Int, String)] = ... + * rdd.toDataFrame // this implicit conversion creates a DataFrame with column name _1 and _2 + * rdd.toDataFrame("id", "name") // this creates a DataFrame with column name "id" and "name" + * }}} + */ + @scala.annotation.varargs + def toDataFrame(colName: String, colNames: String*): DataFrame = { + val newNames = colName +: colNames + require(schema.size == newNames.size, + "The number of columns doesn't match.\n" + + "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" + + "New column names: " + newNames.mkString(", ")) + + val newCols = schema.fieldNames.zip(newNames).map { case (oldName, newName) => + apply(oldName).as(newName) + } + select(newCols :_*) + } + + /** Returns the schema of this [[DataFrame]]. */ + override def schema: StructType = queryExecution.analyzed.schema + + /** Returns all column names and their data types as an array. */ + override def dtypes: Array[(String, String)] = schema.fields.map { field => + (field.name, field.dataType.toString) + } + + /** Returns all column names as an array. */ + override def columns: Array[String] = schema.fields.map(_.name) + + /** Prints the schema to the console in a nice tree format. */ + override def printSchema(): Unit = println(schema.treeString) + + /** + * Cartesian join with another [[DataFrame]]. + * + * Note that cartesian joins are very expensive without an extra filter that can be pushed down. + * + * @param right Right side of the join operation. + */ + override def join(right: DataFrame): DataFrame = { + Join(logicalPlan, right.logicalPlan, joinType = Inner, None) + } + + /** + * Inner join with another [[DataFrame]], using the given join expression. + * + * {{{ + * // The following two are equivalent: + * df1.join(df2, $"df1Key" === $"df2Key") + * df1.join(df2).where($"df1Key" === $"df2Key") + * }}} + */ + override def join(right: DataFrame, joinExprs: Column): DataFrame = { + Join(logicalPlan, right.logicalPlan, Inner, Some(joinExprs.expr)) + } + + /** + * Join with another [[DataFrame]], usin g the given join expression. The following performs + * a full outer join between `df1` and `df2`. + * + * {{{ + * df1.join(df2, "outer", $"df1Key" === $"df2Key") + * }}} + * + * @param right Right side of the join. + * @param joinExprs Join expression. + * @param joinType One of: `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + */ + override def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame = { + Join(logicalPlan, right.logicalPlan, JoinType(joinType), Some(joinExprs.expr)) + } + + /** + * Returns a new [[DataFrame]] sorted by the specified column, all in ascending order. + * {{{ + * // The following 3 are equivalent + * df.sort("sortcol") + * df.sort($"sortcol") + * df.sort($"sortcol".asc) + * }}} + */ + @scala.annotation.varargs + override def sort(sortCol: String, sortCols: String*): DataFrame = { + orderBy(apply(sortCol), sortCols.map(apply) :_*) + } + + /** + * Returns a new [[DataFrame]] sorted by the given expressions. For example: + * {{{ + * df.sort($"col1", $"col2".desc) + * }}} + */ + @scala.annotation.varargs + override def sort(sortExpr: Column, sortExprs: Column*): DataFrame = { + val sortOrder: Seq[SortOrder] = (sortExpr +: sortExprs).map { col => + col.expr match { + case expr: SortOrder => + expr + case expr: Expression => + SortOrder(expr, Ascending) + } + } + Sort(sortOrder, global = true, logicalPlan) + } + + /** + * Returns a new [[DataFrame]] sorted by the given expressions. + * This is an alias of the `sort` function. + */ + @scala.annotation.varargs + override def orderBy(sortCol: String, sortCols: String*): DataFrame = { + sort(sortCol, sortCols :_*) + } + + /** + * Returns a new [[DataFrame]] sorted by the given expressions. + * This is an alias of the `sort` function. + */ + @scala.annotation.varargs + override def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame = { + sort(sortExpr, sortExprs :_*) + } + + /** + * Selects column based on the column name and return it as a [[Column]]. + */ + override def apply(colName: String): Column = colName match { + case "*" => + new Column(ResolvedStar(schema.fieldNames.map(resolve))) + case _ => + val expr = resolve(colName) + new Column(Some(sqlContext), Some(Project(Seq(expr), logicalPlan)), expr) + } + + /** + * Selects a set of expressions, wrapped in a Product. + * {{{ + * // The following two are equivalent: + * df.apply(($"colA", $"colB" + 1)) + * df.select($"colA", $"colB" + 1) + * }}} + */ + override def apply(projection: Product): DataFrame = { + require(projection.productArity >= 1) + select(projection.productIterator.map { + case c: Column => c + case o: Any => new Column(Some(sqlContext), None, Literal(o)) + }.toSeq :_*) + } + + /** + * Returns a new [[DataFrame]] with an alias set. + */ + override def as(name: String): DataFrame = Subquery(name, logicalPlan) + + /** + * Selects a set of expressions. + * {{{ + * df.select($"colA", $"colB" + 1) + * }}} + */ + @scala.annotation.varargs + override def select(cols: Column*): DataFrame = { + val exprs = cols.zipWithIndex.map { + case (Column(expr: NamedExpression), _) => + expr + case (Column(expr: Expression), _) => + Alias(expr, expr.toString)() + } + Project(exprs.toSeq, logicalPlan) + } + + /** + * Selects a set of columns. This is a variant of `select` that can only select + * existing columns using column names (i.e. cannot construct expressions). + * + * {{{ + * // The following two are equivalent: + * df.select("colA", "colB") + * df.select($"colA", $"colB") + * }}} + */ + @scala.annotation.varargs + override def select(col: String, cols: String*): DataFrame = { + select((col +: cols).map(new Column(_)) :_*) + } + + /** + * Filters rows using the given condition. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def filter(condition: Column): DataFrame = { + Filter(condition.expr, logicalPlan) + } + + /** + * Filters rows using the given condition. This is an alias for `filter`. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def where(condition: Column): DataFrame = filter(condition) + + /** + * Filters rows using the given condition. This is a shorthand meant for Scala. + * {{{ + * // The following are equivalent: + * peopleDf.filter($"age" > 15) + * peopleDf.where($"age" > 15) + * peopleDf($"age" > 15) + * }}} + */ + override def apply(condition: Column): DataFrame = filter(condition) + + /** + * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedDataFrame]] for all the available aggregate functions. + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy($"department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + */ + @scala.annotation.varargs + override def groupBy(cols: Column*): GroupedDataFrame = { + new GroupedDataFrame(this, cols.map(_.expr)) + } + + /** + * Groups the [[DataFrame]] using the specified columns, so we can run aggregation on them. + * See [[GroupedDataFrame]] for all the available aggregate functions. + * + * This is a variant of groupBy that can only group by existing columns using column names + * (i.e. cannot construct expressions). + * + * {{{ + * // Compute the average for all numeric columns grouped by department. + * df.groupBy("department").avg() + * + * // Compute the max age and average salary, grouped by department and gender. + * df.groupBy($"department", $"gender").agg(Map( + * "salary" -> "avg", + * "age" -> "max" + * )) + * }}} + */ + @scala.annotation.varargs + override def groupBy(col1: String, cols: String*): GroupedDataFrame = { + val colNames: Seq[String] = col1 +: cols + new GroupedDataFrame(this, colNames.map(colName => resolve(colName))) + } + + /** + * Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }} + */ + override def agg(exprs: Map[String, String]): DataFrame = groupBy().agg(exprs) + + /** + * Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(Map("age" -> "max", "salary" -> "avg")) + * df.groupBy().agg(Map("age" -> "max", "salary" -> "avg")) + * }} + */ + override def agg(exprs: java.util.Map[String, String]): DataFrame = agg(exprs.toMap) + + /** + * Aggregates on the entire [[DataFrame]] without groups. + * {{ + * // df.agg(...) is a shorthand for df.groupBy().agg(...) + * df.agg(max($"age"), avg($"salary")) + * df.groupBy().agg(max($"age"), avg($"salary")) + * }} + */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = groupBy().agg(expr, exprs :_*) + + /** + * Returns a new [[DataFrame]] by taking the first `n` rows. The difference between this function + * and `head` is that `head` returns an array while `limit` returns a new [[DataFrame]]. + */ + override def limit(n: Int): DataFrame = Limit(Literal(n), logicalPlan) + + /** + * Returns a new [[DataFrame]] containing union of rows in this frame and another frame. + * This is equivalent to `UNION ALL` in SQL. + */ + override def unionAll(other: DataFrame): DataFrame = Union(logicalPlan, other.logicalPlan) + + /** + * Returns a new [[DataFrame]] containing rows only in both this frame and another frame. + * This is equivalent to `INTERSECT` in SQL. + */ + override def intersect(other: DataFrame): DataFrame = Intersect(logicalPlan, other.logicalPlan) + + /** + * Returns a new [[DataFrame]] containing rows in this frame but not in another frame. + * This is equivalent to `EXCEPT` in SQL. + */ + override def except(other: DataFrame): DataFrame = Except(logicalPlan, other.logicalPlan) + + /** + * Returns a new [[DataFrame]] by sampling a fraction of rows. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + * @param seed Seed for sampling. + */ + override def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame = { + Sample(fraction, withReplacement, seed, logicalPlan) + } + + /** + * Returns a new [[DataFrame]] by sampling a fraction of rows, using a random seed. + * + * @param withReplacement Sample with replacement or not. + * @param fraction Fraction of rows to generate. + */ + override def sample(withReplacement: Boolean, fraction: Double): DataFrame = { + sample(withReplacement, fraction, Utils.random.nextLong) + } + + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns a new [[DataFrame]] by adding a column. + */ + override def addColumn(colName: String, col: Column): DataFrame = { + select(Column("*"), col.as(colName)) + } + + /** + * Returns the first `n` rows. + */ + override def head(n: Int): Array[Row] = limit(n).collect() + + /** + * Returns the first row. + */ + override def head(): Row = head(1).head + + /** + * Returns the first row. Alias for head(). + */ + override def first(): Row = head() + + /** + * Returns a new RDD by applying a function to all rows of this DataFrame. + */ + override def map[R: ClassTag](f: Row => R): RDD[R] = { + rdd.map(f) + } + + /** + * Returns a new RDD by first applying a function to all rows of this [[DataFrame]], + * and then flattening the results. + */ + override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = rdd.flatMap(f) + + /** + * Returns a new RDD by applying a function to each partition of this DataFrame. + */ + override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): RDD[R] = { + rdd.mapPartitions(f) + } + + /** + * Applies a function `f` to all rows. + */ + override def foreach(f: Row => Unit): Unit = rdd.foreach(f) + + /** + * Applies a function f to each partition of this [[DataFrame]]. + */ + override def foreachPartition(f: Iterator[Row] => Unit): Unit = rdd.foreachPartition(f) + + /** + * Returns the first `n` rows in the [[DataFrame]]. + */ + override def take(n: Int): Array[Row] = head(n) + + /** + * Returns an array that contains all of [[Row]]s in this [[DataFrame]]. + */ + override def collect(): Array[Row] = rdd.collect() + + /** + * Returns a Java list that contains all of [[Row]]s in this [[DataFrame]]. + */ + override def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(rdd.collect() :_*) + + /** + * Returns the number of rows in the [[DataFrame]]. + */ + override def count(): Long = groupBy().count().rdd.collect().head.getLong(0) + + /** + * Returns a new [[DataFrame]] that has exactly `numPartitions` partitions. + */ + override def repartition(numPartitions: Int): DataFrame = { + sqlContext.applySchema(rdd.repartition(numPartitions), schema) + } + + override def persist(): this.type = { + sqlContext.cacheManager.cacheQuery(this) + this + } + + override def persist(newLevel: StorageLevel): this.type = { + sqlContext.cacheManager.cacheQuery(this, None, newLevel) + this + } + + override def unpersist(blocking: Boolean): this.type = { + sqlContext.cacheManager.tryUncacheQuery(this, blocking) + this + } + + ///////////////////////////////////////////////////////////////////////////// + // I/O + ///////////////////////////////////////////////////////////////////////////// + + /** + * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s. + */ + override def rdd: RDD[Row] = { + val schema = this.schema + queryExecution.executedPlan.execute().map(ScalaReflection.convertRowToScala(_, schema)) + } + + /** + * Registers this RDD as a temporary table using the given name. The lifetime of this temporary + * table is tied to the [[SQLContext]] that was used to create this DataFrame. + * + * @group schema + */ + override def registerTempTable(tableName: String): Unit = { + sqlContext.registerRDDAsTable(this, tableName) + } + + /** + * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. + * Files that are written out using this method can be read back in as a [[DataFrame]] + * using the `parquetFile` function in [[SQLContext]]. + */ + override def saveAsParquetFile(path: String): Unit = { + sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd + } + + /** + * :: Experimental :: + * Creates a table from the the contents of this DataFrame. This will fail if the table already + * exists. + * + * Note that this currently only works with DataFrames that are created from a HiveContext as + * there is no notion of a persisted catalog in a standard SQL context. Instead you can write + * an RDD out to a parquet file, and then register that file as a table. This "table" can then + * be the target of an `insertInto`. + */ + @Experimental + override def saveAsTable(tableName: String): Unit = { + sqlContext.executePlan( + CreateTableAsSelect(None, tableName, logicalPlan, allowExisting = false)).toRdd + } + + /** + * :: Experimental :: + * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. + */ + @Experimental + override def insertInto(tableName: String, overwrite: Boolean): Unit = { + sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), + Map.empty, logicalPlan, overwrite)).toRdd + } + + /** + * Returns the content of the [[DataFrame]] as a RDD of JSON strings. + */ + override def toJSON: RDD[String] = { + val rowSchema = this.schema + this.mapPartitions { iter => + val jsonFactory = new JsonFactory() + iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) + } + } + + //////////////////////////////////////////////////////////////////////////// + // for Python API + //////////////////////////////////////////////////////////////////////////// + /** + * A helpful function for Py4j, convert a list of Column to an array + */ + protected[sql] def toColumnArray(cols: JList[Column]): Array[Column] = { + cols.toList.toArray + } + + /** + * Converts a JavaRDD to a PythonRDD. + */ + protected[sql] def javaToPython: JavaRDD[Array[Byte]] = { + val fieldTypes = schema.fields.map(_.dataType) + val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() + SerDeUtil.javaToPython(jrdd) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala new file mode 100644 index 0000000000000..3499956023d11 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dsl.scala @@ -0,0 +1,529 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.{TypeTag, typeTag} + +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ + + +/** + * Domain specific functions available for [[DataFrame]]. + */ +object Dsl { + + /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) + + // /** + // * An implicit conversion that turns a RDD of product into a [[DataFrame]]. + // * + // * This method requires an implicit SQLContext in scope. For example: + // * {{{ + // * implicit val sqlContext: SQLContext = ... + // * val rdd: RDD[(Int, String)] = ... + // * rdd.toDataFrame // triggers the implicit here + // * }}} + // */ + // implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit context: SQLContext) + // : DataFrame = { + // context.createDataFrame(rdd) + // } + + /** Converts $"col name" into an [[Column]]. */ + implicit class StringToColumn(val sc: StringContext) extends AnyVal { + def $(args: Any*): ColumnName = { + new ColumnName(sc.s(args :_*)) + } + } + + private[this] implicit def toColumn(expr: Expression): Column = new Column(expr) + + /** + * Returns a [[Column]] based on the given column name. + */ + def col(colName: String): Column = new Column(colName) + + /** + * Returns a [[Column]] based on the given column name. Alias of [[col]]. + */ + def column(colName: String): Column = new Column(colName) + + /** + * Creates a [[Column]] of literal value. + */ + def lit(literal: Any): Column = { + if (literal.isInstanceOf[Symbol]) { + return new ColumnName(literal.asInstanceOf[Symbol].name) + } + + val literalExpr = literal match { + case v: Boolean => Literal(v, BooleanType) + case v: Byte => Literal(v, ByteType) + case v: Short => Literal(v, ShortType) + case v: Int => Literal(v, IntegerType) + case v: Long => Literal(v, LongType) + case v: Float => Literal(v, FloatType) + case v: Double => Literal(v, DoubleType) + case v: String => Literal(v, StringType) + case v: BigDecimal => Literal(Decimal(v), DecimalType.Unlimited) + case v: java.math.BigDecimal => Literal(Decimal(v), DecimalType.Unlimited) + case v: Decimal => Literal(v, DecimalType.Unlimited) + case v: java.sql.Timestamp => Literal(v, TimestampType) + case v: java.sql.Date => Literal(v, DateType) + case v: Array[Byte] => Literal(v, BinaryType) + case null => Literal(null, NullType) + case _ => + throw new RuntimeException("Unsupported literal type " + literal.getClass + " " + literal) + } + new Column(literalExpr) + } + + def sum(e: Column): Column = Sum(e.expr) + def sumDistinct(e: Column): Column = SumDistinct(e.expr) + def count(e: Column): Column = Count(e.expr) + + @scala.annotation.varargs + def countDistinct(expr: Column, exprs: Column*): Column = + CountDistinct((expr +: exprs).map(_.expr)) + + def approxCountDistinct(e: Column): Column = + ApproxCountDistinct(e.expr) + def approxCountDistinct(e: Column, rsd: Double): Column = + ApproxCountDistinct(e.expr, rsd) + + def avg(e: Column): Column = Average(e.expr) + def first(e: Column): Column = First(e.expr) + def last(e: Column): Column = Last(e.expr) + def min(e: Column): Column = Min(e.expr) + def max(e: Column): Column = Max(e.expr) + + def upper(e: Column): Column = Upper(e.expr) + def lower(e: Column): Column = Lower(e.expr) + def sqrt(e: Column): Column = Sqrt(e.expr) + def abs(e: Column): Column = Abs(e.expr) + + + // scalastyle:off + + /* Use the following code to generate: + (0 to 22).map { x => + val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[$typeTags](f: Function$x[$types]${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq($argsInUdf)) + }""") + } + + (0 to 22).map { x => + val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") + val fTypes = Seq.fill(x + 1)("_").mkString(", ") + val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + println(s""" + /** + * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { + ScalaUdf(f, returnType, Seq($argsInUdf)) + }""") + } + } + */ + /** + * Call a Scala function of 0 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag](f: Function0[RT]): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT], arg1: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT], arg1: Column, arg2: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT], arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](f: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](f: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](f: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](f: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](f: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](f: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](f: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](f: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](f: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](f: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](f: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF), and automatically + * infer the data types based on the function's signature. + */ + def callUDF[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](f: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT], arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + ////////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Call a Scala function of 0 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function0[_], returnType: DataType): Column = { + ScalaUdf(f, returnType, Seq()) + } + + /** + * Call a Scala function of 1 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr)) + } + + /** + * Call a Scala function of 2 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + } + + /** + * Call a Scala function of 3 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + } + + /** + * Call a Scala function of 4 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + } + + /** + * Call a Scala function of 5 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + } + + /** + * Call a Scala function of 6 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + } + + /** + * Call a Scala function of 7 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + } + + /** + * Call a Scala function of 8 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + } + + /** + * Call a Scala function of 9 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + } + + /** + * Call a Scala function of 10 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + } + + /** + * Call a Scala function of 11 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr)) + } + + /** + * Call a Scala function of 12 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr)) + } + + /** + * Call a Scala function of 13 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr)) + } + + /** + * Call a Scala function of 14 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr)) + } + + /** + * Call a Scala function of 15 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr)) + } + + /** + * Call a Scala function of 16 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr)) + } + + /** + * Call a Scala function of 17 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr)) + } + + /** + * Call a Scala function of 18 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr)) + } + + /** + * Call a Scala function of 19 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr)) + } + + /** + * Call a Scala function of 20 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr)) + } + + /** + * Call a Scala function of 21 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr)) + } + + /** + * Call a Scala function of 22 arguments as user-defined function (UDF). This requires + * you to specify the return data type. + */ + def callUDF(f: Function22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column, arg11: Column, arg12: Column, arg13: Column, arg14: Column, arg15: Column, arg16: Column, arg17: Column, arg18: Column, arg19: Column, arg20: Column, arg21: Column, arg22: Column): Column = { + ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr, arg11.expr, arg12.expr, arg13.expr, arg14.expr, arg15.expr, arg16.expr, arg17.expr, arg18.expr, arg19.expr, arg20.expr, arg21.expr, arg22.expr)) + } + + // scalastyle:on +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala new file mode 100644 index 0000000000000..1c948cbbfe58f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataFrame.scala @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.language.implicitConversions +import scala.collection.JavaConversions._ + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr} +import org.apache.spark.sql.catalyst.plans.logical.Aggregate + + +/** + * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. + */ +class GroupedDataFrame protected[sql](df: DataFrame, groupingExprs: Seq[Expression]) + extends GroupedDataFrameApi { + + private[this] implicit def toDataFrame(aggExprs: Seq[NamedExpression]): DataFrame = { + val namedGroupingExprs = groupingExprs.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.toString)() + } + new DataFrame(df.sqlContext, + Aggregate(groupingExprs, namedGroupingExprs ++ aggExprs, df.logicalPlan)) + } + + private[this] def aggregateNumericColumns(f: Expression => Expression): Seq[NamedExpression] = { + df.numericColumns.map { c => + val a = f(c) + Alias(a, a.toString)() + } + } + + private[this] def strToExpr(expr: String): (Expression => Expression) = { + expr.toLowerCase match { + case "avg" | "average" | "mean" => Average + case "max" => Max + case "min" => Min + case "sum" => Sum + case "count" | "size" => Count + } + } + + /** + * Compute aggregates by specifying a map from column name to aggregate methods. The resulting + * [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max" + * "sum" -> "expense" + * )) + * }}} + */ + override def agg(exprs: Map[String, String]): DataFrame = { + exprs.map { case (colName, expr) => + val a = strToExpr(expr)(df(colName).expr) + Alias(a, a.toString)() + }.toSeq + } + + /** + * Compute aggregates by specifying a map from column name to aggregate methods. The resulting + * [[DataFrame]] will also contain the grouping columns. + * + * The available aggregate methods are `avg`, `max`, `min`, `sum`, `count`. + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * df.groupBy("department").agg(Map( + * "age" -> "max" + * "sum" -> "expense" + * )) + * }}} + */ + def agg(exprs: java.util.Map[String, String]): DataFrame = { + agg(exprs.toMap) + } + + /** + * Compute aggregates by specifying a series of aggregate columns. Unlike other methods in this + * class, the resulting [[DataFrame]] won't automatically include the grouping columns. + * + * The available aggregate methods are defined in [[org.apache.spark.sql.Dsl]]. + * + * {{{ + * // Selects the age of the oldest employee and the aggregate expense for each department + * import org.apache.spark.sql.dsl._ + * df.groupBy("department").agg($"department", max($"age"), sum($"expense")) + * }}} + */ + @scala.annotation.varargs + override def agg(expr: Column, exprs: Column*): DataFrame = { + val aggExprs = (expr +: exprs).map(_.expr).map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.toString)() + } + + new DataFrame(df.sqlContext, Aggregate(groupingExprs, aggExprs, df.logicalPlan)) + } + + /** + * Count the number of rows for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def count(): DataFrame = Seq(Alias(Count(LiteralExpr(1)), "count")()) + + /** + * Compute the average value for each numeric columns for each group. This is an alias for `avg`. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def mean(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the max value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def max(): DataFrame = aggregateNumericColumns(Max) + + /** + * Compute the mean value for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def avg(): DataFrame = aggregateNumericColumns(Average) + + /** + * Compute the min value for each numeric column for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def min(): DataFrame = aggregateNumericColumns(Min) + + /** + * Compute the sum for each numeric columns for each group. + * The resulting [[DataFrame]] will also contain the grouping columns. + */ + override def sum(): DataFrame = aggregateNumericColumns(Sum) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0a22968cc7807..84933dd944837 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -30,7 +30,6 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.dsl.ExpressionConversions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -43,7 +42,7 @@ import org.apache.spark.util.Utils /** * :: AlphaComponent :: - * The entry point for running relational queries using Spark. Allows the creation of [[SchemaRDD]] + * The entry point for running relational queries using Spark. Allows the creation of [[DataFrame]] * objects and the execution of SQL queries. * * @groupname userf Spark SQL Functions @@ -52,8 +51,6 @@ import org.apache.spark.util.Utils @AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) extends org.apache.spark.Logging - with CacheManager - with ExpressionConversions with Serializable { self => @@ -111,37 +108,82 @@ class SQLContext(@transient val sparkContext: SparkContext) } protected[sql] def executeSql(sql: String): this.QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + + protected[sql] def executePlan(plan: LogicalPlan) = new this.QueryExecution(plan) sparkContext.getConf.getAll.foreach { case (key, value) if key.startsWith("spark.sql") => setConf(key, value) case _ => } + protected[sql] val cacheManager = new CacheManager(this) + + /** + * A collection of methods that are considered experimental, but can be used to hook into + * the query planner for advanced functionalities. + */ + val experimental: ExperimentalMethods = new ExperimentalMethods(this) + /** - * Creates a SchemaRDD from an RDD of case classes. + * A collection of methods for registering user-defined functions (UDF). + * + * The following example registers a Scala closure as UDF: + * {{{ + * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) + * }}} + * + * The following example registers a UDF in Java: + * {{{ + * sqlContext.udf().register("myUDF", + * new UDF2() { + * @Override + * public String call(Integer arg1, String arg2) { + * return arg2 + arg1; + * } + * }, DataTypes.StringType); + * }}} + * + * Or, to use Java 8 lambda syntax: + * {{{ + * sqlContext.udf().register("myUDF", + * (Integer arg1, String arg2) -> arg2 + arg1), + * DataTypes.StringType); + * }}} + */ + val udf: UDFRegistration = new UDFRegistration(this) + + /** Returns true if the table is currently cached in-memory. */ + def isCached(tableName: String): Boolean = cacheManager.isCached(tableName) + + /** Caches the specified table in-memory. */ + def cacheTable(tableName: String): Unit = cacheManager.cacheTable(tableName) + + /** Removes the specified table from the in-memory cache. */ + def uncacheTable(tableName: String): Unit = cacheManager.uncacheTable(tableName) + + /** + * Creates a DataFrame from an RDD of case classes. * * @group userf */ - implicit def createSchemaRDD[A <: Product: TypeTag](rdd: RDD[A]): SchemaRDD = { + implicit def createDataFrame[A <: Product: TypeTag](rdd: RDD[A]): DataFrame = { SparkPlan.currentContext.set(self) - val attributeSeq = ScalaReflection.attributesFor[A] - val schema = StructType.fromAttributes(attributeSeq) + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes val rowRDD = RDDConversions.productToRowRdd(rdd, schema) - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRDD)(self)) + new DataFrame(this, LogicalRDD(attributeSeq, rowRDD)(self)) } /** - * Convert a [[BaseRelation]] created for external data sources into a [[SchemaRDD]]. + * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]]. */ - def baseRelationToSchemaRDD(baseRelation: BaseRelation): SchemaRDD = { - new SchemaRDD(this, LogicalRelation(baseRelation)) + def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { + new DataFrame(this, LogicalRelation(baseRelation)) } /** * :: DeveloperApi :: - * Creates a [[SchemaRDD]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s by applying a schema to this RDD. * It is important to make sure that the structure of every [[Row]] of the provided RDD matches * the provided schema. Otherwise, there will be runtime exception. * Example: @@ -157,24 +199,24 @@ class SQLContext(@transient val sparkContext: SparkContext) * val people = * sc.textFile("examples/src/main/resources/people.txt").map( * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) - * val peopleSchemaRDD = sqlContext. applySchema(people, schema) - * peopleSchemaRDD.printSchema + * val dataFrame = sqlContext. applySchema(people, schema) + * dataFrame.printSchema * // root * // |-- name: string (nullable = false) * // |-- age: integer (nullable = true) * - * peopleSchemaRDD.registerTempTable("people") + * dataFrame.registerTempTable("people") * sqlContext.sql("select name from people").collect.foreach(println) * }}} * * @group userf */ @DeveloperApi - def applySchema(rowRDD: RDD[Row], schema: StructType): SchemaRDD = { - // TODO: use MutableProjection when rowRDD is another SchemaRDD and the applied + def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { + // TODO: use MutableProjection when rowRDD is another DataFrame and the applied // schema differs from the existing schema on any field data type. val logicalPlan = LogicalRDD(schema.toAttributes, rowRDD)(self) - new SchemaRDD(this, logicalPlan) + new DataFrame(this, logicalPlan) } /** @@ -183,7 +225,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ - def applySchema(rdd: RDD[_], beanClass: Class[_]): SchemaRDD = { + def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { val attributeSeq = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => @@ -201,7 +243,7 @@ class SQLContext(@transient val sparkContext: SparkContext) ) : Row } } - new SchemaRDD(this, LogicalRDD(attributeSeq, rowRdd)(this)) + new DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) } /** @@ -210,35 +252,35 @@ class SQLContext(@transient val sparkContext: SparkContext) * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, * SELECT * queries will return the columns in an undefined order. */ - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): SchemaRDD = { + def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { applySchema(rdd.rdd, beanClass) } /** - * Loads a Parquet file, returning the result as a [[SchemaRDD]]. + * Loads a Parquet file, returning the result as a [[DataFrame]]. * * @group userf */ - def parquetFile(path: String): SchemaRDD = - new SchemaRDD(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) + def parquetFile(path: String): DataFrame = + new DataFrame(this, parquet.ParquetRelation(path, Some(sparkContext.hadoopConfiguration), this)) /** - * Loads a JSON file (one object per line), returning the result as a [[SchemaRDD]]. + * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * * @group userf */ - def jsonFile(path: String): SchemaRDD = jsonFile(path, 1.0) + def jsonFile(path: String): DataFrame = jsonFile(path, 1.0) /** * :: Experimental :: * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * * @group userf */ @Experimental - def jsonFile(path: String, schema: StructType): SchemaRDD = { + def jsonFile(path: String, schema: StructType): DataFrame = { val json = sparkContext.textFile(path) jsonRDD(json, schema) } @@ -247,29 +289,29 @@ class SQLContext(@transient val sparkContext: SparkContext) * :: Experimental :: */ @Experimental - def jsonFile(path: String, samplingRatio: Double): SchemaRDD = { + def jsonFile(path: String, samplingRatio: Double): DataFrame = { val json = sparkContext.textFile(path) jsonRDD(json, samplingRatio) } /** * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[SchemaRDD]]. + * [[DataFrame]]. * It goes through the entire dataset once to determine the schema. * * @group userf */ - def jsonRDD(json: RDD[String]): SchemaRDD = jsonRDD(json, 1.0) + def jsonRDD(json: RDD[String]): DataFrame = jsonRDD(json, 1.0) /** * :: Experimental :: * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[SchemaRDD]]. + * returning the result as a [[DataFrame]]. * * @group userf */ @Experimental - def jsonRDD(json: RDD[String], schema: StructType): SchemaRDD = { + def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = Option(schema).getOrElse( @@ -283,7 +325,7 @@ class SQLContext(@transient val sparkContext: SparkContext) * :: Experimental :: */ @Experimental - def jsonRDD(json: RDD[String], samplingRatio: Double): SchemaRDD = { + def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord val appliedSchema = JsonRDD.nullTypeToStringType( @@ -298,8 +340,8 @@ class SQLContext(@transient val sparkContext: SparkContext) * * @group userf */ - def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = { - catalog.registerTable(Seq(tableName), rdd.queryExecution.logical) + def registerRDDAsTable(rdd: DataFrame, tableName: String): Unit = { + catalog.registerTable(Seq(tableName), rdd.logicalPlan) } /** @@ -311,61 +353,27 @@ class SQLContext(@transient val sparkContext: SparkContext) * @group userf */ def dropTempTable(tableName: String): Unit = { - tryUncacheQuery(table(tableName)) + cacheManager.tryUncacheQuery(table(tableName)) catalog.unregisterTable(Seq(tableName)) } /** - * Executes a SQL query using Spark, returning the result as a SchemaRDD. The dialect that is + * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. The dialect that is * used for SQL parsing can be configured with 'spark.sql.dialect'. * * @group userf */ - def sql(sqlText: String): SchemaRDD = { + def sql(sqlText: String): DataFrame = { if (conf.dialect == "sql") { - new SchemaRDD(this, parseSql(sqlText)) + new DataFrame(this, parseSql(sqlText)) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}") } } - /** Returns the specified table as a SchemaRDD */ - def table(tableName: String): SchemaRDD = - new SchemaRDD(this, catalog.lookupRelation(Seq(tableName))) - - /** - * A collection of methods that are considered experimental, but can be used to hook into - * the query planner for advanced functionalities. - */ - val experimental: ExperimentalMethods = new ExperimentalMethods(this) - - /** - * A collection of methods for registering user-defined functions (UDF). - * - * The following example registers a Scala closure as UDF: - * {{{ - * sqlContext.udf.register("myUdf", (arg1: Int, arg2: String) => arg2 + arg1) - * }}} - * - * The following example registers a UDF in Java: - * {{{ - * sqlContext.udf().register("myUDF", - * new UDF2() { - * @Override - * public String call(Integer arg1, String arg2) { - * return arg2 + arg1; - * } - * }, DataTypes.StringType); - * }}} - * - * Or, to use Java 8 lambda syntax: - * {{{ - * sqlContext.udf().register("myUDF", - * (Integer arg1, String arg2) -> arg2 + arg1), - * DataTypes.StringType); - * }}} - */ - val udf: UDFRegistration = new UDFRegistration(this) + /** Returns the specified table as a [[DataFrame]]. */ + def table(tableName: String): DataFrame = + new DataFrame(this, catalog.lookupRelation(Seq(tableName))) protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext: SparkContext = self.sparkContext @@ -454,15 +462,14 @@ class SQLContext(@transient val sparkContext: SparkContext) * access to the intermediate phases of query execution for developers. */ @DeveloperApi - protected abstract class QueryExecution { - def logical: LogicalPlan + protected class QueryExecution(val logical: LogicalPlan) { - lazy val analyzed = ExtractPythonUdfs(analyzer(logical)) - lazy val withCachedData = useCachedData(analyzed) - lazy val optimizedPlan = optimizer(withCachedData) + lazy val analyzed: LogicalPlan = ExtractPythonUdfs(analyzer(logical)) + lazy val withCachedData: LogicalPlan = cacheManager.useCachedData(analyzed) + lazy val optimizedPlan: LogicalPlan = optimizer(withCachedData) // TODO: Don't just pick the first one... - lazy val sparkPlan = { + lazy val sparkPlan: SparkPlan = { SparkPlan.currentContext.set(self) planner(optimizedPlan).next() } @@ -512,7 +519,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schemaString: String): SchemaRDD = { + schemaString: String): DataFrame = { val schema = parseDataType(schemaString).asInstanceOf[StructType] applySchemaToPythonRDD(rdd, schema) } @@ -522,7 +529,7 @@ class SQLContext(@transient val sparkContext: SparkContext) */ protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], - schema: StructType): SchemaRDD = { + schema: StructType): DataFrame = { def needsConversion(dataType: DataType): Boolean = dataType match { case ByteType => true @@ -549,7 +556,7 @@ class SQLContext(@transient val sparkContext: SparkContext) iter.map { m => new GenericRow(m): Row} } - new SchemaRDD(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + new DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala deleted file mode 100644 index d1e21dffeb8c5..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala +++ /dev/null @@ -1,511 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import java.util.{List => JList} - -import scala.collection.JavaConversions._ - -import com.fasterxml.jackson.core.JsonFactory - -import net.razorvine.pickle.Pickler - -import org.apache.spark.{Dependency, OneToOneDependency, Partition, Partitioner, TaskContext} -import org.apache.spark.annotation.{AlphaComponent, Experimental} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.api.python.SerDeUtil -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.ScalaReflection -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{LogicalRDD, EvaluatePython} -import org.apache.spark.sql.json.JsonRDD -import org.apache.spark.sql.types.{BooleanType, StructType} -import org.apache.spark.storage.StorageLevel - -/** - * :: AlphaComponent :: - * An RDD of [[Row]] objects that has an associated schema. In addition to standard RDD functions, - * SchemaRDDs can be used in relational queries, as shown in the examples below. - * - * Importing a SQLContext brings an implicit into scope that automatically converts a standard RDD - * whose elements are scala case classes into a SchemaRDD. This conversion can also be done - * explicitly using the `createSchemaRDD` function on a [[SQLContext]]. - * - * A `SchemaRDD` can also be created by loading data in from external sources. - * Examples are loading data from Parquet files by using the `parquetFile` method on [[SQLContext]] - * and loading JSON datasets by using `jsonFile` and `jsonRDD` methods on [[SQLContext]]. - * - * == SQL Queries == - * A SchemaRDD can be registered as a table in the [[SQLContext]] that was used to create it. Once - * an RDD has been registered as a table, it can be used in the FROM clause of SQL statements. - * - * {{{ - * // One method for defining the schema of an RDD is to make a case class with the desired column - * // names and types. - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i"))) - * // Any RDD containing case classes can be registered as a table. The schema of the table is - * // automatically inferred using scala reflection. - * rdd.registerTempTable("records") - * - * val results: SchemaRDD = sql("SELECT * FROM records") - * }}} - * - * == Language Integrated Queries == - * - * {{{ - * - * case class Record(key: Int, value: String) - * - * val sc: SparkContext // An existing spark context. - * val sqlContext = new SQLContext(sc) - * - * // Importing the SQL context gives access to all the SQL functions and implicit conversions. - * import sqlContext._ - * - * val rdd = sc.parallelize((1 to 100).map(i => Record(i, "val_" + i))) - * - * // Example of language integrated queries. - * rdd.where('key === 1).orderBy('value.asc).select('key).collect() - * }}} - * - * @groupname Query Language Integrated Queries - * @groupdesc Query Functions that create new queries from SchemaRDDs. The - * result of all query functions is also a SchemaRDD, allowing multiple operations to be - * chained using a builder pattern. - * @groupprio Query -2 - * @groupname schema SchemaRDD Functions - * @groupprio schema -1 - * @groupname Ungrouped Base RDD Functions - */ -@AlphaComponent -class SchemaRDD( - @transient val sqlContext: SQLContext, - @transient val baseLogicalPlan: LogicalPlan) - extends RDD[Row](sqlContext.sparkContext, Nil) with SchemaRDDLike { - - def baseSchemaRDD = this - - // ========================================================================================= - // RDD functions: Copy the internal row representation so we present immutable data to users. - // ========================================================================================= - - override def compute(split: Partition, context: TaskContext): Iterator[Row] = - firstParent[Row].compute(split, context).map(ScalaReflection.convertRowToScala(_, this.schema)) - - override def getPartitions: Array[Partition] = firstParent[Row].partitions - - override protected def getDependencies: Seq[Dependency[_]] = { - schema // Force reification of the schema so it is available on executors. - - List(new OneToOneDependency(queryExecution.toRdd)) - } - - /** - * Returns the schema of this SchemaRDD (represented by a [[StructType]]). - * - * @group schema - */ - lazy val schema: StructType = queryExecution.analyzed.schema - - /** - * Returns a new RDD with each row transformed to a JSON string. - * - * @group schema - */ - def toJSON: RDD[String] = { - val rowSchema = this.schema - this.mapPartitions { iter => - val jsonFactory = new JsonFactory() - iter.map(JsonRDD.rowToJSON(rowSchema, jsonFactory)) - } - } - - - // ======================================================================= - // Query DSL - // ======================================================================= - - /** - * Changes the output of this relation to the given expressions, similar to the `SELECT` clause - * in SQL. - * - * {{{ - * schemaRDD.select('a, 'b + 'c, 'd as 'aliasedName) - * }}} - * - * @param exprs a set of logical expression that will be evaluated for each input row. - * - * @group Query - */ - def select(exprs: Expression*): SchemaRDD = { - val aliases = exprs.zipWithIndex.map { - case (ne: NamedExpression, _) => ne - case (e, i) => Alias(e, s"c$i")() - } - new SchemaRDD(sqlContext, Project(aliases, logicalPlan)) - } - - /** - * Filters the output, only returning those rows where `condition` evaluates to true. - * - * {{{ - * schemaRDD.where('a === 'b) - * schemaRDD.where('a === 1) - * schemaRDD.where('a + 'b > 10) - * }}} - * - * @group Query - */ - def where(condition: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Filter(condition, logicalPlan)) - - /** - * Performs a relational join on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be joined with this one. - * @param joinType One of `Inner`, `LeftOuter`, `RightOuter`, or `FullOuter`. Defaults to `Inner.` - * @param on An optional condition for the join operation. This is equivalent to the `ON` - * clause in standard SQL. In the case of `Inner` joins, specifying a - * `condition` is equivalent to adding `where` clauses after the `join`. - * - * @group Query - */ - def join( - otherPlan: SchemaRDD, - joinType: JoinType = Inner, - on: Option[Expression] = None): SchemaRDD = - new SchemaRDD(sqlContext, Join(logicalPlan, otherPlan.logicalPlan, joinType, on)) - - /** - * Sorts the results by the given expressions. - * {{{ - * schemaRDD.orderBy('a) - * schemaRDD.orderBy('a, 'b) - * schemaRDD.orderBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def orderBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, true, logicalPlan)) - - /** - * Sorts the results by the given expressions within partition. - * {{{ - * schemaRDD.sortBy('a) - * schemaRDD.sortBy('a, 'b) - * schemaRDD.sortBy('a.asc, 'b.desc) - * }}} - * - * @group Query - */ - def sortBy(sortExprs: SortOrder*): SchemaRDD = - new SchemaRDD(sqlContext, Sort(sortExprs, false, logicalPlan)) - - @deprecated("use limit with integer argument", "1.1.0") - def limit(limitExpr: Expression): SchemaRDD = - new SchemaRDD(sqlContext, Limit(limitExpr, logicalPlan)) - - /** - * Limits the results by the given integer. - * {{{ - * schemaRDD.limit(10) - * }}} - * @group Query - */ - def limit(limitNum: Int): SchemaRDD = - new SchemaRDD(sqlContext, Limit(Literal(limitNum), logicalPlan)) - - /** - * Performs a grouping followed by an aggregation. - * - * {{{ - * schemaRDD.groupBy('year)(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def groupBy(groupingExprs: Expression*)(aggregateExprs: Expression*): SchemaRDD = { - val aliasedExprs = aggregateExprs.map { - case ne: NamedExpression => ne - case e => Alias(e, e.toString)() - } - new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan)) - } - - /** - * Performs an aggregation over all Rows in this RDD. - * This is equivalent to a groupBy with no grouping expressions. - * - * {{{ - * schemaRDD.aggregate(Sum('sales) as 'totalSales) - * }}} - * - * @group Query - */ - def aggregate(aggregateExprs: Expression*): SchemaRDD = { - groupBy()(aggregateExprs: _*) - } - - /** - * Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes - * with the same name, for example, when performing self-joins. - * - * {{{ - * val x = schemaRDD.where('a === 1).as('x) - * val y = schemaRDD.where('a === 2).as('y) - * x.join(y).where("x.a".attr === "y.a".attr), - * }}} - * - * @group Query - */ - def as(alias: Symbol) = - new SchemaRDD(sqlContext, Subquery(alias.name, logicalPlan)) - - /** - * Combines the tuples of two RDDs with the same schema, keeping duplicates. - * - * @group Query - */ - def unionAll(otherPlan: SchemaRDD) = - new SchemaRDD(sqlContext, Union(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational except on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be excepted from this one. - * - * @group Query - */ - def except(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Except(logicalPlan, otherPlan.logicalPlan)) - - /** - * Performs a relational intersect on two SchemaRDDs - * - * @param otherPlan the [[SchemaRDD]] that should be intersected with this one. - * - * @group Query - */ - def intersect(otherPlan: SchemaRDD): SchemaRDD = - new SchemaRDD(sqlContext, Intersect(logicalPlan, otherPlan.logicalPlan)) - - /** - * Filters tuples using a function over the value of the specified column. - * - * {{{ - * schemaRDD.where('a)((a: Int) => ...) - * }}} - * - * @group Query - */ - def where[T1](arg1: Symbol)(udf: (T1) => Boolean) = - new SchemaRDD( - sqlContext, - Filter(ScalaUdf(udf, BooleanType, Seq(UnresolvedAttribute(arg1.name))), logicalPlan)) - - /** - * :: Experimental :: - * Returns a sampled version of the underlying dataset. - * - * @group Query - */ - @Experimental - override - def sample( - withReplacement: Boolean = true, - fraction: Double, - seed: Long) = - new SchemaRDD(sqlContext, Sample(fraction, withReplacement, seed, logicalPlan)) - - /** - * :: Experimental :: - * Return the number of elements in the RDD. Unlike the base RDD implementation of count, this - * implementation leverages the query optimizer to compute the count on the SchemaRDD, which - * supports features such as filter pushdown. - * - * @group Query - */ - @Experimental - override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0) - - /** - * :: Experimental :: - * Applies the given Generator, or table generating function, to this relation. - * - * @param generator A table generating function. The API for such functions is likely to change - * in future releases - * @param join when set to true, each output row of the generator is joined with the input row - * that produced it. - * @param outer when set to true, at least one row will be produced for each input row, similar to - * an `OUTER JOIN` in SQL. When no output rows are produced by the generator for a - * given row, a single row will be output, with `NULL` values for each of the - * generated columns. - * @param alias an optional alias that can be used as qualifier for the attributes that are - * produced by this generate operation. - * - * @group Query - */ - @Experimental - def generate( - generator: Generator, - join: Boolean = false, - outer: Boolean = false, - alias: Option[String] = None) = - new SchemaRDD(sqlContext, Generate(generator, join, outer, alias, logicalPlan)) - - /** - * Returns this RDD as a SchemaRDD. Intended primarily to force the invocation of the implicit - * conversion from a standard RDD to a SchemaRDD. - * - * @group schema - */ - def toSchemaRDD = this - - /** - * Converts a JavaRDD to a PythonRDD. It is used by pyspark. - */ - private[sql] def javaToPython: JavaRDD[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val jrdd = this.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD() - SerDeUtil.javaToPython(jrdd) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same - * format as javaToPython. It is used by pyspark. - */ - private[sql] def collectToPython: JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(collect().map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Serializes the Array[Row] returned by SchemaRDD's takeSample(), using the same - * format as javaToPython and collectToPython. It is used by pyspark. - */ - private[sql] def takeSampleToPython( - withReplacement: Boolean, - num: Int, - seed: Long): JList[Array[Byte]] = { - val fieldTypes = schema.fields.map(_.dataType) - val pickle = new Pickler - new java.util.ArrayList(this.takeSample(withReplacement, num, seed).map { row => - EvaluatePython.rowToArray(row, fieldTypes) - }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable) - } - - /** - * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return value - * of base RDD functions that do not change schema. - * - * @param rdd RDD derived from this one and has same schema - * - * @group schema - */ - private def applySchema(rdd: RDD[Row]): SchemaRDD = { - new SchemaRDD(sqlContext, - LogicalRDD(queryExecution.analyzed.output.map(_.newInstance()), rdd)(sqlContext)) - } - - // ======================================================================= - // Overridden RDD actions - // ======================================================================= - - override def collect(): Array[Row] = queryExecution.executedPlan.executeCollect() - - def collectAsList(): java.util.List[Row] = java.util.Arrays.asList(collect() : _*) - - override def take(num: Int): Array[Row] = limit(num).collect() - - // ======================================================================= - // Base RDD functions that do NOT change schema - // ======================================================================= - - // Transformations (return a new RDD) - - override def coalesce(numPartitions: Int, shuffle: Boolean = false) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.coalesce(numPartitions, shuffle)(ord)) - - override def distinct(): SchemaRDD = applySchema(super.distinct()) - - override def distinct(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.distinct(numPartitions)(ord)) - - def distinct(numPartitions: Int): SchemaRDD = - applySchema(super.distinct(numPartitions)(null)) - - override def filter(f: Row => Boolean): SchemaRDD = - applySchema(super.filter(f)) - - override def intersection(other: RDD[Row]): SchemaRDD = - applySchema(super.intersection(other)) - - override def intersection(other: RDD[Row], partitioner: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.intersection(other, partitioner)(ord)) - - override def intersection(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.intersection(other, numPartitions)) - - override def repartition(numPartitions: Int) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.repartition(numPartitions)(ord)) - - override def subtract(other: RDD[Row]): SchemaRDD = - applySchema(super.subtract(other)) - - override def subtract(other: RDD[Row], numPartitions: Int): SchemaRDD = - applySchema(super.subtract(other, numPartitions)) - - override def subtract(other: RDD[Row], p: Partitioner) - (implicit ord: Ordering[Row] = null): SchemaRDD = - applySchema(super.subtract(other, p)(ord)) - - /** Overridden cache function will always use the in-memory columnar caching. */ - override def cache(): this.type = { - sqlContext.cacheQuery(this) - this - } - - override def persist(newLevel: StorageLevel): this.type = { - sqlContext.cacheQuery(this, None, newLevel) - this - } - - override def unpersist(blocking: Boolean): this.type = { - sqlContext.tryUncacheQuery(this, blocking) - this - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala deleted file mode 100644 index 3cf9209465b76..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDDLike.scala +++ /dev/null @@ -1,139 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql - -import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.LogicalRDD - -/** - * Contains functions that are shared between all SchemaRDD types (i.e., Scala, Java) - */ -private[sql] trait SchemaRDDLike { - @transient def sqlContext: SQLContext - @transient val baseLogicalPlan: LogicalPlan - - private[sql] def baseSchemaRDD: SchemaRDD - - /** - * :: DeveloperApi :: - * A lazily computed query execution workflow. All other RDD operations are passed - * through to the RDD that is produced by this workflow. This workflow is produced lazily because - * invoking the whole query optimization pipeline can be expensive. - * - * The query execution is considered a Developer API as phases may be added or removed in future - * releases. This execution is only exposed to provide an interface for inspecting the various - * phases for debugging purposes. Applications should not depend on particular phases existing - * or producing any specific output, even for exactly the same query. - * - * Additionally, the RDD exposed by this execution is not designed for consumption by end users. - * In particular, it does not contain any schema information, and it reuses Row objects - * internally. This object reuse improves performance, but can make programming against the RDD - * more difficult. Instead end users should perform RDD operations on a SchemaRDD directly. - */ - @transient - @DeveloperApi - lazy val queryExecution = sqlContext.executePlan(baseLogicalPlan) - - @transient protected[spark] val logicalPlan: LogicalPlan = baseLogicalPlan match { - // For various commands (like DDL) and queries with side effects, we force query optimization to - // happen right away to let these side effects take place eagerly. - case _: Command | _: InsertIntoTable | _: CreateTableAsSelect[_] |_: WriteToFile => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) - case _ => - baseLogicalPlan - } - - override def toString = - s"""${super.toString} - |== Query Plan == - |${queryExecution.simpleString}""".stripMargin.trim - - /** - * Saves the contents of this `SchemaRDD` as a parquet file, preserving the schema. Files that - * are written out using this method can be read back in as a SchemaRDD using the `parquetFile` - * function. - * - * @group schema - */ - def saveAsParquetFile(path: String): Unit = { - sqlContext.executePlan(WriteToFile(path, logicalPlan)).toRdd - } - - /** - * Registers this RDD as a temporary table using the given name. The lifetime of this temporary - * table is tied to the [[SQLContext]] that was used to create this SchemaRDD. - * - * @group schema - */ - def registerTempTable(tableName: String): Unit = { - sqlContext.registerRDDAsTable(baseSchemaRDD, tableName) - } - - @deprecated("Use registerTempTable instead of registerAsTable.", "1.1") - def registerAsTable(tableName: String): Unit = registerTempTable(tableName) - - /** - * :: Experimental :: - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String, overwrite: Boolean): Unit = - sqlContext.executePlan(InsertIntoTable(UnresolvedRelation(Seq(tableName)), - Map.empty, logicalPlan, overwrite)).toRdd - - /** - * :: Experimental :: - * Appends the rows from this RDD to the specified table. - * - * @group schema - */ - @Experimental - def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) - - /** - * :: Experimental :: - * Creates a table from the the contents of this SchemaRDD. This will fail if the table already - * exists. - * - * Note that this currently only works with SchemaRDDs that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * @group schema - */ - @Experimental - def saveAsTable(tableName: String): Unit = - sqlContext.executePlan(CreateTableAsSelect(None, tableName, logicalPlan, false)).toRdd - - /** Returns the schema as a string in the tree format. - * - * @group schema - */ - def schemaString: String = baseSchemaRDD.schema.treeString - - /** Prints out the schema. - * - * @group schema - */ - def printSchema(): Unit = println(schemaString) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 2e9d037f93c03..1beb19437a8da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -21,7 +21,7 @@ import java.util.{List => JList, Map => JMap} import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.Accumulator +import org.apache.spark.{Accumulator, Logging} import org.apache.spark.api.python.PythonBroadcast import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.api.java._ @@ -34,7 +34,7 @@ import org.apache.spark.sql.types.DataType /** * Functions for registering user-defined functions. */ -class UDFRegistration (sqlContext: SQLContext) extends org.apache.spark.Logging { +class UDFRegistration(sqlContext: SQLContext) extends Logging { private val functionRegistry = sqlContext.functionRegistry diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala b/sql/core/src/main/scala/org/apache/spark/sql/api.scala new file mode 100644 index 0000000000000..eb0eb3f32560c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala @@ -0,0 +1,299 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import scala.reflect.ClassTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.storage.StorageLevel + + +/** + * An internal interface defining the RDD-like methods for [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +private[sql] trait RDDApi[T] { + + def cache(): this.type = persist() + + def persist(): this.type + + def persist(newLevel: StorageLevel): this.type + + def unpersist(): this.type = unpersist(blocking = false) + + def unpersist(blocking: Boolean): this.type + + def map[R: ClassTag](f: T => R): RDD[R] + + def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R] + + def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R] + + def foreach(f: T => Unit): Unit + + def foreachPartition(f: Iterator[T] => Unit): Unit + + def take(n: Int): Array[T] + + def collect(): Array[T] + + def collectAsList(): java.util.List[T] + + def count(): Long + + def first(): T + + def repartition(numPartitions: Int): DataFrame +} + + +/** + * An internal interface defining data frame related methods in [[DataFrame]]. + * Please use [[DataFrame]] directly, and do NOT use this. + */ +private[sql] trait DataFrameSpecificApi { + + def schema: StructType + + def printSchema(): Unit + + def dtypes: Array[(String, String)] + + def columns: Array[String] + + def head(): Row + + def head(n: Int): Array[Row] + + ///////////////////////////////////////////////////////////////////////////// + // Relational operators + ///////////////////////////////////////////////////////////////////////////// + def apply(colName: String): Column + + def apply(projection: Product): DataFrame + + @scala.annotation.varargs + def select(cols: Column*): DataFrame + + @scala.annotation.varargs + def select(col: String, cols: String*): DataFrame + + def apply(condition: Column): DataFrame + + def as(name: String): DataFrame + + def filter(condition: Column): DataFrame + + def where(condition: Column): DataFrame + + @scala.annotation.varargs + def groupBy(cols: Column*): GroupedDataFrame + + @scala.annotation.varargs + def groupBy(col1: String, cols: String*): GroupedDataFrame + + def agg(exprs: Map[String, String]): DataFrame + + def agg(exprs: java.util.Map[String, String]): DataFrame + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame + + @scala.annotation.varargs + def sort(sortExpr: Column, sortExprs: Column*): DataFrame + + @scala.annotation.varargs + def sort(sortCol: String, sortCols: String*): DataFrame + + @scala.annotation.varargs + def orderBy(sortExpr: Column, sortExprs: Column*): DataFrame + + @scala.annotation.varargs + def orderBy(sortCol: String, sortCols: String*): DataFrame + + def join(right: DataFrame): DataFrame + + def join(right: DataFrame, joinExprs: Column): DataFrame + + def join(right: DataFrame, joinExprs: Column, joinType: String): DataFrame + + def limit(n: Int): DataFrame + + def unionAll(other: DataFrame): DataFrame + + def intersect(other: DataFrame): DataFrame + + def except(other: DataFrame): DataFrame + + def sample(withReplacement: Boolean, fraction: Double, seed: Long): DataFrame + + def sample(withReplacement: Boolean, fraction: Double): DataFrame + + ///////////////////////////////////////////////////////////////////////////// + // Column mutation + ///////////////////////////////////////////////////////////////////////////// + def addColumn(colName: String, col: Column): DataFrame + + ///////////////////////////////////////////////////////////////////////////// + // I/O and interaction with other frameworks + ///////////////////////////////////////////////////////////////////////////// + + def rdd: RDD[Row] + + def toJavaRDD: JavaRDD[Row] = rdd.toJavaRDD() + + def toJSON: RDD[String] + + def registerTempTable(tableName: String): Unit + + def saveAsParquetFile(path: String): Unit + + @Experimental + def saveAsTable(tableName: String): Unit + + @Experimental + def insertInto(tableName: String, overwrite: Boolean): Unit + + @Experimental + def insertInto(tableName: String): Unit = insertInto(tableName, overwrite = false) + + ///////////////////////////////////////////////////////////////////////////// + // Stat functions + ///////////////////////////////////////////////////////////////////////////// +// def describe(): Unit +// +// def mean(): Unit +// +// def max(): Unit +// +// def min(): Unit +} + + +/** + * An internal interface defining expression APIs for [[DataFrame]]. + * Please use [[DataFrame]] and [[Column]] directly, and do NOT use this. + */ +private[sql] trait ExpressionApi { + + def isComputable: Boolean + + def unary_- : Column + def unary_! : Column + def unary_~ : Column + + def + (other: Column): Column + def + (other: Any): Column + def - (other: Column): Column + def - (other: Any): Column + def * (other: Column): Column + def * (other: Any): Column + def / (other: Column): Column + def / (other: Any): Column + def % (other: Column): Column + def % (other: Any): Column + def & (other: Column): Column + def & (other: Any): Column + def | (other: Column): Column + def | (other: Any): Column + def ^ (other: Column): Column + def ^ (other: Any): Column + + def && (other: Column): Column + def && (other: Boolean): Column + def || (other: Column): Column + def || (other: Boolean): Column + + def < (other: Column): Column + def < (other: Any): Column + def <= (other: Column): Column + def <= (other: Any): Column + def > (other: Column): Column + def > (other: Any): Column + def >= (other: Column): Column + def >= (other: Any): Column + def === (other: Column): Column + def === (other: Any): Column + def equalTo(other: Column): Column + def equalTo(other: Any): Column + def <=> (other: Column): Column + def <=> (other: Any): Column + def !== (other: Column): Column + def !== (other: Any): Column + + @scala.annotation.varargs + def in(list: Column*): Column + + def like(other: String): Column + def rlike(other: String): Column + + def contains(other: Column): Column + def contains(other: Any): Column + def startsWith(other: Column): Column + def startsWith(other: String): Column + def endsWith(other: Column): Column + def endsWith(other: String): Column + + def substr(startPos: Column, len: Column): Column + def substr(startPos: Int, len: Int): Column + + def isNull: Column + def isNotNull: Column + + def getItem(ordinal: Int): Column + def getField(fieldName: String): Column + + def cast(to: DataType): Column + def cast(to: String): Column + + def asc: Column + def desc: Column + + def as(alias: String): Column +} + + +/** + * An internal interface defining aggregation APIs for [[DataFrame]]. + * Please use [[DataFrame]] and [[GroupedDataFrame]] directly, and do NOT use this. + */ +private[sql] trait GroupedDataFrameApi { + + def agg(exprs: Map[String, String]): DataFrame + + @scala.annotation.varargs + def agg(expr: Column, exprs: Column*): DataFrame + + def avg(): DataFrame + + def mean(): DataFrame + + def min(): DataFrame + + def max(): DataFrame + + def sum(): DataFrame + + def count(): DataFrame + + // TODO: Add var, std +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala old mode 100644 new mode 100755 index be9f155253d77..ad44a01d0e164 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -56,10 +56,6 @@ case class Aggregate( } } - // HACK: Generators don't correctly preserve their output through serializations so we grab - // out child's output attributes statically here. - private[this] val childOutput = child.output - override def output = aggregateExpressions.map(_.toAttribute) /** @@ -81,7 +77,7 @@ case class Aggregate( case a: AggregateExpression => ComputedAggregate( a, - BindReferences.bindReference(a, childOutput), + BindReferences.bindReference(a, child.output), AttributeReference(s"aggResult:$a", a.dataType, a.nullable)()) } }.toArray @@ -150,7 +146,7 @@ case class Aggregate( } else { child.execute().mapPartitions { iter => val hashTable = new HashMap[Row, Array[AggregateFunction]] - val groupingProjection = new InterpretedMutableProjection(groupingExpressions, childOutput) + val groupingProjection = new InterpretedMutableProjection(groupingExpressions, child.output) var currentRow: Row = null while (iter.hasNext) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 52a31f01a4358..6fba76c52171b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Row, Attribute} import org.apache.spark.sql.catalyst.plans.logical @@ -137,7 +137,9 @@ case class CacheTableCommand( isLazy: Boolean) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - plan.foreach(p => new SchemaRDD(sqlContext, p).registerTempTable(tableName)) + plan.foreach { logicalPlan => + sqlContext.registerRDDAsTable(new DataFrame(sqlContext, logicalPlan), tableName) + } sqlContext.cacheTable(tableName) if (!isLazy) { @@ -159,7 +161,7 @@ case class CacheTableCommand( case class UncacheTableCommand(tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext) = { - sqlContext.table(tableName).unpersist() + sqlContext.table(tableName).unpersist(blocking = false) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 4d7e338e8ed13..5cc67cdd13944 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -22,7 +22,7 @@ import scala.collection.mutable.HashSet import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext._ -import org.apache.spark.sql.{SchemaRDD, Row} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.types._ @@ -39,10 +39,10 @@ package object debug { /** * :: DeveloperApi :: - * Augments SchemaRDDs with debug methods. + * Augments [[DataFrame]]s with debug methods. */ @DeveloperApi - implicit class DebugQuery(query: SchemaRDD) { + implicit class DebugQuery(query: DataFrame) { def debug(): Unit = { val plan = query.queryExecution.executedPlan val visited = new collection.mutable.HashSet[TreeNodeRef]() @@ -166,7 +166,7 @@ package object debug { /** * :: DeveloperApi :: - * Augments SchemaRDDs with debug methods. + * Augments [[DataFrame]]s with debug methods. */ @DeveloperApi private[sql] case class TypeCheck(child: SparkPlan) extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index 6dd39be807037..7c49b5220d607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -37,5 +37,5 @@ package object sql { * Converts a logical plan into zero or more SparkPlans. */ @DeveloperApi - type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] + protected[sql] type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 9d9150246c8d4..10df8c3310092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.parquet import scala.collection.mutable.{Buffer, ArrayBuffer, HashMap} +import parquet.column.Dictionary import parquet.io.api.{PrimitiveConverter, GroupConverter, Binary, Converter} import parquet.schema.MessageType @@ -102,12 +103,8 @@ private[sql] object CatalystConverter { } // Strings, Shorts and Bytes do not have a corresponding type in Parquet // so we need to treat them separately - case StringType => { - new CatalystPrimitiveConverter(parent, fieldIndex) { - override def addBinary(value: Binary): Unit = - parent.updateString(fieldIndex, value) - } - } + case StringType => + new CatalystPrimitiveStringConverter(parent, fieldIndex) case ShortType => { new CatalystPrimitiveConverter(parent, fieldIndex) { override def addInt(value: Int): Unit = @@ -197,8 +194,8 @@ private[parquet] abstract class CatalystConverter extends GroupConverter { protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = updateField(fieldIndex, value.getBytes) - protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - updateField(fieldIndex, value.toStringUsingUTF8) + protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = + updateField(fieldIndex, value) protected[parquet] def updateDecimal(fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { updateField(fieldIndex, readDecimal(new Decimal(), value, ctype)) @@ -384,8 +381,8 @@ private[parquet] class CatalystPrimitiveRowConverter( override protected[parquet] def updateBinary(fieldIndex: Int, value: Binary): Unit = current.update(fieldIndex, value.getBytes) - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = - current.setString(fieldIndex, value.toStringUsingUTF8) + override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = + current.setString(fieldIndex, value) override protected[parquet] def updateDecimal( fieldIndex: Int, value: Binary, ctype: DecimalType): Unit = { @@ -426,6 +423,33 @@ private[parquet] class CatalystPrimitiveConverter( parent.updateLong(fieldIndex, value) } +/** + * A `parquet.io.api.PrimitiveConverter` that converts Parquet Binary to Catalyst String. + * Supports dictionaries to reduce Binary to String conversion overhead. + * + * Follows pattern in Parquet of using dictionaries, where supported, for String conversion. + * + * @param parent The parent group converter. + * @param fieldIndex The index inside the record. + */ +private[parquet] class CatalystPrimitiveStringConverter(parent: CatalystConverter, fieldIndex: Int) + extends CatalystPrimitiveConverter(parent, fieldIndex) { + + private[this] var dict: Array[String] = null + + override def hasDictionarySupport: Boolean = true + + override def setDictionary(dictionary: Dictionary):Unit = + dict = Array.tabulate(dictionary.getMaxId + 1) {dictionary.decodeToBinary(_).toStringUsingUTF8} + + + override def addValueFromDictionary(dictionaryId: Int): Unit = + parent.updateString(fieldIndex, dict(dictionaryId)) + + override def addBinary(value: Binary): Unit = + parent.updateString(fieldIndex, value.toStringUsingUTF8) +} + private[parquet] object CatalystArrayConverter { val INITIAL_ARRAY_SIZE = 20 } @@ -583,9 +607,9 @@ private[parquet] class CatalystNativeArrayConverter( elements += 1 } - override protected[parquet] def updateString(fieldIndex: Int, value: Binary): Unit = { + override protected[parquet] def updateString(fieldIndex: Int, value: String): Unit = { checkGrowBuffer() - buffer(elements) = value.toStringUsingUTF8.asInstanceOf[NativeType] + buffer(elements) = value.asInstanceOf[NativeType] elements += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala index f08350878f239..0357dcc4688be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetFilters.scala @@ -164,33 +164,57 @@ private[sql] object ParquetFilters { case EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeEq.lift(dataType).map(_(name, value)) + case EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeEq.lift(dataType).map(_(name, value)) case EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeEq.lift(dataType).map(_(name, value)) - + case EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(NamedExpression(name, _), NonNullLiteral(value, dataType))) => makeNotEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _))) => + makeNotEq.lift(dataType).map(_(name, value)) case Not(EqualTo(NonNullLiteral(value, dataType), NamedExpression(name, _))) => makeNotEq.lift(dataType).map(_(name, value)) + case Not(EqualTo(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType))) => + makeNotEq.lift(dataType).map(_(name, value)) case LessThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLt.lift(dataType).map(_(name, value)) + case LessThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeLt.lift(dataType).map(_(name, value)) case LessThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGt.lift(dataType).map(_(name, value)) + case LessThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeGt.lift(dataType).map(_(name, value)) case LessThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeLtEq.lift(dataType).map(_(name, value)) + case LessThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeLtEq.lift(dataType).map(_(name, value)) case LessThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeGtEq.lift(dataType).map(_(name, value)) + case LessThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThan(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGt.lift(dataType).map(_(name, value)) + case GreaterThan(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeGt.lift(dataType).map(_(name, value)) case GreaterThan(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLt.lift(dataType).map(_(name, value)) + case GreaterThan(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeLt.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NamedExpression(name, _), NonNullLiteral(value, dataType)) => makeGtEq.lift(dataType).map(_(name, value)) + case GreaterThanOrEqual(Cast(NamedExpression(name, _), dataType), NonNullLiteral(value, _)) => + makeGtEq.lift(dataType).map(_(name, value)) case GreaterThanOrEqual(NonNullLiteral(value, dataType), NamedExpression(name, _)) => makeLtEq.lift(dataType).map(_(name, value)) + case GreaterThanOrEqual(NonNullLiteral(value, _), Cast(NamedExpression(name, _), dataType)) => + makeLtEq.lift(dataType).map(_(name, value)) case And(lhs, rhs) => (createFilter(lhs) ++ createFilter(rhs)).reduceOption(FilterApi.and) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala index cde5160149e9c..a54485e719dad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala @@ -26,7 +26,7 @@ import parquet.hadoop.ParquetOutputFormat import parquet.hadoop.metadata.CompressionCodecName import parquet.schema.MessageType -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, UnresolvedException} import org.apache.spark.sql.catalyst.expressions.{AttributeMap, Attribute} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} @@ -34,8 +34,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Stati /** * Relation that consists of data stored in a Parquet columnar format. * - * Users should interact with parquet files though a SchemaRDD, created by a [[SQLContext]] instead - * of using this class directly. + * Users should interact with parquet files though a [[DataFrame]], created by a [[SQLContext]] + * instead of using this class directly. * * {{{ * val parquetRDD = sqlContext.parquetFile("path/to/parquet.file") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala index 02ce1b3e6d811..9d6c529574da0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTest.scala @@ -23,7 +23,7 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import scala.util.Try -import org.apache.spark.sql.{SQLContext, SchemaRDD} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.util import org.apache.spark.util.Utils @@ -95,12 +95,12 @@ trait ParquetTest { } /** - * Writes `data` to a Parquet file and reads it back as a SchemaRDD, which is then passed to `f`. - * The Parquet file will be deleted after `f` returns. + * Writes `data` to a Parquet file and reads it back as a [[DataFrame]], + * which is then passed to `f`. The Parquet file will be deleted after `f` returns. */ protected def withParquetRDD[T <: Product: ClassTag: TypeTag] (data: Seq[T]) - (f: SchemaRDD => Unit): Unit = { + (f: DataFrame => Unit): Unit = { withParquetFile(data)(path => f(parquetFile(path))) } @@ -112,15 +112,15 @@ trait ParquetTest { } /** - * Writes `data` to a Parquet file, reads it back as a SchemaRDD and registers it as a temporary - * table named `tableName`, then call `f`. The temporary table together with the Parquet file will - * be dropped/deleted after `f` returns. + * Writes `data` to a Parquet file, reads it back as a [[DataFrame]] and registers it as a + * temporary table named `tableName`, then call `f`. The temporary table together with the + * Parquet file will be dropped/deleted after `f` returns. */ protected def withParquetTable[T <: Product: ClassTag: TypeTag] (data: Seq[T], tableName: String) (f: => Unit): Unit = { withParquetRDD(data) { rdd => - rdd.registerTempTable(tableName) + sqlContext.registerRDDAsTable(rdd, tableName) withTempTable(tableName)(f) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala index 1b50afbbabcb0..1e794cad73936 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/newParquet.scala @@ -20,26 +20,26 @@ import java.util.{List => JList} import scala.collection.JavaConversions._ -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.conf.{Configurable, Configuration} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce.{JobContext, InputSplit, Job} - +import org.apache.hadoop.mapreduce.{InputSplit, Job, JobContext} +import parquet.filter2.predicate.FilterApi import parquet.hadoop.ParquetInputFormat import parquet.hadoop.util.ContextUtil import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.{Partition => SparkPartition, Logging} import org.apache.spark.rdd.{NewHadoopPartition, RDD} -import org.apache.spark.sql.{SQLConf, Row, SQLContext} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} +import org.apache.spark.sql.{Row, SQLConf, SQLContext} +import org.apache.spark.{Logging, Partition => SparkPartition} /** * Allows creation of parquet based tables using the syntax - * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option + * `CREATE TEMPORARY TABLE ... USING org.apache.spark.sql.parquet`. Currently the only option * required is `path`, which should be the location of a collection of, optionally partitioned, * parquet files. */ @@ -193,10 +193,12 @@ case class ParquetRelation2(path: String)(@transient val sqlContext: SQLContext) org.apache.hadoop.mapreduce.lib.input.FileInputFormat.setInputPaths(job, selectedFiles: _*) } - // Push down filters when possible + // Push down filters when possible. Notice that not all filters can be converted to Parquet + // filter predicate. Here we try to convert each individual predicate and only collect those + // convertible ones. predicates - .reduceOption(And) .flatMap(ParquetFilters.createFilter) + .reduceOption(FilterApi.and) .filter(_ => sqlContext.conf.parquetFilterPushDown) .foreach(ParquetInputFormat.setFilterPredicate(jobConf, _)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala index 37853d4d03019..d13f2ce2a5e1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/DataSourceStrategy.scala @@ -18,19 +18,18 @@ package org.apache.spark.sql.sources import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row -import org.apache.spark.sql._ +import org.apache.spark.sql.{Row, Strategy} import org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeReference, AttributeSet, Expression, NamedExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution /** * A Strategy for planning scans over data sources defined using the sources API. */ private[sql] object DataSourceStrategy extends Strategy { - def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projectList, filters, l @ LogicalRelation(t: CatalystScan)) => pruneFilterProjectRaw( l, @@ -112,23 +111,26 @@ private[sql] object DataSourceStrategy extends Strategy { } } + /** Turn Catalyst [[Expression]]s into data source [[Filter]]s. */ protected[sql] def selectFilters(filters: Seq[Expression]): Seq[Filter] = filters.collect { - case expressions.EqualTo(a: Attribute, Literal(v, _)) => EqualTo(a.name, v) - case expressions.EqualTo(Literal(v, _), a: Attribute) => EqualTo(a.name, v) + case expressions.EqualTo(a: Attribute, expressions.Literal(v, _)) => EqualTo(a.name, v) + case expressions.EqualTo(expressions.Literal(v, _), a: Attribute) => EqualTo(a.name, v) - case expressions.GreaterThan(a: Attribute, Literal(v, _)) => GreaterThan(a.name, v) - case expressions.GreaterThan(Literal(v, _), a: Attribute) => LessThan(a.name, v) + case expressions.GreaterThan(a: Attribute, expressions.Literal(v, _)) => GreaterThan(a.name, v) + case expressions.GreaterThan(expressions.Literal(v, _), a: Attribute) => LessThan(a.name, v) - case expressions.LessThan(a: Attribute, Literal(v, _)) => LessThan(a.name, v) - case expressions.LessThan(Literal(v, _), a: Attribute) => GreaterThan(a.name, v) + case expressions.LessThan(a: Attribute, expressions.Literal(v, _)) => LessThan(a.name, v) + case expressions.LessThan(expressions.Literal(v, _), a: Attribute) => GreaterThan(a.name, v) - case expressions.GreaterThanOrEqual(a: Attribute, Literal(v, _)) => + case expressions.GreaterThanOrEqual(a: Attribute, expressions.Literal(v, _)) => GreaterThanOrEqual(a.name, v) - case expressions.GreaterThanOrEqual(Literal(v, _), a: Attribute) => + case expressions.GreaterThanOrEqual(expressions.Literal(v, _), a: Attribute) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(a: Attribute, Literal(v, _)) => LessThanOrEqual(a.name, v) - case expressions.LessThanOrEqual(Literal(v, _), a: Attribute) => GreaterThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(a: Attribute, expressions.Literal(v, _)) => + LessThanOrEqual(a.name, v) + case expressions.LessThanOrEqual(expressions.Literal(v, _), a: Attribute) => + GreaterThanOrEqual(a.name, v) case expressions.InSet(a: Attribute, set) => In(a.name, set.toArray) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala index 171b816a26332..b7c721f8c0691 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/ddl.scala @@ -20,14 +20,13 @@ package org.apache.spark.sql.sources import scala.language.implicitConversions import org.apache.spark.Logging -import org.apache.spark.sql.{SchemaRDD, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ import org.apache.spark.util.Utils - /** * A parser for foreign DDL commands. */ @@ -59,6 +58,7 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { protected val TABLE = Keyword("TABLE") protected val USING = Keyword("USING") protected val OPTIONS = Keyword("OPTIONS") + protected val COMMENT = Keyword("COMMENT") // Data types. protected val STRING = Keyword("STRING") @@ -111,8 +111,13 @@ private[sql] class DDLParser extends AbstractSparkSQLParser with Logging { protected lazy val pair: Parser[(String, String)] = ident ~ stringLit ^^ { case k ~ v => (k,v) } protected lazy val column: Parser[StructField] = - ident ~ dataType ^^ { case columnName ~ typ => - StructField(columnName, typ) + ident ~ dataType ~ (COMMENT ~> stringLit).? ^^ { case columnName ~ typ ~ cm => + val meta = cm match { + case Some(comment) => + new MetadataBuilder().putString(COMMENT.str.toLowerCase(), comment).build() + case None => Metadata.empty + } + StructField(columnName, typ, true, meta) } protected lazy val primitiveType: Parser[DataType] = @@ -225,7 +230,8 @@ private [sql] case class CreateTempTableUsing( def run(sqlContext: SQLContext) = { val resolved = ResolvedDataSource(sqlContext, userSpecifiedSchema, provider, options) - new SchemaRDD(sqlContext, LogicalRelation(resolved.relation)).registerTempTable(tableName) + sqlContext.registerRDDAsTable( + new DataFrame(sqlContext, LogicalRelation(resolved.relation)), tableName) Seq.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala index f9c082216085d..906455dd40c0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.test import scala.language.implicitConversions import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.{SchemaRDD, SQLConf, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** A SQLContext that can be used for local testing. */ @@ -37,11 +37,11 @@ object TestSQLContext } /** - * Turn a logical plan into a SchemaRDD. This should be removed once we have an easier way to - * construct SchemaRDD directly out of local data without relying on implicits. + * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to + * construct [[DataFrame]] directly out of local data without relying on implicits. */ - protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): SchemaRDD = { - new SchemaRDD(this, plan) + protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + new DataFrame(this, plan) } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java index 9ff40471a00af..e5588938ea162 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaAPISuite.java @@ -61,7 +61,7 @@ public Integer call(String str) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test')").head(); assert(result.getInt(0) == 4); } @@ -81,7 +81,7 @@ public Integer call(String str1, String str2) throws Exception { } }, DataTypes.IntegerType); - Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").first(); + Row result = sqlContext.sql("SELECT stringLengthTest('test', 'test2')").head(); assert(result.getInt(0) == 9); } } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java index 9e96738ac095a..badd00d34b9b1 100644 --- a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaApplySchemaSuite.java @@ -98,8 +98,8 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - SchemaRDD schemaRDD = javaSqlCtx.applySchema(rowRDD.rdd(), schema); - schemaRDD.registerTempTable("people"); + DataFrame df = javaSqlCtx.applySchema(rowRDD.rdd(), schema); + df.registerTempTable("people"); Row[] actual = javaSqlCtx.sql("SELECT * FROM people").collect(); List expected = new ArrayList(2); @@ -147,17 +147,17 @@ public void applySchemaToJSON() { null, "this is another simple string.")); - SchemaRDD schemaRDD1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); - StructType actualSchema1 = schemaRDD1.schema(); + DataFrame df1 = javaSqlCtx.jsonRDD(jsonRDD.rdd()); + StructType actualSchema1 = df1.schema(); Assert.assertEquals(expectedSchema, actualSchema1); - schemaRDD1.registerTempTable("jsonTable1"); + df1.registerTempTable("jsonTable1"); List actual1 = javaSqlCtx.sql("select * from jsonTable1").collectAsList(); Assert.assertEquals(expectedResult, actual1); - SchemaRDD schemaRDD2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); - StructType actualSchema2 = schemaRDD2.schema(); + DataFrame df2 = javaSqlCtx.jsonRDD(jsonRDD.rdd(), expectedSchema); + StructType actualSchema2 = df2.schema(); Assert.assertEquals(expectedSchema, actualSchema2); - schemaRDD2.registerTempTable("jsonTable2"); + df2.registerTempTable("jsonTable2"); List actual2 = javaSqlCtx.sql("select * from jsonTable2").collectAsList(); Assert.assertEquals(expectedResult, actual2); } diff --git a/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java new file mode 100644 index 0000000000000..639436368c4a3 --- /dev/null +++ b/sql/core/src/test/java/org/apache/spark/sql/api/java/JavaDsl.java @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +import com.google.common.collect.ImmutableMap; + +import org.apache.spark.sql.Column; +import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.types.DataTypes; + +import static org.apache.spark.sql.Dsl.*; + +/** + * This test doesn't actually run anything. It is here to check the API compatibility for Java. + */ +public class JavaDsl { + + public static void testDataFrame(final DataFrame df) { + DataFrame df1 = df.select("colA"); + df1 = df.select("colA", "colB"); + + df1 = df.select(col("colA"), col("colB"), lit("literal value").$plus(1)); + + df1 = df.filter(col("colA")); + + java.util.Map aggExprs = ImmutableMap.builder() + .put("colA", "sum") + .put("colB", "avg") + .build(); + + df1 = df.agg(aggExprs); + + df1 = df.groupBy("groupCol").agg(aggExprs); + + df1 = df.join(df1, col("key1").$eq$eq$eq(col("key2")), "outer"); + + df.orderBy("colA"); + df.orderBy("colA", "colB", "colC"); + df.orderBy(col("colA").desc()); + df.orderBy(col("colA").desc(), col("colB").asc()); + + df.sort("colA"); + df.sort("colA", "colB", "colC"); + df.sort(col("colA").desc()); + df.sort(col("colA").desc(), col("colB").asc()); + + df.as("b"); + + df.limit(5); + + df.unionAll(df1); + df.intersect(df1); + df.except(df1); + + df.sample(true, 0.1, 234); + + df.head(); + df.head(5); + df.first(); + df.count(); + } + + public static void testColumn(final Column c) { + c.asc(); + c.desc(); + + c.endsWith("abcd"); + c.startsWith("afgasdf"); + + c.like("asdf%"); + c.rlike("wef%asdf"); + + c.as("newcol"); + + c.cast("int"); + c.cast(DataTypes.IntegerType); + } + + public static void testDsl() { + // Creating a column. + Column c = col("abcd"); + Column c1 = column("abcd"); + + // Literals + Column l1 = lit(1); + Column l2 = lit(1.0); + Column l3 = lit("abcd"); + + // Functions + Column a = upper(c); + a = lower(c); + a = sqrt(c); + a = abs(c); + + // Aggregates + a = min(c); + a = max(c); + a = sum(c); + a = sumDistinct(c); + a = countDistinct(c, a); + a = avg(c); + a = first(c); + a = last(c); + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index cfc037caff2a9..c9221f8f934ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.storage.{StorageLevel, RDDBlockId} @@ -50,17 +51,17 @@ class CachedTableSuite extends QueryTest { } test("unpersist an uncached table will not raise exception") { - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.unpersist(true) - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.unpersist(false) - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != lookupCachedData(testData)) + assert(None != cacheManager.lookupCachedData(testData)) testData.unpersist(true) - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) testData.unpersist(false) - assert(None == lookupCachedData(testData)) + assert(None == cacheManager.lookupCachedData(testData)) } test("cache table as select") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala new file mode 100644 index 0000000000000..2d464c2b53d79 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -0,0 +1,306 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.Dsl._ +import org.apache.spark.sql.test.TestSQLContext +import org.apache.spark.sql.types.{BooleanType, IntegerType, StructField, StructType} + + +class ColumnExpressionSuite extends QueryTest { + import org.apache.spark.sql.TestData._ + + // TODO: Add test cases for bitwise operations. + + test("star") { + checkAnswer(testData.select($"*"), testData.collect().toSeq) + } + + test("star qualified by data frame object") { + // This is not yet supported. + val df = testData.toDataFrame + val goldAnswer = df.collect().toSeq + checkAnswer(df.select(df("*")), goldAnswer) + + val df1 = df.select(df("*"), lit("abcd").as("litCol")) + checkAnswer(df1.select(df("*")), goldAnswer) + } + + test("star qualified by table name") { + checkAnswer(testData.as("testData").select($"testData.*"), testData.collect().toSeq) + } + + test("+") { + checkAnswer( + testData2.select($"a" + 1), + testData2.collect().toSeq.map(r => Row(r.getInt(0) + 1))) + + checkAnswer( + testData2.select($"a" + $"b" + 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) + r.getInt(1) + 2))) + } + + test("-") { + checkAnswer( + testData2.select($"a" - 1), + testData2.collect().toSeq.map(r => Row(r.getInt(0) - 1))) + + checkAnswer( + testData2.select($"a" - $"b" - 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) - r.getInt(1) - 2))) + } + + test("*") { + checkAnswer( + testData2.select($"a" * 10), + testData2.collect().toSeq.map(r => Row(r.getInt(0) * 10))) + + checkAnswer( + testData2.select($"a" * $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0) * r.getInt(1)))) + } + + test("/") { + checkAnswer( + testData2.select($"a" / 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0).toDouble / 2))) + + checkAnswer( + testData2.select($"a" / $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0).toDouble / r.getInt(1)))) + } + + + test("%") { + checkAnswer( + testData2.select($"a" % 2), + testData2.collect().toSeq.map(r => Row(r.getInt(0) % 2))) + + checkAnswer( + testData2.select($"a" % $"b"), + testData2.collect().toSeq.map(r => Row(r.getInt(0) % r.getInt(1)))) + } + + test("unary -") { + checkAnswer( + testData2.select(-$"a"), + testData2.collect().toSeq.map(r => Row(-r.getInt(0)))) + } + + test("unary !") { + checkAnswer( + complexData.select(!$"b"), + complexData.collect().toSeq.map(r => Row(!r.getBoolean(3)))) + } + + test("isNull") { + checkAnswer( + nullStrings.toDataFrame.where($"s".isNull), + nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) + } + + test("isNotNull") { + checkAnswer( + nullStrings.toDataFrame.where($"s".isNotNull), + nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) + } + + test("===") { + checkAnswer( + testData2.filter($"a" === 1), + testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) + + checkAnswer( + testData2.filter($"a" === $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) + } + + test("<=>") { + checkAnswer( + testData2.filter($"a" === 1), + testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) + + checkAnswer( + testData2.filter($"a" === $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) + } + + test("!==") { + val nullData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + Row(1, 1) :: + Row(1, 2) :: + Row(1, null) :: + Row(null, null) :: Nil), + StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) + + checkAnswer( + nullData.filter($"b" <=> 1), + Row(1, 1) :: Nil) + + checkAnswer( + nullData.filter($"b" <=> null), + Row(1, null) :: Row(null, null) :: Nil) + + checkAnswer( + nullData.filter($"a" <=> $"b"), + Row(1, 1) :: Row(null, null) :: Nil) + } + + test(">") { + checkAnswer( + testData2.filter($"a" > 1), + testData2.collect().toSeq.filter(r => r.getInt(0) > 1)) + + checkAnswer( + testData2.filter($"a" > $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) > r.getInt(1))) + } + + test(">=") { + checkAnswer( + testData2.filter($"a" >= 1), + testData2.collect().toSeq.filter(r => r.getInt(0) >= 1)) + + checkAnswer( + testData2.filter($"a" >= $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) >= r.getInt(1))) + } + + test("<") { + checkAnswer( + testData2.filter($"a" < 2), + testData2.collect().toSeq.filter(r => r.getInt(0) < 2)) + + checkAnswer( + testData2.filter($"a" < $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) < r.getInt(1))) + } + + test("<=") { + checkAnswer( + testData2.filter($"a" <= 2), + testData2.collect().toSeq.filter(r => r.getInt(0) <= 2)) + + checkAnswer( + testData2.filter($"a" <= $"b"), + testData2.collect().toSeq.filter(r => r.getInt(0) <= r.getInt(1))) + } + + val booleanData = TestSQLContext.applySchema(TestSQLContext.sparkContext.parallelize( + Row(false, false) :: + Row(false, true) :: + Row(true, false) :: + Row(true, true) :: Nil), + StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) + + test("&&") { + checkAnswer( + booleanData.filter($"a" && true), + Row(true, false) :: Row(true, true) :: Nil) + + checkAnswer( + booleanData.filter($"a" && false), + Nil) + + checkAnswer( + booleanData.filter($"a" && $"b"), + Row(true, true) :: Nil) + } + + test("||") { + checkAnswer( + booleanData.filter($"a" || true), + booleanData.collect()) + + checkAnswer( + booleanData.filter($"a" || false), + Row(true, false) :: Row(true, true) :: Nil) + + checkAnswer( + booleanData.filter($"a" || $"b"), + Row(false, true) :: Row(true, false) :: Row(true, true) :: Nil) + } + + test("sqrt") { + checkAnswer( + testData.select(sqrt('key)).orderBy('key.asc), + (1 to 100).map(n => Row(math.sqrt(n))) + ) + + checkAnswer( + testData.select(sqrt('value), 'key).orderBy('key.asc, 'value.asc), + (1 to 100).map(n => Row(math.sqrt(n), n)) + ) + + checkAnswer( + testData.select(sqrt(lit(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("abs") { + checkAnswer( + testData.select(abs('key)).orderBy('key.asc), + (1 to 100).map(n => Row(n)) + ) + + checkAnswer( + negativeData.select(abs('key)).orderBy('key.desc), + (1 to 100).map(n => Row(n)) + ) + + checkAnswer( + testData.select(abs(lit(null))), + (1 to 100).map(_ => Row(null)) + ) + } + + test("upper") { + checkAnswer( + lowerCaseData.select(upper('l)), + ('a' to 'd').map(c => Row(c.toString.toUpperCase)) + ) + + checkAnswer( + testData.select(upper('value), 'key), + (1 to 100).map(n => Row(n.toString, n)) + ) + + checkAnswer( + testData.select(upper(lit(null))), + (1 to 100).map(n => Row(null)) + ) + } + + test("lower") { + checkAnswer( + upperCaseData.select(lower('L)), + ('A' to 'F').map(c => Row(c.toString.toLowerCase)) + ) + + checkAnswer( + testData.select(lower('value), 'key), + (1 to 100).map(n => Row(n.toString, n)) + ) + + checkAnswer( + testData.select(lower(lit(null))), + (1 to 100).map(n => Row(null)) + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala similarity index 52% rename from sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index afbfe214f1ce4..df343adc793bd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.types._ /* Implicits */ -import org.apache.spark.sql.catalyst.dsl._ import org.apache.spark.sql.test.TestSQLContext._ import scala.language.postfixOps -class DslQuerySuite extends QueryTest { +class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ test("table scan") { @@ -44,46 +42,46 @@ class DslQuerySuite extends QueryTest { test("agg") { checkAnswer( - testData2.groupBy('a)('a, sum('b)), + testData2.groupBy("a").agg($"a", sum($"b")), Seq(Row(1,3), Row(2,3), Row(3,3)) ) checkAnswer( - testData2.groupBy('a)('a, sum('b) as 'totB).aggregate(sum('totB)), + testData2.groupBy("a").agg($"a", sum($"b").as("totB")).agg(sum('totB)), Row(9) ) checkAnswer( - testData2.aggregate(sum('b)), + testData2.agg(sum('b)), Row(9) ) } test("convert $\"attribute name\" into unresolved attribute") { checkAnswer( - testData.where($"key" === 1).select($"value"), + testData.where($"key" === lit(1)).select($"value"), Row("1")) } test("convert Scala Symbol 'attrname into unresolved attribute") { checkAnswer( - testData.where('key === 1).select('value), + testData.where('key === lit(1)).select('value), Row("1")) } test("select *") { checkAnswer( - testData.select(Star(None)), + testData.select($"*"), testData.collect().toSeq) } test("simple select") { checkAnswer( - testData.where('key === 1).select('value), + testData.where('key === lit(1)).select('value), Row("1")) } test("select with functions") { checkAnswer( - testData.select(sum('value), avg('value), count(1)), + testData.select(sum('value), avg('value), count(lit(1))), Row(5050.0, 50.5, 100)) checkAnswer( @@ -120,46 +118,19 @@ class DslQuerySuite extends QueryTest { checkAnswer( arrayData.orderBy('data.getItem(0).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) + arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(0).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) + arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(0)).reverse.toSeq) checkAnswer( arrayData.orderBy('data.getItem(1).asc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) + arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).toSeq) checkAnswer( arrayData.orderBy('data.getItem(1).desc), - arrayData.toSchemaRDD.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) - } - - test("partition wide sorting") { - // 2 partitions totally, and - // Partition #1 with values: - // (1, 1) - // (1, 2) - // (2, 1) - // Partition #2 with values: - // (2, 2) - // (3, 1) - // (3, 2) - checkAnswer( - testData2.sortBy('a.asc, 'b.asc), - Seq(Row(1,1), Row(1,2), Row(2,1), Row(2,2), Row(3,1), Row(3,2))) - - checkAnswer( - testData2.sortBy('a.asc, 'b.desc), - Seq(Row(1,2), Row(1,1), Row(2,1), Row(2,2), Row(3,2), Row(3,1))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.desc), - Seq(Row(2,1), Row(1,2), Row(1,1), Row(3,2), Row(3,1), Row(2,2))) - - checkAnswer( - testData2.sortBy('a.desc, 'b.asc), - Seq(Row(2,1), Row(1,1), Row(1,2), Row(3,1), Row(3,2), Row(2,2))) + arrayData.toDataFrame.collect().sortBy(_.getAs[Seq[Int]](0)(1)).reverse.toSeq) } test("limit") { @@ -176,71 +147,51 @@ class DslQuerySuite extends QueryTest { mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) } - test("SPARK-3395 limit distinct") { - val filtered = TestData.testData2 - .distinct() - .orderBy(SortOrder('a, Ascending), SortOrder('b, Ascending)) - .limit(1) - .registerTempTable("onerow") - checkAnswer( - sql("select * from onerow inner join testData2 on onerow.a = testData2.a"), - Row(1, 1, 1, 1) :: - Row(1, 1, 1, 2) :: Nil) - } - - test("SPARK-3858 generator qualifiers are discarded") { - checkAnswer( - arrayData.as('ad) - .generate(Explode("data" :: Nil, 'data), alias = Some("ex")) - .select("ex.data".attr), - Seq(1, 2, 3, 2, 3, 4).map(Row(_))) - } - test("average") { checkAnswer( - testData2.aggregate(avg('a)), + testData2.agg(avg('a)), Row(2.0)) checkAnswer( - testData2.aggregate(avg('a), sumDistinct('a)), // non-partial + testData2.agg(avg('a), sumDistinct('a)), // non-partial Row(2.0, 6.0) :: Nil) checkAnswer( - decimalData.aggregate(avg('a)), + decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) checkAnswer( - decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial + decimalData.agg(avg('a), sumDistinct('a)), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2))), + decimalData.agg(avg('a cast DecimalType(10, 2))), Row(new java.math.BigDecimal(2.0))) checkAnswer( - decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial + decimalData.agg(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } test("null average") { checkAnswer( - testData3.aggregate(avg('b)), + testData3.agg(avg('b)), Row(2.0)) checkAnswer( - testData3.aggregate(avg('b), countDistinct('b)), + testData3.agg(avg('b), countDistinct('b)), Row(2.0, 1)) checkAnswer( - testData3.aggregate(avg('b), sumDistinct('b)), // non-partial + testData3.agg(avg('b), sumDistinct('b)), // non-partial Row(2.0, 2.0)) } test("zero average") { checkAnswer( - emptyTableData.aggregate(avg('a)), + emptyTableData.agg(avg('a)), Row(null)) checkAnswer( - emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial + emptyTableData.agg(avg('a), sumDistinct('b)), // non-partial Row(null, null)) } @@ -248,28 +199,28 @@ class DslQuerySuite extends QueryTest { assert(testData2.count() === testData2.map(_ => 1).count()) checkAnswer( - testData2.aggregate(count('a), sumDistinct('a)), // non-partial + testData2.agg(count('a), sumDistinct('a)), // non-partial Row(6, 6.0)) } test("null count") { checkAnswer( - testData3.groupBy('a)('a, count('b)), + testData3.groupBy('a).agg('a, count('b)), Seq(Row(1,0), Row(2, 1)) ) checkAnswer( - testData3.groupBy('a)('a, count('a + 'b)), + testData3.groupBy('a).agg('a, count('a + 'b)), Seq(Row(1,0), Row(2, 1)) ) checkAnswer( - testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), + testData3.agg(count('a), count('b), count(lit(1)), countDistinct('a), countDistinct('b)), Row(2, 1, 2, 2, 1) ) checkAnswer( - testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial + testData3.agg(count('b), countDistinct('b), sumDistinct('b)), // non-partial Row(1, 1, 2) ) } @@ -278,19 +229,19 @@ class DslQuerySuite extends QueryTest { assert(emptyTableData.count() === 0) checkAnswer( - emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial + emptyTableData.agg(count('a), sumDistinct('a)), // non-partial Row(0, null)) } test("zero sum") { checkAnswer( - emptyTableData.aggregate(sum('a)), + emptyTableData.agg(sum('a)), Row(null)) } test("zero sum distinct") { checkAnswer( - emptyTableData.aggregate(sumDistinct('a)), + emptyTableData.agg(sumDistinct('a)), Row(null)) } @@ -320,76 +271,14 @@ class DslQuerySuite extends QueryTest { checkAnswer( // SELECT *, foo(key, value) FROM testData - testData.select(Star(None), foo.call('key, 'value)).limit(3), + testData.select($"*", callUDF(foo, 'key, 'value)).limit(3), Row(1, "1", "11") :: Row(2, "2", "22") :: Row(3, "3", "33") :: Nil ) } - test("sqrt") { - checkAnswer( - testData.select(sqrt('key)).orderBy('key asc), - (1 to 100).map(n => Row(math.sqrt(n))) - ) - - checkAnswer( - testData.select(sqrt('value), 'key).orderBy('key asc, 'value asc), - (1 to 100).map(n => Row(math.sqrt(n), n)) - ) - - checkAnswer( - testData.select(sqrt(Literal(null))), - (1 to 100).map(_ => Row(null)) - ) - } - - test("abs") { - checkAnswer( - testData.select(abs('key)).orderBy('key asc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - negativeData.select(abs('key)).orderBy('key desc), - (1 to 100).map(n => Row(n)) - ) - - checkAnswer( - testData.select(abs(Literal(null))), - (1 to 100).map(_ => Row(null)) - ) - } - - test("upper") { - checkAnswer( - lowerCaseData.select(upper('l)), - ('a' to 'd').map(c => Row(c.toString.toUpperCase())) - ) - - checkAnswer( - testData.select(upper('value), 'key), - (1 to 100).map(n => Row(n.toString, n)) - ) - - checkAnswer( - testData.select(upper(Literal(null))), - (1 to 100).map(n => Row(null)) - ) + test("apply on query results (SPARK-5462)") { + val df = testData.sqlContext.sql("select key from testData") + checkAnswer(df("key"), testData.select('key).collect().toSeq) } - test("lower") { - checkAnswer( - upperCaseData.select(lower('L)), - ('A' to 'F').map(c => Row(c.toString.toLowerCase())) - ) - - checkAnswer( - testData.select(lower('value), 'key), - (1 to 100).map(n => Row(n.toString, n)) - ) - - checkAnswer( - testData.select(lower(Literal(null))), - (1 to 100).map(n => Row(null)) - ) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index cd36da7751e83..f0c939dbb195f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,19 +20,20 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.test.TestSQLContext._ + class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData test("equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.analyzed val planned = planner.HashJoin(join) assert(planned.size === 1) } @@ -58,7 +59,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - clearCache() + cacheManager.clearCache() Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), @@ -92,7 +93,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted hash join operator selection") { - clearCache() + cacheManager.clearCache() sql("CACHE TABLE testData") Seq( @@ -105,17 +106,16 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("multiple-key equi-join is hash-join") { - val x = testData2.as('x) - val y = testData2.as('y) - val join = x.join(y, Inner, - Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed + val x = testData2.as("x") + val y = testData2.as("y") + val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.analyzed val planned = planner.HashJoin(join) assert(planned.size === 1) } test("inner join where, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner).where('n === 'N), + upperCaseData.join(lowerCaseData).where('n === 'N), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -126,7 +126,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("inner join ON, one match per row") { checkAnswer( - upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N"), Seq( Row(1, "A", 1, "a"), Row(2, "B", 2, "b"), @@ -136,10 +136,10 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, where, multiple matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 1).as('y) + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 1).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Row(1,1,1,1) :: Row(1,1,1,2) :: Row(1,2,1,1) :: @@ -148,22 +148,21 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("inner join, no matches") { - val x = testData2.where('a === 1).as('x) - val y = testData2.where('a === 2).as('y) + val x = testData2.where($"a" === 1).as("x") + val y = testData2.where($"a" === 2).as("y") checkAnswer( - x.join(y).where("x.a".attr === "y.a".attr), + x.join(y).where($"x.a" === $"y.a"), Nil) } test("big inner join, 4 matches per row") { val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData) - val bigDataX = bigData.as('x) - val bigDataY = bigData.as('y) + val bigDataX = bigData.as("x") + val bigDataY = bigData.as("y") checkAnswer( - bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr), - testData.flatMap( - row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) + bigDataX.join(bigDataY).where($"x.key" === $"y.key"), + testData.rdd.flatMap(row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } test("cartisian product join") { @@ -177,7 +176,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("left outer join") { checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)), + upperCaseData.join(lowerCaseData, $"n" === $"N", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -186,7 +185,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -195,7 +194,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left"), Row(1, "A", null, null) :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -204,7 +203,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, "F", null, null) :: Nil) checkAnswer( - upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)), + upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left"), Row(1, "A", 1, "a") :: Row(2, "B", 2, "b") :: Row(3, "C", 3, "c") :: @@ -240,7 +239,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { test("right outer join") { checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)), + lowerCaseData.join(upperCaseData, $"n" === $"N", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -248,7 +247,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'n > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -256,7 +255,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'N > 1)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right"), Row(null, null, 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -264,7 +263,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 5, "E") :: Row(null, null, 6, "F") :: Nil) checkAnswer( - lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N && 'l > 'L)), + lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right"), Row(1, "a", 1, "A") :: Row(2, "b", 2, "B") :: Row(3, "c", 3, "C") :: @@ -306,7 +305,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { val right = UnresolvedRelation(Seq("right"), None) checkAnswer( - left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)), + left.join(right, $"left.N" === $"right.N", "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", 3, "C") :: @@ -315,7 +314,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"left.N" !== 3), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -325,7 +324,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, null, 6, "F") :: Nil) checkAnswer( - left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))), + left.join(right, ($"left.N" === $"right.N") && ($"right.N" !== 3), "full"), Row(1, "A", null, null) :: Row(2, "B", null, null) :: Row(3, "C", null, null) :: @@ -385,7 +384,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - clearCache() + cacheManager.clearCache() sql("CACHE TABLE testData") val tmp = conf.autoBroadcastJoinThreshold diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 42a21c148df53..a7f2faa3ecf75 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -26,12 +26,12 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { val outputs = rdd.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { @@ -44,10 +44,10 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -91,7 +91,7 @@ class QueryTest extends PlanTest { } } - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { checkAnswer(rdd, Seq(expectedAnswer)) } @@ -101,8 +101,10 @@ class QueryTest extends PlanTest { } } - /** Asserts that a given SchemaRDD will be executed using the given number of cached results. */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + /** + * Asserts that a given [[DataFrame]] will be executed using the given number of cached results. + */ + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 03b44ca1d6695..d82c34316cefa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -21,6 +21,7 @@ import java.util.TimeZone import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ @@ -29,6 +30,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ + class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. TestData @@ -86,6 +88,18 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } + test("Add Parser of SQL COALESCE()") { + checkAnswer( + sql("""SELECT COALESCE(1, 2)"""), + Row(1)) + checkAnswer( + sql("SELECT COALESCE(null, 1, 1.5)"), + Row(1.toDouble)) + checkAnswer( + sql("SELECT COALESCE(null, null, null)"), + Row(null)) + } + test("SPARK-3176 Added Parser of SQL LAST()") { checkAnswer( sql("SELECT LAST(n) FROM lowerCaseData"), @@ -184,6 +198,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Seq(Row(1,3), Row(2,3), Row(3,3))) } + test("literal in agg grouping expressions") { + checkAnswer( + sql("SELECT a, count(1) FROM testData2 GROUP BY a, 1"), + Seq(Row(1,2), Row(2,2), Row(3,2))) + checkAnswer( + sql("SELECT a, count(2) FROM testData2 GROUP BY a, 2"), + Seq(Row(1,2), Row(2,2), Row(3,2))) + } + test("aggregates with nulls") { checkAnswer( sql("SELECT MIN(a), MAX(a), AVG(a), SUM(a), COUNT(a) FROM nullInts"), @@ -381,8 +404,6 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("big inner join, 4 matches per row") { - - checkAnswer( sql( """ @@ -396,7 +417,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { | SELECT * FROM testData UNION ALL | SELECT * FROM testData) y |WHERE x.key = y.key""".stripMargin), - testData.flatMap( + testData.rdd.flatMap( row => Seq.fill(16)(Row.merge(row, row))).collect().toSeq) } @@ -651,8 +672,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val schemaRDD1 = applySchema(rowRDD1, schema1) - schemaRDD1.registerTempTable("applySchema1") + val df1 = applySchema(rowRDD1, schema1) + df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), Row(1, "A1", true, null) :: @@ -681,8 +702,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val schemaRDD2 = applySchema(rowRDD2, schema2) - schemaRDD2.registerTempTable("applySchema2") + val df2 = applySchema(rowRDD2, schema2) + df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), Row(Row(1, true), Map("A1" -> null)) :: @@ -706,8 +727,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val schemaRDD3 = applySchema(rowRDD3, schema2) - schemaRDD3.registerTempTable("applySchema3") + val df3 = applySchema(rowRDD3, schema2) + df3.registerTempTable("applySchema3") checkAnswer( sql("SELECT f1.f11, f2['D4'] FROM applySchema3"), @@ -742,7 +763,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { } test("metadata is propagated correctly") { - val person = sql("SELECT * FROM person") + val person: DataFrame = sql("SELECT * FROM person") val schema = person.schema val docKey = "doc" val docValue = "first name" @@ -751,14 +772,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = applySchema(person, schemaWithMeta) - def validateMetadata(rdd: SchemaRDD): Unit = { + val personWithMeta = applySchema(person.rdd, schemaWithMeta) + def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } personWithMeta.registerTempTable("personWithMeta") - validateMetadata(personWithMeta.select('name)) - validateMetadata(personWithMeta.select("name".attr)) - validateMetadata(personWithMeta.select('id, 'name)) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"name")) + validateMetadata(personWithMeta.select($"id", $"name")) validateMetadata(sql("SELECT * FROM personWithMeta")) validateMetadata(sql("SELECT id, name FROM personWithMeta")) validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 808ed5288cfb8..dd781169ca57f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.sql.Timestamp import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.test._ /* Implicits */ @@ -29,11 +30,11 @@ case class TestData(key: Int, value: String) object TestData { val testData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(i, i.toString))).toSchemaRDD + (1 to 100).map(i => TestData(i, i.toString))).toDataFrame testData.registerTempTable("testData") val negativeData = TestSQLContext.sparkContext.parallelize( - (1 to 100).map(i => TestData(-i, (-i).toString))).toSchemaRDD + (1 to 100).map(i => TestData(-i, (-i).toString))).toDataFrame negativeData.registerTempTable("negativeData") case class LargeAndSmallInts(a: Int, b: Int) @@ -44,7 +45,7 @@ object TestData { LargeAndSmallInts(2147483645, 1) :: LargeAndSmallInts(2, 2) :: LargeAndSmallInts(2147483646, 1) :: - LargeAndSmallInts(3, 2) :: Nil).toSchemaRDD + LargeAndSmallInts(3, 2) :: Nil).toDataFrame largeAndSmallInts.registerTempTable("largeAndSmallInts") case class TestData2(a: Int, b: Int) @@ -55,7 +56,7 @@ object TestData { TestData2(2, 1) :: TestData2(2, 2) :: TestData2(3, 1) :: - TestData2(3, 2) :: Nil, 2).toSchemaRDD + TestData2(3, 2) :: Nil, 2).toDataFrame testData2.registerTempTable("testData2") case class DecimalData(a: BigDecimal, b: BigDecimal) @@ -67,7 +68,7 @@ object TestData { DecimalData(2, 1) :: DecimalData(2, 2) :: DecimalData(3, 1) :: - DecimalData(3, 2) :: Nil).toSchemaRDD + DecimalData(3, 2) :: Nil).toDataFrame decimalData.registerTempTable("decimalData") case class BinaryData(a: Array[Byte], b: Int) @@ -77,17 +78,17 @@ object TestData { BinaryData("22".getBytes(), 5) :: BinaryData("122".getBytes(), 3) :: BinaryData("121".getBytes(), 2) :: - BinaryData("123".getBytes(), 4) :: Nil).toSchemaRDD + BinaryData("123".getBytes(), 4) :: Nil).toDataFrame binaryData.registerTempTable("binaryData") case class TestData3(a: Int, b: Option[Int]) val testData3 = TestSQLContext.sparkContext.parallelize( TestData3(1, None) :: - TestData3(2, Some(2)) :: Nil).toSchemaRDD + TestData3(2, Some(2)) :: Nil).toDataFrame testData3.registerTempTable("testData3") - val emptyTableData = logical.LocalRelation('a.int, 'b.int) + val emptyTableData = logical.LocalRelation($"a".int, $"b".int) case class UpperCaseData(N: Int, L: String) val upperCaseData = @@ -97,7 +98,7 @@ object TestData { UpperCaseData(3, "C") :: UpperCaseData(4, "D") :: UpperCaseData(5, "E") :: - UpperCaseData(6, "F") :: Nil).toSchemaRDD + UpperCaseData(6, "F") :: Nil).toDataFrame upperCaseData.registerTempTable("upperCaseData") case class LowerCaseData(n: Int, l: String) @@ -106,7 +107,7 @@ object TestData { LowerCaseData(1, "a") :: LowerCaseData(2, "b") :: LowerCaseData(3, "c") :: - LowerCaseData(4, "d") :: Nil).toSchemaRDD + LowerCaseData(4, "d") :: Nil).toDataFrame lowerCaseData.registerTempTable("lowerCaseData") case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]]) @@ -160,7 +161,7 @@ object TestData { TestSQLContext.sparkContext.parallelize( NullStrings(1, "abc") :: NullStrings(2, "ABC") :: - NullStrings(3, null) :: Nil) + NullStrings(3, null) :: Nil).toDataFrame nullStrings.registerTempTable("nullStrings") case class TableName(tableName: String) @@ -200,6 +201,6 @@ object TestData { TestSQLContext.sparkContext.parallelize( ComplexData(Map(1 -> "1"), TestData(1, "1"), Seq(1), true) :: ComplexData(Map(2 -> "2"), TestData(2, "2"), Seq(2), false) - :: Nil).toSchemaRDD + :: Nil).toDataFrame complexData.registerTempTable("complexData") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 0c98120031242..95923f9aad931 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import org.apache.spark.sql.Dsl.StringToColumn import org.apache.spark.sql.test._ /* Implicits */ @@ -28,25 +29,25 @@ class UDFSuite extends QueryTest { test("Simple UDF") { udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) + assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").first().getDouble(0) >= 0.0) + assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { udf.register("strLenScala", (_: String).length + (_:Int)) - assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) + assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) - val result= + val result = sql("SELECT returnStruct('test', 'test2') as ret") - .select("ret.f1".attr).first().getString(0) - assert(result == "test") + .select($"ret.f1").head().getString(0) + assert(result === "test") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index fbc8704f7837b..0696a2335e63f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types._ + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -66,14 +68,14 @@ class UserDefinedTypeSuite extends QueryTest { test("register user type: MyDenseVector for MyLabeledPoint") { - val labels: RDD[Double] = pointsRDD.select('label).map { case Row(v: Double) => v } + val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } val labelsArrays: Array[Double] = labels.collect() assert(labelsArrays.size === 2) assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) val features: RDD[MyDenseVector] = - pointsRDD.select('features).map { case Row(v: MyDenseVector) => v } + pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v } val featuresArrays: Array[MyDenseVector] = features.collect() assert(featuresArrays.size === 2) assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala index e61f3c39631da..3d33484ab0eb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.columnar +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala index c3a3f8ddc3ebf..fe9a69edbb920 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala @@ -104,14 +104,14 @@ class PartitionBatchPruningSuite extends FunSuite with BeforeAndAfterAll with Be expectedQueryResult: => Seq[Int]): Unit = { test(query) { - val schemaRdd = sql(query) - val queryExecution = schemaRdd.queryExecution + val df = sql(query) + val queryExecution = df.queryExecution assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") { - schemaRdd.collect().map(_(0)).toArray + df.collect().map(_(0)).toArray } - val (readPartitions, readBatches) = schemaRdd.queryExecution.executedPlan.collect { + val (readPartitions, readBatches) = df.queryExecution.executedPlan.collect { case in: InMemoryColumnarTableScan => (in.readPartitions.value, in.readBatches.value) }.head diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 67007b8c093ca..df108a9d262bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.scalatest.FunSuite import org.apache.spark.sql.{SQLConf, execution} +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ @@ -28,6 +29,7 @@ import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ + class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan @@ -40,7 +42,7 @@ class PlannerSuite extends FunSuite { } test("count is partially aggregated") { - val query = testData.groupBy('value)(Count('key)).queryExecution.analyzed + val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed val planned = HashAggregation(query).head val aggregations = planned.collect { case n if n.nodeName contains "Aggregate" => n } @@ -48,14 +50,14 @@ class PlannerSuite extends FunSuite { } test("count distinct is partially aggregated") { - val query = testData.groupBy('value)(CountDistinct('key :: Nil)).queryExecution.analyzed + val query = testData.groupBy('value).agg(countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } test("mixed aggregates are partially aggregated") { val query = - testData.groupBy('value)(Count('value), CountDistinct('key :: Nil)).queryExecution.analyzed + testData.groupBy('value).agg(count('value), countDistinct('key)).queryExecution.analyzed val planned = HashAggregation(query) assert(planned.nonEmpty) } @@ -128,9 +130,9 @@ class PlannerSuite extends FunSuite { testData.limit(3).registerTempTable("tiny") sql("CACHE TABLE tiny") - val a = testData.as('a) - val b = table("tiny").as('b) - val planned = a.join(b, Inner, Some("a.key".attr === "b.key".attr)).queryExecution.executedPlan + val a = testData.as("a") + val b = table("tiny").as("b") + val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join } val shuffledHashJoins = planned.collect { case join: ShuffledHashJoin => join } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala deleted file mode 100644 index 272c0d4cb2335..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TgfSuite.scala +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ - -/* Implicit conversions */ -import org.apache.spark.sql.test.TestSQLContext._ - -/** - * This is an example TGF that uses UnresolvedAttributes 'name and 'age to access specific columns - * from the input data. These will be replaced during analysis with specific AttributeReferences - * and then bound to specific ordinals during query planning. While TGFs could also access specific - * columns using hand-coded ordinals, doing so violates data independence. - * - * Note: this is only a rough example of how TGFs can be expressed, the final version will likely - * involve a lot more sugar for cleaner use in Scala/Java/etc. - */ -case class ExampleTGF(input: Seq[Expression] = Seq('name, 'age)) extends Generator { - def children = input - protected def makeOutput() = 'nameAndAge.string :: Nil - - val Seq(nameAttr, ageAttr) = input - - override def eval(input: Row): TraversableOnce[Row] = { - val name = nameAttr.eval(input) - val age = ageAttr.eval(input).asInstanceOf[Int] - - Iterator( - new GenericRow(Array[Any](s"$name is $age years old")), - new GenericRow(Array[Any](s"Next year, $name will be ${age + 1} years old"))) - } -} - -class TgfSuite extends QueryTest { - val inputData = - logical.LocalRelation('name.string, 'age.int).loadData( - ("michael", 29) :: Nil - ) - - test("simple tgf example") { - checkAnswer( - inputData.generate(ExampleTGF()), - Seq( - Row("michael is 29 years old"), - Row("Next year, michael will be 30 years old"))) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 87c28c334d228..4e9472c60249e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -23,11 +23,11 @@ import org.apache.spark.sql.TestData._ import org.apache.spark.sql.test.TestSQLContext._ class DebuggingSuite extends FunSuite { - test("SchemaRDD.debug()") { + test("DataFrame.debug()") { testData.debug() } - test("SchemaRDD.typeCheck()") { + test("DataFrame.typeCheck()") { testData.typeCheck() } } \ No newline at end of file diff --git a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala index 94d14acccbb18..cb615388da0c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/json/JsonSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.json.JsonRDD.{compatibleType, enforceCorrectType} import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ @@ -29,6 +30,7 @@ import org.apache.spark.sql.{QueryTest, Row, SQLConf} class JsonSuite extends QueryTest { import org.apache.spark.sql.json.TestJsonData._ + TestJsonData test("Type promotion") { @@ -193,7 +195,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring with null in sampling") { - val jsonSchemaRDD = jsonRDD(jsonNullStruct) + val jsonDF = jsonRDD(jsonNullStruct) val expectedSchema = StructType( StructField("headers", StructType( StructField("Charset", StringType, true) :: @@ -202,8 +204,8 @@ class JsonSuite extends QueryTest { StructField("ip", StringType, true) :: StructField("nullstr", StringType, true):: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + assert(expectedSchema === jsonDF.schema) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select nullstr, headers.Host from jsonTable"), @@ -212,7 +214,7 @@ class JsonSuite extends QueryTest { } test("Primitive field and type inferring") { - val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) + val jsonDF = jsonRDD(primitiveFieldAndType) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -223,9 +225,9 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -240,7 +242,7 @@ class JsonSuite extends QueryTest { } test("Complex field and type inferring") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType1) + val jsonDF = jsonRDD(complexFieldAndType1) val expectedSchema = StructType( StructField("arrayOfArray1", ArrayType(ArrayType(StringType, false), false), true) :: @@ -264,9 +266,9 @@ class JsonSuite extends QueryTest { StructField("field1", ArrayType(IntegerType, false), true) :: StructField("field2", ArrayType(StringType, false), true) :: Nil), true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") // Access elements of a primitive array. checkAnswer( @@ -339,8 +341,8 @@ class JsonSuite extends QueryTest { } ignore("Complex field and type inferring (Ignored)") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType1) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(complexFieldAndType1) + jsonDF.registerTempTable("jsonTable") // Right now, "field1" and "field2" are treated as aliases. We should fix it. checkAnswer( @@ -357,7 +359,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in primitive field values") { - val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) + val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) val expectedSchema = StructType( StructField("num_bool", StringType, true) :: @@ -367,9 +369,9 @@ class JsonSuite extends QueryTest { StructField("num_str", StringType, true) :: StructField("str_bool", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -428,8 +430,8 @@ class JsonSuite extends QueryTest { } ignore("Type conflict in primitive field values (Ignored)") { - val jsonSchemaRDD = jsonRDD(primitiveFieldValueTypeConflict) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(primitiveFieldValueTypeConflict) + jsonDF.registerTempTable("jsonTable") // Right now, the analyzer does not promote strings in a boolean expreesion. // Number and Boolean conflict: resolve the type as boolean in this query. @@ -462,9 +464,9 @@ class JsonSuite extends QueryTest { // We should directly cast num_str to DecimalType and also need to do the right type promotion // in the Project. checkAnswer( - jsonSchemaRDD. + jsonDF. where('num_str > BigDecimal("92233720368547758060")). - select('num_str + 1.2 as Symbol("num")), + select(('num_str + 1.2).as("num")), Row(new java.math.BigDecimal("92233720368547758061.2")) ) @@ -481,7 +483,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in complex field values") { - val jsonSchemaRDD = jsonRDD(complexFieldValueTypeConflict) + val jsonDF = jsonRDD(complexFieldValueTypeConflict) val expectedSchema = StructType( StructField("array", ArrayType(IntegerType, false), true) :: @@ -491,9 +493,9 @@ class JsonSuite extends QueryTest { StructField("field", StringType, true) :: Nil), true) :: StructField("struct_array", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -505,7 +507,7 @@ class JsonSuite extends QueryTest { } test("Type conflict in array elements") { - val jsonSchemaRDD = jsonRDD(arrayElementTypeConflict) + val jsonDF = jsonRDD(arrayElementTypeConflict) val expectedSchema = StructType( StructField("array1", ArrayType(StringType, true), true) :: @@ -513,9 +515,9 @@ class JsonSuite extends QueryTest { StructField("field", LongType, true) :: Nil), false), true) :: StructField("array3", ArrayType(StringType, false), true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -533,7 +535,7 @@ class JsonSuite extends QueryTest { } test("Handling missing fields") { - val jsonSchemaRDD = jsonRDD(missingFields) + val jsonDF = jsonRDD(missingFields) val expectedSchema = StructType( StructField("a", BooleanType, true) :: @@ -543,16 +545,16 @@ class JsonSuite extends QueryTest { StructField("field", BooleanType, true) :: Nil), true) :: StructField("e", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") } test("Loading a JSON dataset from a text file") { val file = getTempFilePath("json") val path = file.toString primitiveFieldAndType.map(record => record.replaceAll("\n", " ")).saveAsTextFile(path) - val jsonSchemaRDD = jsonFile(path) + val jsonDF = jsonFile(path) val expectedSchema = StructType( StructField("bigInteger", DecimalType.Unlimited, true) :: @@ -563,9 +565,9 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - assert(expectedSchema === jsonSchemaRDD.schema) + assert(expectedSchema === jsonDF.schema) - jsonSchemaRDD.registerTempTable("jsonTable") + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select * from jsonTable"), @@ -619,11 +621,11 @@ class JsonSuite extends QueryTest { StructField("null", StringType, true) :: StructField("string", StringType, true) :: Nil) - val jsonSchemaRDD1 = jsonFile(path, schema) + val jsonDF1 = jsonFile(path, schema) - assert(schema === jsonSchemaRDD1.schema) + assert(schema === jsonDF1.schema) - jsonSchemaRDD1.registerTempTable("jsonTable1") + jsonDF1.registerTempTable("jsonTable1") checkAnswer( sql("select * from jsonTable1"), @@ -636,11 +638,11 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val jsonSchemaRDD2 = jsonRDD(primitiveFieldAndType, schema) + val jsonDF2 = jsonRDD(primitiveFieldAndType, schema) - assert(schema === jsonSchemaRDD2.schema) + assert(schema === jsonDF2.schema) - jsonSchemaRDD2.registerTempTable("jsonTable2") + jsonDF2.registerTempTable("jsonTable2") checkAnswer( sql("select * from jsonTable2"), @@ -655,8 +657,8 @@ class JsonSuite extends QueryTest { } test("SPARK-2096 Correctly parse dot notations") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType2) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(complexFieldAndType2) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql("select arrayOfStruct[0].field1, arrayOfStruct[0].field2 from jsonTable"), @@ -673,8 +675,8 @@ class JsonSuite extends QueryTest { } test("SPARK-3390 Complex arrays") { - val jsonSchemaRDD = jsonRDD(complexFieldAndType2) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(complexFieldAndType2) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql( @@ -696,8 +698,8 @@ class JsonSuite extends QueryTest { } test("SPARK-3308 Read top level JSON arrays") { - val jsonSchemaRDD = jsonRDD(jsonArray) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(jsonArray) + jsonDF.registerTempTable("jsonTable") checkAnswer( sql( @@ -717,8 +719,8 @@ class JsonSuite extends QueryTest { val oldColumnNameOfCorruptRecord = TestSQLContext.conf.columnNameOfCorruptRecord TestSQLContext.setConf(SQLConf.COLUMN_NAME_OF_CORRUPT_RECORD, "_unparsed") - val jsonSchemaRDD = jsonRDD(corruptRecords) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(corruptRecords) + jsonDF.registerTempTable("jsonTable") val schema = StructType( StructField("_unparsed", StringType, true) :: @@ -726,7 +728,7 @@ class JsonSuite extends QueryTest { StructField("b", StringType, true) :: StructField("c", StringType, true) :: Nil) - assert(schema === jsonSchemaRDD.schema) + assert(schema === jsonDF.schema) // In HiveContext, backticks should be used to access columns starting with a underscore. checkAnswer( @@ -771,8 +773,8 @@ class JsonSuite extends QueryTest { } test("SPARK-4068: nulls in arrays") { - val jsonSchemaRDD = jsonRDD(nullsInArrays) - jsonSchemaRDD.registerTempTable("jsonTable") + val jsonDF = jsonRDD(nullsInArrays) + jsonDF.registerTempTable("jsonTable") val schema = StructType( StructField("field1", @@ -786,7 +788,7 @@ class JsonSuite extends QueryTest { StructField("field4", ArrayType(ArrayType(ArrayType(IntegerType, false), true), false), true) :: Nil) - assert(schema === jsonSchemaRDD.schema) + assert(schema === jsonDF.schema) checkAnswer( sql( @@ -801,7 +803,7 @@ class JsonSuite extends QueryTest { ) } - test("SPARK-4228 SchemaRDD to JSON") + test("SPARK-4228 DataFrame to JSON") { val schema1 = StructType( StructField("f1", IntegerType, false) :: @@ -818,10 +820,10 @@ class JsonSuite extends QueryTest { Row(values(0).toInt, values(1), values(2).toBoolean, r.split(",").toList, v5) } - val schemaRDD1 = applySchema(rowRDD1, schema1) - schemaRDD1.registerTempTable("applySchema1") - val schemaRDD2 = schemaRDD1.toSchemaRDD - val result = schemaRDD2.toJSON.collect() + val df1 = applySchema(rowRDD1, schema1) + df1.registerTempTable("applySchema1") + val df2 = df1.toDataFrame + val result = df2.toJSON.collect() assert(result(0) == "{\"f1\":1,\"f2\":\"A1\",\"f3\":true,\"f4\":[\"1\",\" A1\",\" true\",\" null\"]}") assert(result(3) == "{\"f1\":4,\"f2\":\"D4\",\"f3\":true,\"f4\":[\"4\",\" D4\",\" true\",\" 2147483644\"],\"f5\":2147483644}") @@ -839,16 +841,16 @@ class JsonSuite extends QueryTest { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val schemaRDD3 = applySchema(rowRDD2, schema2) - schemaRDD3.registerTempTable("applySchema2") - val schemaRDD4 = schemaRDD3.toSchemaRDD - val result2 = schemaRDD4.toJSON.collect() + val df3 = applySchema(rowRDD2, schema2) + df3.registerTempTable("applySchema2") + val df4 = df3.toDataFrame + val result2 = df4.toJSON.collect() assert(result2(1) == "{\"f1\":{\"f11\":2,\"f12\":false},\"f2\":{\"B2\":null}}") assert(result2(3) == "{\"f1\":{\"f11\":4,\"f12\":true},\"f2\":{\"D4\":2147483644}}") - val jsonSchemaRDD = jsonRDD(primitiveFieldAndType) - val primTable = jsonRDD(jsonSchemaRDD.toJSON) + val jsonDF = jsonRDD(primitiveFieldAndType) + val primTable = jsonRDD(jsonDF.toJSON) primTable.registerTempTable("primativeTable") checkAnswer( sql("select * from primativeTable"), @@ -860,8 +862,8 @@ class JsonSuite extends QueryTest { "this is a simple string.") ) - val complexJsonSchemaRDD = jsonRDD(complexFieldAndType1) - val compTable = jsonRDD(complexJsonSchemaRDD.toJSON) + val complexJsonDF = jsonRDD(complexFieldAndType1) + val compTable = jsonRDD(complexJsonDF.toJSON) compTable.registerTempTable("complexTable") // Access elements of a primitive array. checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala index 1e7d3e06fc196..e78145f4dda5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetFilterSuite.scala @@ -21,9 +21,10 @@ import parquet.filter2.predicate.Operators._ import parquet.filter2.predicate.{FilterPredicate, Operators} import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal, Predicate, Row} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, Predicate, Row} +import org.apache.spark.sql.types._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -41,15 +42,17 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { val sqlContext = TestSQLContext private def checkFilterPredicate( - rdd: SchemaRDD, + rdd: DataFrame, predicate: Predicate, filterClass: Class[_ <: FilterPredicate], - checker: (SchemaRDD, Seq[Row]) => Unit, + checker: (DataFrame, Seq[Row]) => Unit, expected: Seq[Row]): Unit = { val output = predicate.collect { case a: Attribute => a }.distinct withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED -> "true") { - val query = rdd.select(output: _*).where(predicate) + val query = rdd + .select(output.map(e => new org.apache.spark.sql.Column(e)): _*) + .where(new org.apache.spark.sql.Column(predicate)) val maybeAnalyzedPredicate = query.queryExecution.executedPlan.collect { case plan: ParquetTableScan => plan.columnPruningPred @@ -71,13 +74,13 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { private def checkFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkFilterPredicate(rdd, predicate, filterClass, checkAnswer(_, _: Seq[Row]), expected) } private def checkFilterPredicate[T] (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: T) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } @@ -91,26 +94,50 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { } } + test("filter pushdown - short") { + withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toShort)))) { implicit rdd => + checkFilterPredicate(Cast('_1, IntegerType) === 1, classOf[Eq [_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) + + checkFilterPredicate(Cast('_1, IntegerType) < 2, classOf[Lt [_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) > 3, classOf[Gt [_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) <= 1, classOf[LtEq[_]], 1) + checkFilterPredicate(Cast('_1, IntegerType) >= 4, classOf[GtEq[_]], 4) + + checkFilterPredicate(Literal(1) === Cast('_1, IntegerType), classOf[Eq [_]], 1) + checkFilterPredicate(Literal(2) > Cast('_1, IntegerType), classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < Cast('_1, IntegerType), classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= Cast('_1, IntegerType), classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= Cast('_1, IntegerType), classOf[GtEq[_]], 4) + + checkFilterPredicate(!(Cast('_1, IntegerType) < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(Cast('_1, IntegerType) > 2 && Cast('_1, IntegerType) < 4, + classOf[Operators.And], 3) + checkFilterPredicate(Cast('_1, IntegerType) < 2 || Cast('_1, IntegerType) > 3, + classOf[Operators.Or], Seq(Row(1), Row(4))) + } + } + test("filter pushdown - integer") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -118,24 +145,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - long") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -143,24 +170,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - float") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq [_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) - checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(1) === '_1, classOf[Eq[_]], 1) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt[_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt[_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -168,24 +195,24 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { test("filter pushdown - double") { withParquetRDD((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit rdd => - checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) + checkFilterPredicate('_1.isNull, classOf[Eq[_]], Seq.empty[Row]) checkFilterPredicate('_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 === 1, classOf[Eq [_]], 1) + checkFilterPredicate('_1 === 1, classOf[Eq[_]], 1) checkFilterPredicate('_1 !== 1, classOf[NotEq[_]], (2 to 4).map(Row.apply(_))) - checkFilterPredicate('_1 < 2, classOf[Lt [_]], 1) - checkFilterPredicate('_1 > 3, classOf[Gt [_]], 4) + checkFilterPredicate('_1 < 2, classOf[Lt[_]], 1) + checkFilterPredicate('_1 > 3, classOf[Gt[_]], 4) checkFilterPredicate('_1 <= 1, classOf[LtEq[_]], 1) checkFilterPredicate('_1 >= 4, classOf[GtEq[_]], 4) checkFilterPredicate(Literal(1) === '_1, classOf[Eq [_]], 1) - checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) - checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) - checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) - checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) + checkFilterPredicate(Literal(2) > '_1, classOf[Lt [_]], 1) + checkFilterPredicate(Literal(3) < '_1, classOf[Gt [_]], 4) + checkFilterPredicate(Literal(1) >= '_1, classOf[LtEq[_]], 1) + checkFilterPredicate(Literal(4) <= '_1, classOf[GtEq[_]], 4) - checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) + checkFilterPredicate(!('_1 < 4), classOf[GtEq[_]], 4) checkFilterPredicate('_1 > 2 && '_1 < 4, classOf[Operators.And], 3) checkFilterPredicate('_1 < 2 || '_1 > 3, classOf[Operators.Or], Seq(Row(1), Row(4))) } @@ -197,30 +224,30 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkFilterPredicate( '_1.isNotNull, classOf[NotEq[_]], (1 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 === "1", classOf[Eq [_]], "1") + checkFilterPredicate('_1 === "1", classOf[Eq[_]], "1") checkFilterPredicate('_1 !== "1", classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.toString))) - checkFilterPredicate('_1 < "2", classOf[Lt [_]], "1") - checkFilterPredicate('_1 > "3", classOf[Gt [_]], "4") + checkFilterPredicate('_1 < "2", classOf[Lt[_]], "1") + checkFilterPredicate('_1 > "3", classOf[Gt[_]], "4") checkFilterPredicate('_1 <= "1", classOf[LtEq[_]], "1") checkFilterPredicate('_1 >= "4", classOf[GtEq[_]], "4") - checkFilterPredicate(Literal("1") === '_1, classOf[Eq [_]], "1") - checkFilterPredicate(Literal("2") > '_1, classOf[Lt [_]], "1") - checkFilterPredicate(Literal("3") < '_1, classOf[Gt [_]], "4") - checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") - checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") + checkFilterPredicate(Literal("1") === '_1, classOf[Eq[_]], "1") + checkFilterPredicate(Literal("2") > '_1, classOf[Lt[_]], "1") + checkFilterPredicate(Literal("3") < '_1, classOf[Gt[_]], "4") + checkFilterPredicate(Literal("1") >= '_1, classOf[LtEq[_]], "1") + checkFilterPredicate(Literal("4") <= '_1, classOf[GtEq[_]], "4") - checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") + checkFilterPredicate(!('_1 < "4"), classOf[GtEq[_]], "4") checkFilterPredicate('_1 > "2" && '_1 < "4", classOf[Operators.And], "3") - checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) + checkFilterPredicate('_1 < "2" || '_1 > "3", classOf[Operators.Or], Seq(Row("1"), Row("4"))) } } def checkBinaryFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Seq[Row]) - (implicit rdd: SchemaRDD): Unit = { - def checkBinaryAnswer(rdd: SchemaRDD, expected: Seq[Row]) = { + (implicit rdd: DataFrame): Unit = { + def checkBinaryAnswer(rdd: DataFrame, expected: Seq[Row]) = { assertResult(expected.map(_.getAs[Array[Byte]](0).mkString(",")).toSeq.sorted) { rdd.map(_.getAs[Array[Byte]](0).mkString(",")).collect().toSeq.sorted } @@ -231,7 +258,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { def checkBinaryFilterPredicate (predicate: Predicate, filterClass: Class[_ <: FilterPredicate], expected: Array[Byte]) - (implicit rdd: SchemaRDD): Unit = { + (implicit rdd: DataFrame): Unit = { checkBinaryFilterPredicate(predicate, filterClass, Seq(Row(expected)))(rdd) } @@ -249,16 +276,16 @@ class ParquetFilterSuite extends QueryTest with ParquetTest { checkBinaryFilterPredicate( '_1 !== 1.b, classOf[NotEq[_]], (2 to 4).map(i => Row.apply(i.b)).toSeq) - checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt [_]], 4.b) + checkBinaryFilterPredicate('_1 < 2.b, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate('_1 > 3.b, classOf[Gt[_]], 4.b) checkBinaryFilterPredicate('_1 <= 1.b, classOf[LtEq[_]], 1.b) checkBinaryFilterPredicate('_1 >= 4.b, classOf[GtEq[_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq [_]], 1.b) - checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt [_]], 1.b) - checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt [_]], 4.b) - checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) - checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) === '_1, classOf[Eq[_]], 1.b) + checkBinaryFilterPredicate(Literal(2.b) > '_1, classOf[Lt[_]], 1.b) + checkBinaryFilterPredicate(Literal(3.b) < '_1, classOf[Gt[_]], 4.b) + checkBinaryFilterPredicate(Literal(1.b) >= '_1, classOf[LtEq[_]], 1.b) + checkBinaryFilterPredicate(Literal(4.b) <= '_1, classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate(!('_1 < 4.b), classOf[GtEq[_]], 4.b) checkBinaryFilterPredicate('_1 > 2.b && '_1 < 4.b, classOf[Operators.And], 3.b) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala index a57e4e85a35ef..d9ab16baf9a66 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetIOSuite.scala @@ -32,12 +32,13 @@ import parquet.schema.{MessageType, MessageTypeParser} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.sql.{DataFrame, QueryTest, SQLConf} +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.expressions.Row import org.apache.spark.sql.test.TestSQLContext import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.DecimalType -import org.apache.spark.sql.{QueryTest, SQLConf, SchemaRDD} // Write support class for nested groups: ParquetWriter initializes GroupWriteSupport // with an empty configuration (it is after all not intended to be used in this way?) @@ -97,11 +98,11 @@ class ParquetIOSuite extends QueryTest with ParquetTest { } test("fixed-length decimals") { - def makeDecimalRDD(decimal: DecimalType): SchemaRDD = + def makeDecimalRDD(decimal: DecimalType): DataFrame = sparkContext .parallelize(0 to 1000) .map(i => Tuple1(i / 100.0)) - .select('_1 cast decimal) + .select($"_1" cast decimal as "abcd") for ((precision, scale) <- Seq((5, 2), (1, 0), (1, 1), (18, 10), (18, 17))) { withTempPath { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 1263ff818ea19..3d82f4bce7778 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -85,4 +85,15 @@ class ParquetQuerySuite extends QueryTest with ParquetTest { checkAnswer(sql(s"SELECT _1 FROM t WHERE _1 < 10"), (1 to 9).map(Row.apply(_))) } } + + test("SPARK-5309 strings stored using dictionary compression in parquet") { + withParquetTable((0 until 1000).map(i => ("same", "run_" + i /100, 1)), "t") { + + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t GROUP BY _1, _2"), + (0 until 10).map(i => Row("same", "run_" + i, 100))) + + checkAnswer(sql(s"SELECT _1, _2, SUM(_3) FROM t WHERE _2 = 'run_5' GROUP BY _1, _2"), + List(Row("same", "run_5", 100))) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala index 7900b3e8948d9..a33cf1172cac9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.sources +import scala.language.existentials + import org.apache.spark.sql._ import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala index b1e0919b7aed1..0a4d4b6342d4f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala @@ -344,4 +344,24 @@ class TableScanSuite extends DataSourceTest { } assert(schemaNeeded.getMessage.contains("A schema needs to be specified when using")) } + + test("SPARK-5196 schema field with comment") { + sql( + """ + |CREATE TEMPORARY TABLE student(name string comment "SN", age int comment "SA", grade int) + |USING org.apache.spark.sql.sources.AllDataTypesScanSource + |OPTIONS ( + | from '1', + | to '10' + |) + """.stripMargin) + + val planned = sql("SELECT * FROM student").queryExecution.executedPlan + val comments = planned.schema.fields.map { field => + if (field.metadata.contains("comment")) field.metadata.getString("comment") + else "NO_COMMENT" + }.mkString(",") + + assert(comments === "SN,SA,NO_COMMENT") + } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 7385952861ee5..bb19ac232fcbe 100755 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -23,6 +23,7 @@ import java.io._ import java.util.{ArrayList => JArrayList} import jline.{ConsoleReader, History} + import org.apache.commons.lang.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration @@ -39,7 +40,6 @@ import org.apache.thrift.transport.TSocket import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveShim -import org.apache.spark.sql.hive.thriftserver.HiveThriftServerShim private[hive] object SparkSQLCLIDriver { private var prompt = "spark-sql" diff --git a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala index 166c56b9dfe20..ea9d61d8d0f5e 100644 --- a/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala +++ b/sql/hive-thriftserver/v0.12.0/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim12.scala @@ -32,7 +32,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{SQLConf, SchemaRDD, Row => SparkRow} +import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -71,7 +71,7 @@ private[hive] class SparkExecuteStatementOperation( sessionToActivePool: SMap[SessionHandle, String]) extends ExecuteStatementOperation(parentSession, statement, confOverlay) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -202,7 +202,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } diff --git a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala index eaf7a1ddd4996..71e3954b2c7ac 100644 --- a/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala +++ b/sql/hive-thriftserver/v0.13.1/src/main/scala/org/apache/spark/sql/hive/thriftserver/Shim13.scala @@ -30,7 +30,7 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging -import org.apache.spark.sql.{Row => SparkRow, SQLConf, SchemaRDD} +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} @@ -72,7 +72,7 @@ private[hive] class SparkExecuteStatementOperation( // NOTE: `runInBackground` is set to `false` intentionally to disable asynchronous execution extends ExecuteStatementOperation(parentSession, statement, confOverlay, false) with Logging { - private var result: SchemaRDD = _ + private var result: DataFrame = _ private var iter: Iterator[SparkRow] = _ private var dataTypes: Array[DataType] = _ @@ -173,7 +173,7 @@ private[hive] class SparkExecuteStatementOperation( val useIncrementalCollect = hiveContext.getConf("spark.sql.thriftServer.incrementalCollect", "false").toBoolean if (useIncrementalCollect) { - result.toLocalIterator + result.rdd.toLocalIterator } else { result.collect().iterator } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 9d2cfd8e0d669..b746942cb1067 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -64,15 +64,15 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { getConf("spark.sql.hive.convertMetastoreParquet", "true") == "true" override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) - override def sql(sqlText: String): SchemaRDD = { + override def sql(sqlText: String): DataFrame = { val substituted = new VariableSubstitution().substitute(hiveconf, sqlText) // TODO: Create a framework for registering parsers instead of just hardcoding if statements. if (conf.dialect == "sql") { super.sql(substituted) } else if (conf.dialect == "hiveql") { - new SchemaRDD(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) + new DataFrame(this, ddlParser(sqlText, false).getOrElse(HiveQl.parseSql(substituted))) } else { sys.error(s"Unsupported SQL dialect: ${conf.dialect}. Try 'sql' or 'hiveql'") } @@ -352,7 +352,8 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override protected[sql] val planner = hivePlanner /** Extends QueryExecution with hive specific features. */ - protected[sql] abstract class QueryExecution extends super.QueryExecution { + protected[sql] class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { /** * Returns the result as a hive compatible sequence of strings. For native commands, the diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 5e29e57d93585..30a64b48d7951 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -965,6 +965,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C /* Case insensitive matches */ val ARRAY = "(?i)ARRAY".r + val COALESCE = "(?i)COALESCE".r val COUNT = "(?i)COUNT".r val AVG = "(?i)AVG".r val SUM = "(?i)SUM".r @@ -1002,11 +1003,11 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C } /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => Star(None) + case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only // has a single child which is tableName. case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - Star(Some(name)) + UnresolvedStar(Some(name)) /* Aggregate Functions */ case Token("TOK_FUNCTION", Token(AVG(), Nil) :: arg :: Nil) => Average(nodeToExpr(arg)) @@ -1140,12 +1141,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Substring(nodeToExpr(string), nodeToExpr(pos), Literal(Integer.MAX_VALUE, IntegerType)) case Token("TOK_FUNCTION", Token(SUBSTR(), Nil) :: string :: pos :: length :: Nil) => Substring(nodeToExpr(string), nodeToExpr(pos), nodeToExpr(length)) + case Token("TOK_FUNCTION", Token(COALESCE(), Nil) :: list) => Coalesce(list.map(nodeToExpr)) /* UDFs - Must be last otherwise will preempt built in functions */ case Token("TOK_FUNCTION", Token(name, Nil) :: args) => UnresolvedFunction(name, args.map(nodeToExpr)) case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, Star(None) :: Nil) + UnresolvedFunction(name, UnresolvedStar(None) :: Nil) /* Literals */ case Token("TOK_NULL", Nil) => Literal(null, NullType) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 6952b126cf894..ace9329cd5821 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import scala.collection.JavaConversions._ import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{SQLContext, SchemaRDD, Strategy} +import org.apache.spark.sql.{Column, DataFrame, SQLContext, Strategy} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate @@ -55,16 +55,15 @@ private[hive] trait HiveStrategies { */ @Experimental object ParquetConversion extends Strategy { - implicit class LogicalPlanHacks(s: SchemaRDD) { - def lowerCase = - new SchemaRDD(s.sqlContext, s.logicalPlan) + implicit class LogicalPlanHacks(s: DataFrame) { + def lowerCase = new DataFrame(s.sqlContext, s.logicalPlan) def addPartitioningAttributes(attrs: Seq[Attribute]) = { // Don't add the partitioning key if its already present in the data. if (attrs.map(_.name).toSet.subsetOf(s.logicalPlan.output.map(_.name).toSet)) { s } else { - new SchemaRDD( + new DataFrame( s.sqlContext, s.logicalPlan transform { case p: ParquetRelation => p.copy(partitioningAttributes = attrs) @@ -97,13 +96,13 @@ private[hive] trait HiveStrategies { // We are going to throw the predicates and projection back at the whole optimization // sequence so lets unresolve all the attributes, allowing them to be rebound to the // matching parquet attributes. - val unresolvedOtherPredicates = otherPredicates.map(_ transform { + val unresolvedOtherPredicates = new Column(otherPredicates.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }).reduceOption(And).getOrElse(Literal(true)) + }).reduceOption(And).getOrElse(Literal(true))) - val unresolvedProjection = projectList.map(_ transform { + val unresolvedProjection: Seq[Column] = projectList.map(_ transform { case a: AttributeReference => UnresolvedAttribute(a.name) - }) + }).map(new Column(_)) try { if (relation.hiveQlTable.isPartitioned) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala index 47431cef03e13..7c1d1133c3425 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TestHive.scala @@ -68,6 +68,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { System.clearProperty("spark.hostPort") CommandProcessorFactory.clean(hiveconf) + hiveconf.set("hive.plan.serialization.format", "javaXML") + lazy val warehousePath = getTempFilePath("sparkHiveWarehouse").getCanonicalPath lazy val metastorePath = getTempFilePath("sparkHiveMetastore").getCanonicalPath @@ -99,7 +101,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { override def runSqlHive(sql: String): Seq[String] = super.runSqlHive(rewritePaths(sql)) override def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution { val logical = plan } + new this.QueryExecution(plan) /** Fewer partitions to speed up testing. */ protected[sql] override lazy val conf: SQLConf = new SQLConf { @@ -150,8 +152,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { val describedTable = "DESCRIBE (\\w+)".r - protected[hive] class HiveQLQueryExecution(hql: String) extends this.QueryExecution { - lazy val logical = HiveQl.parseSql(hql) + protected[hive] class HiveQLQueryExecution(hql: String) + extends this.QueryExecution(HiveQl.parseSql(hql)) { def hiveExec() = runSqlHive(hql) override def toString = hql + "\n" + super.toString } @@ -159,7 +161,8 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { /** * Override QueryExecution with special debug workflow. */ - abstract class QueryExecution extends super.QueryExecution { + class QueryExecution(logicalPlan: LogicalPlan) + extends super.QueryExecution(logicalPlan) { override lazy val analyzed = { val describedTables = logical match { case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil @@ -395,7 +398,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) } - clearCache() + cacheManager.clearCache() loadedTables.clear() catalog.cachedDataSourceTables.invalidateAll() catalog.client.getAllTables("default").foreach { t => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 42bc8a0b67933..91af35f0965c0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -239,7 +239,7 @@ case class InsertIntoHiveTable( } // Invalidate the cache. - sqlContext.invalidateCache(table) + sqlContext.cacheManager.invalidateCache(table) // It would be nice to just return the childRdd unchanged so insert operations could be chained, // however for now we return an empty list to simplify compatibility checks with hive, which diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 91f9da35abeee..4814cb7ebfe51 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -54,7 +54,7 @@ case class DropTable( val hiveContext = sqlContext.asInstanceOf[HiveContext] val ifExistsClause = if (ifExists) "IF EXISTS " else "" try { - hiveContext.tryUncacheQuery(hiveContext.table(tableName)) + hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) } catch { // This table's metadata is not in case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java index 8b29fa7d1a8f7..4b23fbf6e7362 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/package-info.java @@ -15,4 +15,4 @@ * limitations under the License. */ -package org.apache.spark.sql.hive; \ No newline at end of file +package org.apache.spark.sql.hive; diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala index f320d732fb77a..ba391293884bd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -36,12 +36,12 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output * @param keywords keyword in string array */ - def checkExistence(rdd: SchemaRDD, exists: Boolean, keywords: String*) { + def checkExistence(rdd: DataFrame, exists: Boolean, keywords: String*) { val outputs = rdd.collect().map(_.mkString).mkString for (key <- keywords) { if (exists) { @@ -54,10 +54,10 @@ class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. - * @param rdd the [[SchemaRDD]] to be executed + * @param rdd the [[DataFrame]] to be executed * @param expectedAnswer the expected result, can either be an Any, Seq[Product], or Seq[ Seq[Any] ]. */ - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Seq[Row]): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Seq[Row]): Unit = { val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty def prepareAnswer(answer: Seq[Row]): Seq[Row] = { // Converts data to types that we can do equality comparison using Scala collections. @@ -101,7 +101,7 @@ class QueryTest extends PlanTest { } } - protected def checkAnswer(rdd: SchemaRDD, expectedAnswer: Row): Unit = { + protected def checkAnswer(rdd: DataFrame, expectedAnswer: Row): Unit = { checkAnswer(rdd, Seq(expectedAnswer)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index f95a6b43af357..61e5117feab10 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{QueryTest, SchemaRDD} +import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.storage.RDDBlockId class CachedTableSuite extends QueryTest { @@ -28,7 +28,7 @@ class CachedTableSuite extends QueryTest { * Throws a test failed exception when the number of cached tables differs from the expected * number. */ - def assertCached(query: SchemaRDD, numCachedTables: Int = 1): Unit = { + def assertCached(query: DataFrame, numCachedTables: Int = 1): Unit = { val planWithCaching = query.queryExecution.withCachedData val cachedData = planWithCaching collect { case cached: InMemoryRelation => cached diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 0e6636d38ed3c..4dd96bd5a1b77 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -52,7 +52,7 @@ class InsertIntoHiveTableSuite extends QueryTest { // Make sure the table has been updated. checkAnswer( sql("SELECT * FROM createAndInsertTest"), - testData.toSchemaRDD.collect().toSeq ++ testData.toSchemaRDD.collect().toSeq + testData.toDataFrame.collect().toSeq ++ testData.toDataFrame.collect().toSeq ) // Now overwrite. @@ -82,8 +82,8 @@ class InsertIntoHiveTableSuite extends QueryTest { val schema = StructType(StructField("m", MapType(StringType, StringType), true) :: Nil) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(scala.collection.mutable.HashMap(s"key$i" -> s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithMapValue") + val df = applySchema(rowRDD, schema) + df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m MAP )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -127,8 +127,8 @@ class InsertIntoHiveTableSuite extends QueryTest { val schema = StructType(Seq( StructField("a", ArrayType(StringType, containsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize((1 to 100).map(i => Row(Seq(s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithArrayValue") + val df = applySchema(rowRDD, schema) + df.registerTempTable("tableWithArrayValue") sql("CREATE TABLE hiveTableWithArrayValue(a Array )") sql("INSERT OVERWRITE TABLE hiveTableWithArrayValue SELECT a FROM tableWithArrayValue") @@ -144,8 +144,8 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("m", MapType(StringType, StringType, valueContainsNull = false)))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Map(s"key$i" -> s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithMapValue") + val df = applySchema(rowRDD, schema) + df.registerTempTable("tableWithMapValue") sql("CREATE TABLE hiveTableWithMapValue(m Map )") sql("INSERT OVERWRITE TABLE hiveTableWithMapValue SELECT m FROM tableWithMapValue") @@ -161,8 +161,8 @@ class InsertIntoHiveTableSuite extends QueryTest { StructField("s", StructType(Seq(StructField("f", StringType, nullable = false)))))) val rowRDD = TestHive.sparkContext.parallelize( (1 to 100).map(i => Row(Row(s"value$i")))) - val schemaRDD = applySchema(rowRDD, schema) - schemaRDD.registerTempTable("tableWithStructValue") + val df = applySchema(rowRDD, schema) + df.registerTempTable("tableWithStructValue") sql("CREATE TABLE hiveTableWithStructValue(s Struct )") sql("INSERT OVERWRITE TABLE hiveTableWithStructValue SELECT s FROM tableWithStructValue") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index df72be7746ac6..60619f5d99578 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,11 +27,12 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.{SQLConf, Row, SchemaRDD} case class TestData(a: Int, b: String) @@ -367,7 +368,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") } - test("SchemaRDD toString") { + test("DataFrame toString") { sql("SHOW TABLES").toString sql("SELECT * FROM src").toString } @@ -473,12 +474,12 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } - def isExplanation(result: SchemaRDD) = { + def isExplanation(result: DataFrame) = { val explanation = result.select('plan).collect().map { case Row(plan: String) => plan } explanation.contains("== Physical Plan ==") } - test("SPARK-1704: Explain commands as a SchemaRDD") { + test("SPARK-1704: Explain commands as a DataFrame") { sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") val rdd = sql("explain select key, count(value) from src group by key") @@ -508,6 +509,11 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(sql("select key from src having key > 490").collect().size < 100) } + test("SPARK-5367: resolve star expression in udf") { + assert(sql("select concat(*) from src limit 5").collect().size == 5) + assert(sql("select array(*) from src limit 5").collect().size == 5) + } + test("Query Hive native command execution result") { val tableName = "test_native_commands" @@ -842,7 +848,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val testVal = "test.val.0" val nonexistentKey = "nonexistent" val KV = "([^=]+)=([^=]*)".r - def collectResults(rdd: SchemaRDD): Set[(String, String)] = + def collectResults(rdd: DataFrame): Set[(String, String)] = rdd.collect().map { case Row(key: String, value: String) => key -> value case Row(KV(key, value)) => key -> value diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 16f77a438e1ae..8fb5e050a237a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.Row +import org.apache.spark.sql.Dsl._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ -import org.apache.spark.sql.Row import org.apache.spark.util.Utils @@ -82,10 +83,10 @@ class HiveTableScanSuite extends HiveComparisonTest { sql("create table spark_4959 (col1 string)") sql("""insert into table spark_4959 select "hi" from src limit 1""") table("spark_4959").select( - 'col1.as('CaseSensitiveColName), - 'col1.as('CaseSensitiveColName2)).registerTempTable("spark_4959_2") + 'col1.as("CaseSensitiveColName"), + 'col1.as("CaseSensitiveColName2")).registerTempTable("spark_4959_2") - assert(sql("select CaseSensitiveColName from spark_4959_2").first() === Row("hi")) - assert(sql("select casesensitivecolname from spark_4959_2").first() === Row("hi")) + assert(sql("select CaseSensitiveColName from spark_4959_2").head() === Row("hi")) + assert(sql("select casesensitivecolname from spark_4959_2").head() === Row("hi")) } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala index 48fffe53cf2ff..ab0e0443c7faa 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTypeCoercionSuite.scala @@ -57,4 +57,10 @@ class HiveTypeCoercionSuite extends HiveComparisonTest { } assert(numEquals === 1) } + + test("COALESCE with different types") { + intercept[RuntimeException] { + TestHive.sql("""SELECT COALESCE(1, true, "abc") FROM src limit 1""").collect() + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index f2374a215291b..dd0df1a9f6320 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -58,7 +58,7 @@ class HiveUdfSuite extends QueryTest { | getStruct(1).f3, | getStruct(1).f4, | getStruct(1).f5 FROM src LIMIT 1 - """.stripMargin).first() === Row(1, 2, 3, 4, 5)) + """.stripMargin).head() === Row(1, 2, 3, 4, 5)) } test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 7f9f1ac7cd80d..eb7a7750af02d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -222,7 +222,7 @@ class SQLQuerySuite extends QueryTest { sql("SELECT distinct key FROM src order by key").collect().toSeq) } - test("SPARK-4963 SchemaRDD sample on mutable row return wrong result") { + test("SPARK-4963 DataFrame sample on mutable row return wrong result") { sql("SELECT * FROM src WHERE key % 2 = 0") .sample(withReplacement = false, fraction = 0.3) .registerTempTable("sampled") @@ -267,4 +267,19 @@ class SQLQuerySuite extends QueryTest { sql("DROP TABLE nullValuesInInnerComplexTypes") dropTempTable("testTable") } + + test("SPARK-4296 Grouping field with Hive UDF as sub expression") { + val rdd = sparkContext.makeRDD( """{"a": "str", "b":"1", "c":"1970-01-01 00:00:00"}""" :: Nil) + jsonRDD(rdd).registerTempTable("data") + checkAnswer( + sql("SELECT concat(a, '-', b), year(c) FROM data GROUP BY concat(a, '-', b), year(c)"), + Row("str-1", 1970)) + + dropTempTable("data") + + jsonRDD(rdd).registerTempTable("data") + checkAnswer(sql("SELECT year(c) + 1 FROM data GROUP BY year(c) + 1"), Row(1971)) + + dropTempTable("data") + } } diff --git a/streaming/pom.xml b/streaming/pom.xml index 22b0d714b57f6..d032491e2ff83 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -40,6 +40,10 @@ spark-core_${scala.binary.version} ${project.version} + + com.google.guava + guava + org.eclipse.jetty jetty-server @@ -95,6 +99,14 @@ + + + org.apache.maven.plugins + maven-shade-plugin + + true + + diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index ed1aa114e19d9..74dbba453f026 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -50,9 +50,6 @@ case class StreamingListenerReceiverError(receiverInfo: ReceiverInfo) case class StreamingListenerReceiverStopped(receiverInfo: ReceiverInfo) extends StreamingListenerEvent -/** An event used in the listener to shutdown the listener daemon thread. */ -private[scheduler] case object StreamingListenerShutdown extends StreamingListenerEvent - /** * :: DeveloperApi :: * A listener interface for receiving information about an ongoing streaming diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala index 398724d9e8130..b07d6cf347ca7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListenerBus.scala @@ -17,83 +17,42 @@ package org.apache.spark.streaming.scheduler +import java.util.concurrent.atomic.AtomicBoolean + import org.apache.spark.Logging -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} -import java.util.concurrent.LinkedBlockingQueue +import org.apache.spark.util.AsynchronousListenerBus /** Asynchronously passes StreamingListenerEvents to registered StreamingListeners. */ -private[spark] class StreamingListenerBus() extends Logging { - private val listeners = new ArrayBuffer[StreamingListener]() - with SynchronizedBuffer[StreamingListener] - - /* Cap the capacity of the SparkListenerEvent queue so we get an explicit error (rather than - * an OOM exception) if it's perpetually being added to more quickly than it's being drained. */ - private val EVENT_QUEUE_CAPACITY = 10000 - private val eventQueue = new LinkedBlockingQueue[StreamingListenerEvent](EVENT_QUEUE_CAPACITY) - private var queueFullErrorMessageLogged = false - - val listenerThread = new Thread("StreamingListenerBus") { - setDaemon(true) - override def run() { - while (true) { - val event = eventQueue.take - event match { - case receiverStarted: StreamingListenerReceiverStarted => - listeners.foreach(_.onReceiverStarted(receiverStarted)) - case receiverError: StreamingListenerReceiverError => - listeners.foreach(_.onReceiverError(receiverError)) - case receiverStopped: StreamingListenerReceiverStopped => - listeners.foreach(_.onReceiverStopped(receiverStopped)) - case batchSubmitted: StreamingListenerBatchSubmitted => - listeners.foreach(_.onBatchSubmitted(batchSubmitted)) - case batchStarted: StreamingListenerBatchStarted => - listeners.foreach(_.onBatchStarted(batchStarted)) - case batchCompleted: StreamingListenerBatchCompleted => - listeners.foreach(_.onBatchCompleted(batchCompleted)) - case StreamingListenerShutdown => - // Get out of the while loop and shutdown the daemon thread - return - case _ => - } - } +private[spark] class StreamingListenerBus + extends AsynchronousListenerBus[StreamingListener, StreamingListenerEvent]("StreamingListenerBus") + with Logging { + + private val logDroppedEvent = new AtomicBoolean(false) + + override def onPostEvent(listener: StreamingListener, event: StreamingListenerEvent): Unit = { + event match { + case receiverStarted: StreamingListenerReceiverStarted => + listener.onReceiverStarted(receiverStarted) + case receiverError: StreamingListenerReceiverError => + listener.onReceiverError(receiverError) + case receiverStopped: StreamingListenerReceiverStopped => + listener.onReceiverStopped(receiverStopped) + case batchSubmitted: StreamingListenerBatchSubmitted => + listener.onBatchSubmitted(batchSubmitted) + case batchStarted: StreamingListenerBatchStarted => + listener.onBatchStarted(batchStarted) + case batchCompleted: StreamingListenerBatchCompleted => + listener.onBatchCompleted(batchCompleted) + case _ => } } - def start() { - listenerThread.start() - } - - def addListener(listener: StreamingListener) { - listeners += listener - } - - def post(event: StreamingListenerEvent) { - val eventAdded = eventQueue.offer(event) - if (!eventAdded && !queueFullErrorMessageLogged) { + override def onDropEvent(event: StreamingListenerEvent): Unit = { + if (logDroppedEvent.compareAndSet(false, true)) { + // Only log the following message once to avoid duplicated annoying logs. logError("Dropping StreamingListenerEvent because no remaining room in event queue. " + "This likely means one of the StreamingListeners is too slow and cannot keep up with the " + "rate at which events are being started by the scheduler.") - queueFullErrorMessageLogged = true } } - - /** - * Waits until there are no more events in the queue, or until the specified time has elapsed. - * Used for testing only. Returns true if the queue has emptied and false is the specified time - * elapsed before the queue emptied. - */ - def waitUntilEmpty(timeoutMillis: Int): Boolean = { - val finishTime = System.currentTimeMillis + timeoutMillis - while (!eventQueue.isEmpty) { - if (System.currentTimeMillis > finishTime) { - return false - } - /* Sleep rather than using wait/notify, because this is used only for testing and wait/notify - * add overhead in the general case. */ - Thread.sleep(10) - } - true - } - - def stop(): Unit = post(StreamingListenerShutdown) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 902bdda59860e..d3e327b2497b7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -43,8 +43,11 @@ import org.apache.spark.util.{AkkaUtils, SignalLogger, Utils} /** * Common application master functionality for Spark on Yarn. */ -private[spark] class ApplicationMaster(args: ApplicationMasterArguments, - client: YarnRMClient) extends Logging { +private[spark] class ApplicationMaster( + args: ApplicationMasterArguments, + client: YarnRMClient) + extends Logging { + // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -231,6 +234,24 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, reporterThread = launchReporterThread() } + /** + * Create an actor that communicates with the driver. + * + * In cluster mode, the AM and the driver belong to same process + * so the AM actor need not monitor lifecycle of the driver. + */ + private def runAMActor( + host: String, + port: String, + isDriver: Boolean): Unit = { + val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( + SparkEnv.driverActorSystemName, + host, + port, + YarnSchedulerBackend.ACTOR_NAME) + actor = actorSystem.actorOf(Props(new AMActor(driverUrl, isDriver)), name = "YarnAM") + } + private def runDriver(securityMgr: SecurityManager): Unit = { addAmIpFilter() userClassThread = startUserClass() @@ -245,6 +266,11 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, ApplicationMaster.EXIT_SC_NOT_INITED, "Timed out waiting for SparkContext.") } else { + actorSystem = sc.env.actorSystem + runAMActor( + sc.getConf.get("spark.driver.host"), + sc.getConf.get("spark.driver.port"), + isDriver = true) registerAM(sc.ui.map(_.appUIAddress).getOrElse(""), securityMgr) userClassThread.join() } @@ -253,7 +279,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, private def runExecutorLauncher(securityMgr: SecurityManager): Unit = { actorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0, conf = sparkConf, securityManager = securityMgr)._1 - actor = waitForSparkDriver() + waitForSparkDriver() addAmIpFilter() registerAM(sparkConf.get("spark.driver.appUIAddress", ""), securityMgr) @@ -367,7 +393,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } } - private def waitForSparkDriver(): ActorRef = { + private def waitForSparkDriver(): Unit = { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false val hostport = args.userArgs(0) @@ -399,12 +425,7 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, sparkConf.set("spark.driver.host", driverHost) sparkConf.set("spark.driver.port", driverPort.toString) - val driverUrl = "akka.tcp://%s@%s:%s/user/%s".format( - SparkEnv.driverActorSystemName, - driverHost, - driverPort.toString, - YarnSchedulerBackend.ACTOR_NAME) - actorSystem.actorOf(Props(new AMActor(driverUrl)), name = "YarnAM") + runAMActor(driverHost, driverPort.toString, isDriver = false) } /** Add the Yarn IP filter that is required for properly securing the UI. */ @@ -462,9 +483,9 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, } /** - * Actor that communicates with the driver in client deploy mode. + * An actor that communicates with the driver's scheduler backend. */ - private class AMActor(driverUrl: String) extends Actor { + private class AMActor(driverUrl: String, isDriver: Boolean) extends Actor { var driver: ActorSelection = _ override def preStart() = { @@ -474,13 +495,21 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments, // we can monitor Lifecycle Events. driver ! "Hello" driver ! RegisterClusterManager - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + // In cluster mode, the AM can directly monitor the driver status instead + // of trying to deduce it from the lifecycle of the driver's actor + if (!isDriver) { + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } } override def receive = { case x: DisassociatedEvent => logInfo(s"Driver terminated or disconnected! Shutting down. $x") - finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + // In cluster mode, do not rely on the disassociated event to exit + // This avoids potentially reporting incorrect exit codes if the driver fails + if (!isDriver) { + finish(FinalApplicationStatus.SUCCEEDED, ApplicationMaster.EXIT_SUCCESS) + } case x: AddWebUIFilter => logInfo(s"Add WebUI Filter. $x") diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index d4eeccf64275f..1a18e6509ef26 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -400,7 +400,10 @@ private[spark] class Client( // Add Xmx for AM memory javaOpts += "-Xmx" + args.amMemory + "m" - val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + val tmpDir = new Path( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR + ) javaOpts += "-Djava.io.tmpdir=" + tmpDir // TODO: Remove once cpuset version is pushed out. @@ -491,7 +494,9 @@ private[spark] class Client( "--num-executors ", args.numExecutors.toString) // Command for the ApplicationMaster - val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java", "-server") ++ + val commands = prefixEnv ++ Seq( + YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java", "-server" + ) ++ javaOpts ++ amArgs ++ Seq( "1>", ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout", @@ -769,7 +774,9 @@ object Client extends Logging { env: HashMap[String, String], extraClassPath: Option[String] = None): Unit = { extraClassPath.foreach(addClasspathEntry(_, env)) - addClasspathEntry(Environment.PWD.$(), env) + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), env + ) // Normally the users app.jar is last in case conflicts with spark jars if (sparkConf.getBoolean("spark.yarn.user.classpath.first", false)) { @@ -783,7 +790,9 @@ object Client extends Logging { } // Append all jar files under the working directory to the classpath. - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + "*", env) + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + "*", env + ) } /** @@ -838,7 +847,9 @@ object Client extends Logging { } } if (fileName != null) { - addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + fileName, env) + addClasspathEntry( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD) + Path.SEPARATOR + fileName, env + ) } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index c537da9f67552..ee2002a35f523 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -142,7 +142,10 @@ class ExecutorRunnable( } javaOpts += "-Djava.io.tmpdir=" + - new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + new Path( + YarnSparkHadoopUtil.expandEnvironment(Environment.PWD), + YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR + ) // Certain configs need to be passed here because they are needed before the Executor // registers with the Scheduler and transfers the spark configs. Since the Executor backend @@ -181,7 +184,8 @@ class ExecutorRunnable( // For log4j configuration to reference javaOpts += ("-Dspark.yarn.app.container.log.dir=" + ApplicationConstants.LOG_DIR_EXPANSION_VAR) - val commands = prefixEnv ++ Seq(Environment.JAVA_HOME.$() + "/bin/java", + val commands = prefixEnv ++ Seq( + YarnSparkHadoopUtil.expandEnvironment(Environment.JAVA_HOME) + "/bin/java", "-server", // Kill if OOM is raised - leverage yarn's failure handling to cause rescheduling. // Not killing the task leaves various aspects of the executor and (to some extent) the jvm in diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index d00f29665a58f..3849586c6111e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -32,6 +32,8 @@ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.RackResolver +import org.apache.log4j.{Level, Logger} + import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -60,6 +62,11 @@ private[yarn] class YarnAllocator( import YarnAllocator._ + // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. + if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { + Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) + } + // Visible for testing. val allocatedHostToContainersMap = new HashMap[String, collection.mutable.Set[ContainerId]] diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 4bff846123619..146b2c0f1a302 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -17,22 +17,21 @@ package org.apache.spark.deploy.yarn -import java.lang.{Boolean => JBoolean} import java.io.File -import java.util.{Collections, Set => JSet} import java.util.regex.Matcher import java.util.regex.Pattern -import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.HashMap +import scala.util.Try import org.apache.hadoop.io.Text import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.api.ApplicationConstants +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records.{Priority, ApplicationAccessType} -import org.apache.hadoop.yarn.util.RackResolver import org.apache.hadoop.conf.Configuration import org.apache.spark.{SecurityManager, SparkConf} @@ -106,7 +105,7 @@ object YarnSparkHadoopUtil { * If the map already contains this key, append the value to the existing value instead. */ def addPathToEnvironment(env: HashMap[String, String], key: String, value: String): Unit = { - val newValue = if (env.contains(key)) { env(key) + File.pathSeparator + value } else value + val newValue = if (env.contains(key)) { env(key) + getClassPathSeparator + value } else value env.put(key, newValue) } @@ -186,4 +185,30 @@ object YarnSparkHadoopUtil { ) } + /** + * Expand environment variable using Yarn API. + * If environment.$$() is implemented, return the result of it. + * Otherwise, return the result of environment.$() + * Note: $$() is added in Hadoop 2.4. + */ + private lazy val expandMethod = + Try(classOf[Environment].getMethod("$$")) + .getOrElse(classOf[Environment].getMethod("$")) + + def expandEnvironment(environment: Environment): String = + expandMethod.invoke(environment).asInstanceOf[String] + + /** + * Get class path separator using Yarn API. + * If ApplicationConstants.CLASS_PATH_SEPARATOR is implemented, return it. + * Otherwise, return File.pathSeparator + * Note: CLASS_PATH_SEPARATOR is added in Hadoop 2.4. + */ + private lazy val classPathSeparatorField = + Try(classOf[ApplicationConstants].getField("CLASS_PATH_SEPARATOR")) + .getOrElse(classOf[File].getField("pathSeparator")) + + def getClassPathSeparator(): String = { + classPathSeparatorField.get(null).asInstanceOf[String] + } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala index be55d26f1cf61..72ec4d6b34af6 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala @@ -17,33 +17,17 @@ package org.apache.spark.scheduler.cluster -import org.apache.hadoop.yarn.util.RackResolver - import org.apache.spark._ import org.apache.spark.deploy.yarn.ApplicationMaster -import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.util.Utils /** * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of * ApplicationMaster, etc is done */ -private[spark] class YarnClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { +private[spark] class YarnClusterScheduler(sc: SparkContext) extends YarnScheduler(sc) { logInfo("Created YarnClusterScheduler") - // Nothing else for now ... initialize application master : which needs a SparkContext to - // determine how to allocate. - // Note that only the first creation of a SparkContext influences (and ideally, there must be - // only one SparkContext, right ?). Subsequent creations are ignored since executors are already - // allocated by then. - - // By default, rack is unknown - override def getRackForHost(hostPort: String): Option[String] = { - val host = Utils.parseHostPort(hostPort)._1 - Option(RackResolver.resolve(sc.hadoopConfiguration, host).getNetworkLocation) - } - override def postStartHook() { ApplicationMaster.sparkContextInitialized(sc) super.postStartHook() diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala similarity index 77% rename from yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala rename to yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala index 2fa24cc43325e..4ebf3af12b381 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala @@ -19,14 +19,18 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.util.RackResolver +import org.apache.log4j.{Level, Logger} + import org.apache.spark._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.util.Utils -/** - * This scheduler launches executors through Yarn - by calling into Client to launch the Spark AM. - */ -private[spark] class YarnClientClusterScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { +private[spark] class YarnScheduler(sc: SparkContext) extends TaskSchedulerImpl(sc) { + + // RackResolver logs an INFO message whenever it resolves a rack, which is way too often. + if (Logger.getLogger(classOf[RackResolver]).getLevel == null) { + Logger.getLogger(classOf[RackResolver]).setLevel(Level.WARN) + } // By default, rack is unknown override def getRackForHost(hostPort: String): Option[String] = { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index aad50015b717f..2bb3dcffd61d9 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -28,8 +28,6 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.mockito.Matchers._ import org.mockito.Mockito._ - - import org.scalatest.FunSuite import org.scalatest.Matchers @@ -89,7 +87,7 @@ class ClientSuite extends FunSuite with Matchers { Client.populateClasspath(args, conf, sparkConf, env) - val cp = env("CLASSPATH").split(File.pathSeparator) + val cp = env("CLASSPATH").split(":|;|") s"$SPARK,$USER,$ADDED".split(",").foreach({ entry => val uri = new URI(entry) if (Client.LOCAL_SCHEME.equals(uri.getScheme())) { @@ -98,8 +96,16 @@ class ClientSuite extends FunSuite with Matchers { cp should not contain (uri.getPath()) } }) - cp should contain (Environment.PWD.$()) - cp should contain (s"${Environment.PWD.$()}${File.separator}*") + if (classOf[Environment].getMethods().exists(_.getName == "$$")) { + cp should contain("{{PWD}}") + cp should contain(s"{{PWD}}${Path.SEPARATOR}*") + } else if (Utils.isWindows) { + cp should contain("%PWD%") + cp should contain(s"%PWD%${Path.SEPARATOR}*") + } else { + cp should contain(Environment.PWD.$()) + cp should contain(s"${Environment.PWD.$()}${File.separator}*") + } cp should not contain (Client.SPARK_JAR) cp should not contain (Client.APP_JAR) } @@ -223,7 +229,7 @@ class ClientSuite extends FunSuite with Matchers { def newEnv = MutableHashMap[String, String]() - def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;") + def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;|") def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 2cc5abb3a890c..b5a2db8f6225c 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -20,12 +20,15 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.yarn.api.ApplicationConstants +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.{FunSuite, Matchers} import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.util.Utils class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { @@ -148,4 +151,26 @@ class YarnSparkHadoopUtilSuite extends FunSuite with Matchers with Logging { } } + + test("test expandEnvironment result") { + val target = Environment.PWD + if (classOf[Environment].getMethods().exists(_.getName == "$$")) { + YarnSparkHadoopUtil.expandEnvironment(target) should be ("{{" + target + "}}") + } else if (Utils.isWindows) { + YarnSparkHadoopUtil.expandEnvironment(target) should be ("%" + target + "%") + } else { + YarnSparkHadoopUtil.expandEnvironment(target) should be ("$" + target) + } + + } + + test("test getClassPathSeparator result") { + if (classOf[ApplicationConstants].getFields().exists(_.getName == "CLASS_PATH_SEPARATOR")) { + YarnSparkHadoopUtil.getClassPathSeparator() should be ("") + } else if (Utils.isWindows) { + YarnSparkHadoopUtil.getClassPathSeparator() should be (";") + } else { + YarnSparkHadoopUtil.getClassPathSeparator() should be (":") + } + } }