diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
index 110bd0a9a0c41..55241d33cd3f0 100644
--- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
+++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala
@@ -80,7 +80,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
test("large number of iterations") {
// This tests whether jobs with a large number of iterations finish in a reasonable time,
// because non-memoized recursion in RDD or DAGScheduler used to cause them to hang
- failAfter(10 seconds) {
+ failAfter(30 seconds) {
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
@@ -101,7 +101,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo
sc = new SparkContext("local", "test")
val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0))))
val msgs = sc.parallelize(Array[(String, TestMessage)]())
- val numSupersteps = 50
+ val numSupersteps = 20
val result =
Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) {
(self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) =>
diff --git a/bin/pyspark b/bin/pyspark
index d0fa56f31913f..114cbbc3a8a8e 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -86,6 +86,10 @@ else
if [[ "$IPYTHON" = "1" ]]; then
exec ipython $IPYTHON_OPTS
else
- exec "$PYSPARK_PYTHON"
+ if [[ -n $SPARK_TESTING ]]; then
+ exec "$PYSPARK_PYTHON" -m doctest
+ else
+ exec "$PYSPARK_PYTHON"
+ fi
fi
fi
diff --git a/bin/run-example b/bin/run-example
index 7caab31daef39..e7a5fe3914fbd 100755
--- a/bin/run-example
+++ b/bin/run-example
@@ -51,7 +51,7 @@ if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then
EXAMPLE_CLASS="org.apache.spark.examples.$EXAMPLE_CLASS"
fi
-./bin/spark-submit \
+"$FWDIR"/bin/spark-submit \
--master $EXAMPLE_MASTER \
--class $EXAMPLE_CLASS \
"$SPARK_EXAMPLES_JAR" \
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 9155159cf6aeb..e7f75481939a8 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -83,11 +83,17 @@ class HashPartitioner(partitions: Int) extends Partitioner {
case _ =>
false
}
+
+ override def hashCode: Int = numPartitions
}
/**
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly
* equal ranges. The ranges are determined by sampling the content of the RDD passed in.
+ *
+ * Note that the actual number of partitions created by the RangePartitioner might not be the same
+ * as the `partitions` parameter, in the case where the number of sampled records is less than
+ * the value of `partitions`.
*/
class RangePartitioner[K : Ordering : ClassTag, V](
partitions: Int,
@@ -119,7 +125,7 @@ class RangePartitioner[K : Ordering : ClassTag, V](
}
}
- def numPartitions = partitions
+ def numPartitions = rangeBounds.length + 1
private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K]
@@ -155,4 +161,16 @@ class RangePartitioner[K : Ordering : ClassTag, V](
case _ =>
false
}
+
+ override def hashCode(): Int = {
+ val prime = 31
+ var result = 1
+ var i = 0
+ while (i < rangeBounds.length) {
+ result = prime * result + rangeBounds(i).hashCode
+ i += 1
+ }
+ result = prime * result + ascending.hashCode
+ result
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index d941aea9d7eb2..d721aba709600 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -455,7 +455,7 @@ class SparkContext(config: SparkConf) extends Logging {
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
- minPartitions).map(pair => pair._2.toString)
+ minPartitions).map(pair => pair._2.toString).setName(path)
}
/**
@@ -496,7 +496,7 @@ class SparkContext(config: SparkConf) extends Logging {
classOf[String],
classOf[String],
updateConf,
- minPartitions)
+ minPartitions).setName(path)
}
/**
@@ -551,7 +551,7 @@ class SparkContext(config: SparkConf) extends Logging {
inputFormatClass,
keyClass,
valueClass,
- minPartitions)
+ minPartitions).setName(path)
}
/**
@@ -623,7 +623,7 @@ class SparkContext(config: SparkConf) extends Logging {
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
- new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf)
+ new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
new file mode 100644
index 0000000000000..adaa1ef6cf9ff
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.Logging
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io._
+import scala.util.{Failure, Success, Try}
+import org.apache.spark.annotation.Experimental
+
+
+/**
+ * :: Experimental ::
+ * A trait for use with reading custom classes in PySpark. Implement this trait and add custom
+ * transformation code by overriding the convert method.
+ */
+@Experimental
+trait Converter[T, U] extends Serializable {
+ def convert(obj: T): U
+}
+
+private[python] object Converter extends Logging {
+
+ def getInstance(converterClass: Option[String]): Converter[Any, Any] = {
+ converterClass.map { cc =>
+ Try {
+ val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]]
+ logInfo(s"Loaded converter: $cc")
+ c
+ } match {
+ case Success(c) => c
+ case Failure(err) =>
+ logError(s"Failed to load converter: $cc")
+ throw err
+ }
+ }.getOrElse { new DefaultConverter }
+ }
+}
+
+/**
+ * A converter that handles conversion of common [[org.apache.hadoop.io.Writable]] objects.
+ * Other objects are passed through without conversion.
+ */
+private[python] class DefaultConverter extends Converter[Any, Any] {
+
+ /**
+ * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
+ * object representation
+ */
+ private def convertWritable(writable: Writable): Any = {
+ import collection.JavaConversions._
+ writable match {
+ case iw: IntWritable => iw.get()
+ case dw: DoubleWritable => dw.get()
+ case lw: LongWritable => lw.get()
+ case fw: FloatWritable => fw.get()
+ case t: Text => t.toString
+ case bw: BooleanWritable => bw.get()
+ case byw: BytesWritable => byw.getBytes
+ case n: NullWritable => null
+ case aw: ArrayWritable => aw.get().map(convertWritable(_))
+ case mw: MapWritable => mapAsJavaMap(mw.map { case (k, v) =>
+ (convertWritable(k), convertWritable(v))
+ }.toMap)
+ case other => other
+ }
+ }
+
+ def convert(obj: Any): Any = {
+ obj match {
+ case writable: Writable =>
+ convertWritable(writable)
+ case _ =>
+ obj
+ }
+ }
+}
+
+/** Utilities for working with Python objects <-> Hadoop-related objects */
+private[python] object PythonHadoopUtil {
+
+ /**
+ * Convert a [[java.util.Map]] of properties to a [[org.apache.hadoop.conf.Configuration]]
+ */
+ def mapToConf(map: java.util.Map[String, String]): Configuration = {
+ import collection.JavaConversions._
+ val conf = new Configuration()
+ map.foreach{ case (k, v) => conf.set(k, v) }
+ conf
+ }
+
+ /**
+ * Merges two configurations, returns a copy of left with keys from right overwriting
+ * any matching keys in left
+ */
+ def mergeConfs(left: Configuration, right: Configuration): Configuration = {
+ import collection.JavaConversions._
+ val copy = new Configuration(left)
+ right.iterator().foreach(entry => copy.set(entry.getKey, entry.getValue))
+ copy
+ }
+
+ /**
+ * Converts an RDD of key-value pairs, where key and/or value could be instances of
+ * [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)]
+ */
+ def convertRDD[K, V](rdd: RDD[(K, V)],
+ keyConverter: Converter[Any, Any],
+ valueConverter: Converter[Any, Any]): RDD[(Any, Any)] = {
+ rdd.map { case (k, v) => (keyConverter.convert(k), valueConverter.convert(v)) }
+ }
+
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
index 95bec5030bfdd..e230d222b8604 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala
@@ -50,4 +50,6 @@ private[spark] class PythonPartitioner(
case _ =>
false
}
+
+ override def hashCode: Int = 31 * numPartitions + pyPartitionFunctionId.hashCode
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index d1df99300c5b1..f6570d335757a 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -28,6 +28,9 @@ import scala.util.Try
import net.razorvine.pickle.{Pickler, Unpickler}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapred.{InputFormat, JobConf}
+import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
import org.apache.spark._
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
@@ -266,7 +269,7 @@ private object SpecialLengths {
val TIMING_DATA = -3
}
-private[spark] object PythonRDD {
+private[spark] object PythonRDD extends Logging {
val UTF8 = Charset.forName("UTF-8")
/**
@@ -346,6 +349,180 @@ private[spark] object PythonRDD {
}
}
+ /**
+ * Create an RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]],
+ * key and value class.
+ * A key and/or value converter class can optionally be passed in
+ * (see [[org.apache.spark.api.python.Converter]])
+ */
+ def sequenceFile[K, V](
+ sc: JavaSparkContext,
+ path: String,
+ keyClassMaybeNull: String,
+ valueClassMaybeNull: String,
+ keyConverterClass: String,
+ valueConverterClass: String,
+ minSplits: Int) = {
+ val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
+ val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text")
+ implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]]
+ implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]]
+ val kc = kcm.runtimeClass.asInstanceOf[Class[K]]
+ val vc = vcm.runtimeClass.asInstanceOf[Class[V]]
+
+ val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
+ val keyConverter = Converter.getInstance(Option(keyConverterClass))
+ val valueConverter = Converter.getInstance(Option(valueConverterClass))
+ val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
+ JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ }
+
+ /**
+ * Create an RDD from a file path, using an arbitrary [[org.apache.hadoop.mapreduce.InputFormat]],
+ * key and value class.
+ * A key and/or value converter class can optionally be passed in
+ * (see [[org.apache.spark.api.python.Converter]])
+ */
+ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](
+ sc: JavaSparkContext,
+ path: String,
+ inputFormatClass: String,
+ keyClass: String,
+ valueClass: String,
+ keyConverterClass: String,
+ valueConverterClass: String,
+ confAsMap: java.util.HashMap[String, String]) = {
+ val conf = PythonHadoopUtil.mapToConf(confAsMap)
+ val baseConf = sc.hadoopConfiguration()
+ val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf)
+ val rdd =
+ newAPIHadoopRDDFromClassNames[K, V, F](sc,
+ Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
+ val keyConverter = Converter.getInstance(Option(keyConverterClass))
+ val valueConverter = Converter.getInstance(Option(valueConverterClass))
+ val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
+ JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ }
+
+ /**
+ * Create an RDD from a [[org.apache.hadoop.conf.Configuration]] converted from a map that is
+ * passed in from Python, using an arbitrary [[org.apache.hadoop.mapreduce.InputFormat]],
+ * key and value class.
+ * A key and/or value converter class can optionally be passed in
+ * (see [[org.apache.spark.api.python.Converter]])
+ */
+ def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]](
+ sc: JavaSparkContext,
+ inputFormatClass: String,
+ keyClass: String,
+ valueClass: String,
+ keyConverterClass: String,
+ valueConverterClass: String,
+ confAsMap: java.util.HashMap[String, String]) = {
+ val conf = PythonHadoopUtil.mapToConf(confAsMap)
+ val rdd =
+ newAPIHadoopRDDFromClassNames[K, V, F](sc,
+ None, inputFormatClass, keyClass, valueClass, conf)
+ val keyConverter = Converter.getInstance(Option(keyConverterClass))
+ val valueConverter = Converter.getInstance(Option(valueConverterClass))
+ val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
+ JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ }
+
+ private def newAPIHadoopRDDFromClassNames[K, V, F <: NewInputFormat[K, V]](
+ sc: JavaSparkContext,
+ path: Option[String] = None,
+ inputFormatClass: String,
+ keyClass: String,
+ valueClass: String,
+ conf: Configuration) = {
+ implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]]
+ implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]]
+ implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]]
+ val kc = kcm.runtimeClass.asInstanceOf[Class[K]]
+ val vc = vcm.runtimeClass.asInstanceOf[Class[V]]
+ val fc = fcm.runtimeClass.asInstanceOf[Class[F]]
+ val rdd = if (path.isDefined) {
+ sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf)
+ } else {
+ sc.sc.newAPIHadoopRDD[K, V, F](conf, fc, kc, vc)
+ }
+ rdd
+ }
+
+ /**
+ * Create an RDD from a file path, using an arbitrary [[org.apache.hadoop.mapred.InputFormat]],
+ * key and value class.
+ * A key and/or value converter class can optionally be passed in
+ * (see [[org.apache.spark.api.python.Converter]])
+ */
+ def hadoopFile[K, V, F <: InputFormat[K, V]](
+ sc: JavaSparkContext,
+ path: String,
+ inputFormatClass: String,
+ keyClass: String,
+ valueClass: String,
+ keyConverterClass: String,
+ valueConverterClass: String,
+ confAsMap: java.util.HashMap[String, String]) = {
+ val conf = PythonHadoopUtil.mapToConf(confAsMap)
+ val baseConf = sc.hadoopConfiguration()
+ val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf)
+ val rdd =
+ hadoopRDDFromClassNames[K, V, F](sc,
+ Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
+ val keyConverter = Converter.getInstance(Option(keyConverterClass))
+ val valueConverter = Converter.getInstance(Option(valueConverterClass))
+ val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
+ JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ }
+
+ /**
+ * Create an RDD from a [[org.apache.hadoop.conf.Configuration]] converted from a map
+ * that is passed in from Python, using an arbitrary [[org.apache.hadoop.mapred.InputFormat]],
+ * key and value class
+ * A key and/or value converter class can optionally be passed in
+ * (see [[org.apache.spark.api.python.Converter]])
+ */
+ def hadoopRDD[K, V, F <: InputFormat[K, V]](
+ sc: JavaSparkContext,
+ inputFormatClass: String,
+ keyClass: String,
+ valueClass: String,
+ keyConverterClass: String,
+ valueConverterClass: String,
+ confAsMap: java.util.HashMap[String, String]) = {
+ val conf = PythonHadoopUtil.mapToConf(confAsMap)
+ val rdd =
+ hadoopRDDFromClassNames[K, V, F](sc,
+ None, inputFormatClass, keyClass, valueClass, conf)
+ val keyConverter = Converter.getInstance(Option(keyConverterClass))
+ val valueConverter = Converter.getInstance(Option(valueConverterClass))
+ val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter)
+ JavaRDD.fromRDD(SerDeUtil.rddToPython(converted))
+ }
+
+ private def hadoopRDDFromClassNames[K, V, F <: InputFormat[K, V]](
+ sc: JavaSparkContext,
+ path: Option[String] = None,
+ inputFormatClass: String,
+ keyClass: String,
+ valueClass: String,
+ conf: Configuration) = {
+ implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]]
+ implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]]
+ implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]]
+ val kc = kcm.runtimeClass.asInstanceOf[Class[K]]
+ val vc = vcm.runtimeClass.asInstanceOf[Class[V]]
+ val fc = fcm.runtimeClass.asInstanceOf[Class[F]]
+ val rdd = if (path.isDefined) {
+ sc.sc.hadoopFile(path.get, fc, kc, vc)
+ } else {
+ sc.sc.hadoopRDD(new JobConf(conf), fc, kc, vc)
+ }
+ rdd
+ }
+
def writeUTF(str: String, dataOut: DataOutputStream) {
val bytes = str.getBytes(UTF8)
dataOut.writeInt(bytes.length)
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
new file mode 100644
index 0000000000000..9a012e7254901
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import scala.util.Try
+import org.apache.spark.rdd.RDD
+import org.apache.spark.Logging
+import scala.util.Success
+import scala.util.Failure
+import net.razorvine.pickle.Pickler
+
+
+/** Utilities for serialization / deserialization between Python and Java, using Pickle. */
+private[python] object SerDeUtil extends Logging {
+
+ private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
+ val pickle = new Pickler
+ val kt = Try {
+ pickle.dumps(t._1)
+ }
+ val vt = Try {
+ pickle.dumps(t._2)
+ }
+ (kt, vt) match {
+ case (Failure(kf), Failure(vf)) =>
+ logWarning(s"""
+ |Failed to pickle Java object as key: ${t._1.getClass.getSimpleName}, falling back
+ |to 'toString'. Error: ${kf.getMessage}""".stripMargin)
+ logWarning(s"""
+ |Failed to pickle Java object as value: ${t._2.getClass.getSimpleName}, falling back
+ |to 'toString'. Error: ${vf.getMessage}""".stripMargin)
+ (true, true)
+ case (Failure(kf), _) =>
+ logWarning(s"""
+ |Failed to pickle Java object as key: ${t._1.getClass.getSimpleName}, falling back
+ |to 'toString'. Error: ${kf.getMessage}""".stripMargin)
+ (true, false)
+ case (_, Failure(vf)) =>
+ logWarning(s"""
+ |Failed to pickle Java object as value: ${t._2.getClass.getSimpleName}, falling back
+ |to 'toString'. Error: ${vf.getMessage}""".stripMargin)
+ (false, true)
+ case _ =>
+ (false, false)
+ }
+ }
+
+ /**
+ * Convert an RDD of key-value pairs to an RDD of serialized Python objects, that is usable
+ * by PySpark. By default, if serialization fails, toString is called and the string
+ * representation is serialized
+ */
+ def rddToPython(rdd: RDD[(Any, Any)]): RDD[Array[Byte]] = {
+ val (keyFailed, valueFailed) = checkPickle(rdd.first())
+ rdd.mapPartitions { iter =>
+ val pickle = new Pickler
+ iter.map { case (k, v) =>
+ if (keyFailed && valueFailed) {
+ pickle.dumps(Array(k.toString, v.toString))
+ } else if (keyFailed) {
+ pickle.dumps(Array(k.toString, v))
+ } else if (!keyFailed && valueFailed) {
+ pickle.dumps(Array(k, v.toString))
+ } else {
+ pickle.dumps(Array(k, v))
+ }
+ }
+ }
+ }
+
+}
+
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
new file mode 100644
index 0000000000000..f0e3fb9aff5a0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.python
+
+import org.apache.spark.SparkContext
+import org.apache.hadoop.io._
+import scala.Array
+import java.io.{DataOutput, DataInput}
+import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat
+import org.apache.spark.api.java.JavaSparkContext
+
+/**
+ * A class to test MsgPack serialization on the Scala side, that will be deserialized
+ * in Python
+ * @param str
+ * @param int
+ * @param double
+ */
+case class TestWritable(var str: String, var int: Int, var double: Double) extends Writable {
+ def this() = this("", 0, 0.0)
+
+ def getStr = str
+ def setStr(str: String) { this.str = str }
+ def getInt = int
+ def setInt(int: Int) { this.int = int }
+ def getDouble = double
+ def setDouble(double: Double) { this.double = double }
+
+ def write(out: DataOutput) = {
+ out.writeUTF(str)
+ out.writeInt(int)
+ out.writeDouble(double)
+ }
+
+ def readFields(in: DataInput) = {
+ str = in.readUTF()
+ int = in.readInt()
+ double = in.readDouble()
+ }
+}
+
+class TestConverter extends Converter[Any, Any] {
+ import collection.JavaConversions._
+ override def convert(obj: Any) = {
+ val m = obj.asInstanceOf[MapWritable]
+ seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq)
+ }
+}
+
+/**
+ * This object contains method to generate SequenceFile test data and write it to a
+ * given directory (probably a temp directory)
+ */
+object WriteInputFormatTestDataGenerator {
+ import SparkContext._
+
+ def main(args: Array[String]) {
+ val path = args(0)
+ val sc = new JavaSparkContext("local[4]", "test-writables")
+ generateData(path, sc)
+ }
+
+ def generateData(path: String, jsc: JavaSparkContext) {
+ val sc = jsc.sc
+
+ val basePath = s"$path/sftestdata/"
+ val textPath = s"$basePath/sftext/"
+ val intPath = s"$basePath/sfint/"
+ val doublePath = s"$basePath/sfdouble/"
+ val arrPath = s"$basePath/sfarray/"
+ val mapPath = s"$basePath/sfmap/"
+ val classPath = s"$basePath/sfclass/"
+ val bytesPath = s"$basePath/sfbytes/"
+ val boolPath = s"$basePath/sfbool/"
+ val nullPath = s"$basePath/sfnull/"
+
+ /*
+ * Create test data for IntWritable, DoubleWritable, Text, BytesWritable,
+ * BooleanWritable and NullWritable
+ */
+ val intKeys = Seq((1, "aa"), (2, "bb"), (2, "aa"), (3, "cc"), (2, "bb"), (1, "aa"))
+ sc.parallelize(intKeys).saveAsSequenceFile(intPath)
+ sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath)
+ sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath)
+ sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes) }).saveAsSequenceFile(bytesPath)
+ val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false))
+ sc.parallelize(bools).saveAsSequenceFile(boolPath)
+ sc.parallelize(intKeys).map{ case (k, v) =>
+ (new IntWritable(k), NullWritable.get())
+ }.saveAsSequenceFile(nullPath)
+
+ // Create test data for ArrayWritable
+ val data = Seq(
+ (1, Array(1.0, 2.0, 3.0)),
+ (2, Array(3.0, 4.0, 5.0)),
+ (3, Array(4.0, 5.0, 6.0))
+ )
+ sc.parallelize(data, numSlices = 2)
+ .map{ case (k, v) =>
+ (new IntWritable(k), new ArrayWritable(classOf[DoubleWritable], v.map(new DoubleWritable(_))))
+ }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, ArrayWritable]](arrPath)
+
+ // Create test data for MapWritable, with keys DoubleWritable and values Text
+ val mapData = Seq(
+ (1, Map(2.0 -> "aa")),
+ (2, Map(3.0 -> "bb")),
+ (2, Map(1.0 -> "cc")),
+ (3, Map(2.0 -> "dd")),
+ (2, Map(1.0 -> "aa")),
+ (1, Map(3.0 -> "bb"))
+ )
+ sc.parallelize(mapData, numSlices = 2).map{ case (i, m) =>
+ val mw = new MapWritable()
+ val k = m.keys.head
+ val v = m.values.head
+ mw.put(new DoubleWritable(k), new Text(v))
+ (new IntWritable(i), mw)
+ }.saveAsSequenceFile(mapPath)
+
+ // Create test data for arbitrary custom writable TestWritable
+ val testClass = Seq(
+ ("1", TestWritable("test1", 123, 54.0)),
+ ("2", TestWritable("test2", 456, 8762.3)),
+ ("1", TestWritable("test3", 123, 423.1)),
+ ("3", TestWritable("test56", 456, 423.5)),
+ ("2", TestWritable("test2", 123, 5435.2))
+ )
+ val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
+ rdd.saveAsNewAPIHadoopFile(classPath,
+ classOf[Text], classOf[TestWritable],
+ classOf[SequenceFileOutputFormat[Text, TestWritable]])
+ }
+
+
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index 153eee3bc5889..f1032ea8dbada 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -360,6 +360,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) {
|
| --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G).
|
+ | --help, -h Show this help message and exit
+ | --verbose, -v Print additional debug output
+ |
| Spark standalone with cluster deploy mode only:
| --driver-cores NUM Cores for driver (Default: 1).
| --supervise If given, restarts the driver on failure.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index a41286d3e4a00..9cd79d262ea53 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -1024,7 +1024,7 @@ private[spark] class BlockManager(
if (blockId.isShuffle) {
// Reducer may need to read many local shuffle blocks and will wrap them into Iterators
// at the beginning. The wrapping will cost some memory (compression instance
- // initialization, etc.). Reducer read shuffle blocks one by one so we could do the
+ // initialization, etc.). Reducer reads shuffle blocks one by one so we could do the
// wrapping lazily to save memory.
class LazyProxyIterator(f: => Iterator[Any]) extends Iterator[Any] {
lazy val proxy = f
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index a43314f48112f..1b104253d545d 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -168,7 +168,7 @@ private[spark] object UIUtils extends Logging {
-
+
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index c92b6dc96c8eb..6f1fd25764544 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -141,6 +141,22 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
assert(sched.finishedManagers.contains(manager))
}
+ test("skip unsatisfiable locality levels") {
+ sc = new SparkContext("local", "test")
+ val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execC", "host2"))
+ val taskSet = FakeTask.createTaskSet(1, Seq(TaskLocation("host1", "execB")))
+ val clock = new FakeClock
+ val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
+
+ // An executor that is not NODE_LOCAL should be rejected.
+ assert(manager.resourceOffer("execC", "host2", ANY) === None)
+
+ // Because there are no alive PROCESS_LOCAL executors, the base locality level should be
+ // NODE_LOCAL. So, we should schedule the task on this offered NODE_LOCAL executor before
+ // any of the locality wait timers expire.
+ assert(manager.resourceOffer("execA", "host1", ANY).get.index === 0)
+ }
+
test("basic delay scheduling") {
sc = new SparkContext("local", "test")
val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
diff --git a/dev/merge_spark_pr.py b/dev/merge_spark_pr.py
index e3ac32ef1a12e..ffb70096d6014 100755
--- a/dev/merge_spark_pr.py
+++ b/dev/merge_spark_pr.py
@@ -128,8 +128,9 @@ def merge_pr(pr_num, target_ref):
merge_message_flags = []
- for p in [title, body]:
- merge_message_flags += ["-m", p]
+ merge_message_flags += ["-m", title]
+ if body != None:
+ merge_message_flags += ["-m", body]
authors = "\n".join(["Author: %s" % a for a in distinct_authors])
diff --git a/dev/run-tests b/dev/run-tests
index 93d6692f83ca8..c82a47ebb618b 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -73,9 +73,6 @@ fi
echo "========================================================================="
echo "Running PySpark tests"
echo "========================================================================="
-if [ -z "$PYSPARK_PYTHON" ]; then
- export PYSPARK_PYTHON=/usr/local/bin/python2.7
-fi
./python/run-tests
echo "========================================================================="
diff --git a/docs/programming-guide.md b/docs/programming-guide.md
index 7d77e640d0e4b..7989e02dfb732 100644
--- a/docs/programming-guide.md
+++ b/docs/programming-guide.md
@@ -359,8 +359,7 @@ Apart from text files, Spark's Java API also supports several other data formats
-PySpark can create distributed datasets from any file system supported by Hadoop, including your local file system, HDFS, KFS, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc.
-The current API is limited to text files, but support for binary Hadoop InputFormats is expected in future versions.
+PySpark can create distributed datasets from any storage source supported by Hadoop, including your local file system, HDFS, Cassandra, HBase, [Amazon S3](http://wiki.apache.org/hadoop/AmazonS3), etc. Spark supports text files, [SequenceFiles](http://hadoop.apache.org/common/docs/current/api/org/apache/hadoop/mapred/SequenceFileInputFormat.html), and any other Hadoop [InputFormat](http://hadoop.apache.org/docs/stable/api/org/apache/hadoop/mapred/InputFormat.html).
Text file RDDs can be created using `SparkContext`'s `textFile` method. This method takes an URI for the file (either a local path on the machine, or a `hdfs://`, `s3n://`, etc URI) and reads it as a collection of lines. Here is an example invocation:
@@ -378,11 +377,90 @@ Some notes on reading files with Spark:
* The `textFile` method also takes an optional second argument for controlling the number of slices of the file. By default, Spark creates one slice for each block of the file (blocks being 64MB by default in HDFS), but you can also ask for a higher number of slices by passing a larger value. Note that you cannot have fewer slices than blocks.
-Apart reading files as a collection of lines,
+Apart from reading files as a collection of lines,
`SparkContext.wholeTextFiles` lets you read a directory containing multiple small text files, and returns each of them as (filename, content) pairs. This is in contrast with `textFile`, which would return one record per line in each file.
-
+### SequenceFile and Hadoop InputFormats
+
+In addition to reading text files, PySpark supports reading ```SequenceFile```
+and any arbitrary ```InputFormat```.
+
+**Note** this feature is currently marked ```Experimental``` and is intended for advanced users. It may be replaced in future with read/write support based on SparkSQL, in which case SparkSQL is the preferred approach.
+
+#### Writable Support
+
+PySpark SequenceFile support loads an RDD within Java, and pickles the resulting Java objects using
+[Pyrolite](https://github.com/irmen/Pyrolite/). The following Writables are automatically converted:
+
+
+Writable Type | Python Type |
+Text | unicode str |
+IntWritable | int |
+FloatWritable | float |
+DoubleWritable | float |
+BooleanWritable | bool |
+BytesWritable | bytearray |
+NullWritable | None |
+ArrayWritable | list of primitives, or tuple of objects |
+MapWritable | dict |
+Custom Class conforming to Java Bean conventions |
+ dict of public properties (via JavaBean getters and setters) + __class__ for the class type |
+
+
+#### Loading SequenceFiles
+Similarly to text files, SequenceFiles can be loaded by specifying the path. The key and value
+classes can be specified, but for standard Writables this is not required.
+
+{% highlight python %}
+>>> rdd = sc.sequenceFile("path/to/sequencefile/of/doubles")
+>>> rdd.collect() # this example has DoubleWritable keys and Text values
+[(1.0, u'aa'),
+ (2.0, u'bb'),
+ (2.0, u'aa'),
+ (3.0, u'cc'),
+ (2.0, u'bb'),
+ (1.0, u'aa')]
+{% endhighlight %}
+
+#### Loading Other Hadoop InputFormats
+
+PySpark can also read any Hadoop InputFormat, for both 'new' and 'old' Hadoop APIs. If required,
+a Hadoop configuration can be passed in as a Python dict. Here is an example using the
+Elasticsearch ESInputFormat:
+
+{% highlight python %}
+$ SPARK_CLASSPATH=/path/to/elasticsearch-hadoop.jar ./bin/pyspark
+>>> conf = {"es.resource" : "index/type"} # assume Elasticsearch is running on localhost defaults
+>>> rdd = sc.newAPIHadoopRDD("org.elasticsearch.hadoop.mr.EsInputFormat",\
+ "org.apache.hadoop.io.NullWritable", "org.elasticsearch.hadoop.mr.LinkedMapWritable", conf=conf)
+>>> rdd.first() # the result is a MapWritable that is converted to a Python dict
+(u'Elasticsearch ID',
+ {u'field1': True,
+ u'field2': u'Some Text',
+ u'field3': 12345})
+{% endhighlight %}
+
+Note that, if the InputFormat simply depends on a Hadoop configuration and/or input path, and
+the key and value classes can easily be converted according to the above table,
+then this approach should work well for such cases.
+
+If you have custom serialized binary data (such as loading data from Cassandra / HBase) or custom
+classes that don't conform to the JavaBean requirements, then you will first need to
+transform that data on the Scala/Java side to something which can be handled by Pyrolite's pickler.
+A [Converter](api/scala/index.html#org.apache.spark.api.python.Converter) trait is provided
+for this. Simply extend this trait and implement your transformation code in the ```convert```
+method. Remember to ensure that this class, along with any dependencies required to access your ```InputFormat```, are packaged into your Spark job jar and included on the PySpark
+classpath.
+
+See the [Python examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python) and
+the [Converter examples]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/pythonconverters)
+for examples of using HBase and Cassandra ```InputFormat```.
+
+Future support for writing data out as ```SequenceFileOutputFormat``` and other ```OutputFormats```,
+is forthcoming.
+
+
diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py
new file mode 100644
index 0000000000000..39fa6b0d22ef5
--- /dev/null
+++ b/examples/src/main/python/cassandra_inputformat.py
@@ -0,0 +1,79 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+
+from pyspark import SparkContext
+
+"""
+Create data in Cassandra fist
+(following: https://wiki.apache.org/cassandra/GettingStarted)
+
+cqlsh> CREATE KEYSPACE test
+ ... WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 };
+cqlsh> use test;
+cqlsh:test> CREATE TABLE users (
+ ... user_id int PRIMARY KEY,
+ ... fname text,
+ ... lname text
+ ... );
+cqlsh:test> INSERT INTO users (user_id, fname, lname)
+ ... VALUES (1745, 'john', 'smith');
+cqlsh:test> INSERT INTO users (user_id, fname, lname)
+ ... VALUES (1744, 'john', 'doe');
+cqlsh:test> INSERT INTO users (user_id, fname, lname)
+ ... VALUES (1746, 'john', 'smith');
+cqlsh:test> SELECT * FROM users;
+
+ user_id | fname | lname
+---------+-------+-------
+ 1745 | john | smith
+ 1744 | john | doe
+ 1746 | john | smith
+"""
+if __name__ == "__main__":
+ if len(sys.argv) != 4:
+ print >> sys.stderr, """
+ Usage: cassandra_inputformat
+
+ Run with example jar:
+ ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_inputformat.py
+ Assumes you have some data in Cassandra already, running on , in and
+ """
+ exit(-1)
+
+ host = sys.argv[1]
+ keyspace = sys.argv[2]
+ cf = sys.argv[3]
+ sc = SparkContext(appName="CassandraInputFormat")
+
+ conf = {"cassandra.input.thrift.address":host,
+ "cassandra.input.thrift.port":"9160",
+ "cassandra.input.keyspace":keyspace,
+ "cassandra.input.columnfamily":cf,
+ "cassandra.input.partitioner.class":"Murmur3Partitioner",
+ "cassandra.input.page.row.size":"3"}
+ cass_rdd = sc.newAPIHadoopRDD(
+ "org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat",
+ "java.util.Map",
+ "java.util.Map",
+ keyConverter="org.apache.spark.examples.pythonconverters.CassandraCQLKeyConverter",
+ valueConverter="org.apache.spark.examples.pythonconverters.CassandraCQLValueConverter",
+ conf=conf)
+ output = cass_rdd.collect()
+ for (k, v) in output:
+ print (k, v)
diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py
new file mode 100644
index 0000000000000..3289d9880a0f5
--- /dev/null
+++ b/examples/src/main/python/hbase_inputformat.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import sys
+
+from pyspark import SparkContext
+
+"""
+Create test data in HBase first:
+
+hbase(main):016:0> create 'test', 'f1'
+0 row(s) in 1.0430 seconds
+
+hbase(main):017:0> put 'test', 'row1', 'f1', 'value1'
+0 row(s) in 0.0130 seconds
+
+hbase(main):018:0> put 'test', 'row2', 'f1', 'value2'
+0 row(s) in 0.0030 seconds
+
+hbase(main):019:0> put 'test', 'row3', 'f1', 'value3'
+0 row(s) in 0.0050 seconds
+
+hbase(main):020:0> put 'test', 'row4', 'f1', 'value4'
+0 row(s) in 0.0110 seconds
+
+hbase(main):021:0> scan 'test'
+ROW COLUMN+CELL
+ row1 column=f1:, timestamp=1401883411986, value=value1
+ row2 column=f1:, timestamp=1401883415212, value=value2
+ row3 column=f1:, timestamp=1401883417858, value=value3
+ row4 column=f1:, timestamp=1401883420805, value=value4
+4 row(s) in 0.0240 seconds
+"""
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print >> sys.stderr, """
+ Usage: hbase_inputformat
+
+ Run with example jar:
+ ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_inputformat.py
+ Assumes you have some data in HBase already, running on , in
+ """
+ exit(-1)
+
+ host = sys.argv[1]
+ table = sys.argv[2]
+ sc = SparkContext(appName="HBaseInputFormat")
+
+ conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table}
+ hbase_rdd = sc.newAPIHadoopRDD(
+ "org.apache.hadoop.hbase.mapreduce.TableInputFormat",
+ "org.apache.hadoop.hbase.io.ImmutableBytesWritable",
+ "org.apache.hadoop.hbase.client.Result",
+ valueConverter="org.apache.spark.examples.pythonconverters.HBaseConverter",
+ conf=conf)
+ output = hbase_rdd.collect()
+ for (k, v) in output:
+ print (k, v)
diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
index 9a00701f985f0..71f53af68f4d3 100644
--- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala
@@ -33,6 +33,7 @@ import org.apache.hadoop.mapreduce.Job
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.SparkContext._
+
/*
Need to create following keyspace and column family in cassandra before running this example
Start CQL shell using ./bin/cqlsh and execute following commands
diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala
index a8c338480e6e2..4893b017ed819 100644
--- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala
@@ -22,7 +22,7 @@ import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor}
import org.apache.hadoop.hbase.mapreduce.TableInputFormat
import org.apache.spark._
-import org.apache.spark.rdd.NewHadoopRDD
+
object HBaseTest {
def main(args: Array[String]) {
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
index b97cb8fb02823..e06f4dcd54442 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala
@@ -124,4 +124,6 @@ class CustomPartitioner(partitions: Int) extends Partitioner {
c.numPartitions == numPartitions
case _ => false
}
+
+ override def hashCode: Int = numPartitions
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala
new file mode 100644
index 0000000000000..29a65c7a5f295
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.pythonconverters
+
+import org.apache.spark.api.python.Converter
+import java.nio.ByteBuffer
+import org.apache.cassandra.utils.ByteBufferUtil
+import collection.JavaConversions.{mapAsJavaMap, mapAsScalaMap}
+
+
+/**
+ * Implementation of [[org.apache.spark.api.python.Converter]] that converts Cassandra
+ * output to a Map[String, Int]
+ */
+class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int]] {
+ override def convert(obj: Any): java.util.Map[String, Int] = {
+ val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]]
+ mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.toInt(bb)))
+ }
+}
+
+/**
+ * Implementation of [[org.apache.spark.api.python.Converter]] that converts Cassandra
+ * output to a Map[String, String]
+ */
+class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, String]] {
+ override def convert(obj: Any): java.util.Map[String, String] = {
+ val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]]
+ mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb)))
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala
similarity index 54%
rename from sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala
rename to examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala
index a0d29100f505a..42ae960bd64a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala
@@ -15,31 +15,19 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution
+package org.apache.spark.examples.pythonconverters
-private[sql] object DebugQuery {
- def apply(plan: SparkPlan): SparkPlan = {
- val visited = new collection.mutable.HashSet[Long]()
- plan transform {
- case s: SparkPlan if !visited.contains(s.id) =>
- visited += s.id
- DebugNode(s)
- }
- }
-}
+import org.apache.spark.api.python.Converter
+import org.apache.hadoop.hbase.client.Result
+import org.apache.hadoop.hbase.util.Bytes
-private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode {
- def references = Set.empty
- def output = child.output
- def execute() = {
- val childRdd = child.execute()
- println(
- s"""
- |=========================
- |${child.simpleString}
- |=========================
- """.stripMargin)
- childRdd.foreach(println(_))
- childRdd
+/**
+ * Implementation of [[org.apache.spark.api.python.Converter]] that converts a HBase Result
+ * to a String
+ */
+class HBaseConverter extends Converter[Any, String] {
+ override def convert(obj: Any): String = {
+ val result = obj.asInstanceOf[Result]
+ Bytes.toStringBinary(result.value())
}
}
diff --git a/pom.xml b/pom.xml
index 891468b21bfff..0d46bb4114f73 100644
--- a/pom.xml
+++ b/pom.xml
@@ -209,14 +209,14 @@
spring-releases
Spring Release Repository
- http://repo.spring.io/libs-release
+ http://repo.spring.io/libs-release
true
false
-
+
@@ -987,11 +987,15 @@
avro
+
+ 0.23.10
+
hadoop-2.2
+ 2.2.0
2.5.0
@@ -999,6 +1003,7 @@
hadoop-2.3
+ 2.3.0
2.5.0
0.9.0
@@ -1007,6 +1012,7 @@
hadoop-2.4
+ 2.4.0
2.5.0
0.9.0
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 211918f5a05ec..062bec2381a8f 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -342,6 +342,143 @@ def wholeTextFiles(self, path, minPartitions=None):
return RDD(self._jsc.wholeTextFiles(path, minPartitions), self,
PairDeserializer(UTF8Deserializer(), UTF8Deserializer()))
+ def _dictToJavaMap(self, d):
+ jm = self._jvm.java.util.HashMap()
+ if not d:
+ d = {}
+ for k, v in d.iteritems():
+ jm[k] = v
+ return jm
+
+ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
+ valueConverter=None, minSplits=None):
+ """
+ Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS,
+ a local file system (available on all nodes), or any Hadoop-supported file system URI.
+ The mechanism is as follows:
+ 1. A Java RDD is created from the SequenceFile or other InputFormat, and the key
+ and value Writable classes
+ 2. Serialization is attempted via Pyrolite pickling
+ 3. If this fails, the fallback is to call 'toString' on each key and value
+ 4. C{PickleSerializer} is used to deserialize pickled objects on the Python side
+
+ @param path: path to sequncefile
+ @param keyClass: fully qualified classname of key Writable class
+ (e.g. "org.apache.hadoop.io.Text")
+ @param valueClass: fully qualified classname of value Writable class
+ (e.g. "org.apache.hadoop.io.LongWritable")
+ @param keyConverter:
+ @param valueConverter:
+ @param minSplits: minimum splits in dataset
+ (default min(2, sc.defaultParallelism))
+ """
+ minSplits = minSplits or min(self.defaultParallelism, 2)
+ jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass,
+ keyConverter, valueConverter, minSplits)
+ return RDD(jrdd, self, PickleSerializer())
+
+ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
+ valueConverter=None, conf=None):
+ """
+ Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS,
+ a local file system (available on all nodes), or any Hadoop-supported file system URI.
+ The mechanism is the same as for sc.sequenceFile.
+
+ A Hadoop configuration can be passed in as a Python dict. This will be converted into a
+ Configuration in Java
+
+ @param path: path to Hadoop file
+ @param inputFormatClass: fully qualified classname of Hadoop InputFormat
+ (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat")
+ @param keyClass: fully qualified classname of key Writable class
+ (e.g. "org.apache.hadoop.io.Text")
+ @param valueClass: fully qualified classname of value Writable class
+ (e.g. "org.apache.hadoop.io.LongWritable")
+ @param keyConverter: (None by default)
+ @param valueConverter: (None by default)
+ @param conf: Hadoop configuration, passed in as a dict
+ (None by default)
+ """
+ jconf = self._dictToJavaMap(conf)
+ jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass,
+ valueClass, keyConverter, valueConverter, jconf)
+ return RDD(jrdd, self, PickleSerializer())
+
+ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
+ valueConverter=None, conf=None):
+ """
+ Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
+ Hadoop configuration, which is passed in as a Python dict.
+ This will be converted into a Configuration in Java.
+ The mechanism is the same as for sc.sequenceFile.
+
+ @param inputFormatClass: fully qualified classname of Hadoop InputFormat
+ (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat")
+ @param keyClass: fully qualified classname of key Writable class
+ (e.g. "org.apache.hadoop.io.Text")
+ @param valueClass: fully qualified classname of value Writable class
+ (e.g. "org.apache.hadoop.io.LongWritable")
+ @param keyConverter: (None by default)
+ @param valueConverter: (None by default)
+ @param conf: Hadoop configuration, passed in as a dict
+ (None by default)
+ """
+ jconf = self._dictToJavaMap(conf)
+ jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass,
+ valueClass, keyConverter, valueConverter, jconf)
+ return RDD(jrdd, self, PickleSerializer())
+
+ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
+ valueConverter=None, conf=None):
+ """
+ Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS,
+ a local file system (available on all nodes), or any Hadoop-supported file system URI.
+ The mechanism is the same as for sc.sequenceFile.
+
+ A Hadoop configuration can be passed in as a Python dict. This will be converted into a
+ Configuration in Java.
+
+ @param path: path to Hadoop file
+ @param inputFormatClass: fully qualified classname of Hadoop InputFormat
+ (e.g. "org.apache.hadoop.mapred.TextInputFormat")
+ @param keyClass: fully qualified classname of key Writable class
+ (e.g. "org.apache.hadoop.io.Text")
+ @param valueClass: fully qualified classname of value Writable class
+ (e.g. "org.apache.hadoop.io.LongWritable")
+ @param keyConverter: (None by default)
+ @param valueConverter: (None by default)
+ @param conf: Hadoop configuration, passed in as a dict
+ (None by default)
+ """
+ jconf = self._dictToJavaMap(conf)
+ jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass,
+ valueClass, keyConverter, valueConverter, jconf)
+ return RDD(jrdd, self, PickleSerializer())
+
+ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
+ valueConverter=None, conf=None):
+ """
+ Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
+ Hadoop configuration, which is passed in as a Python dict.
+ This will be converted into a Configuration in Java.
+ The mechanism is the same as for sc.sequenceFile.
+
+ @param inputFormatClass: fully qualified classname of Hadoop InputFormat
+ (e.g. "org.apache.hadoop.mapred.TextInputFormat")
+ @param keyClass: fully qualified classname of key Writable class
+ (e.g. "org.apache.hadoop.io.Text")
+ @param valueClass: fully qualified classname of value Writable class
+ (e.g. "org.apache.hadoop.io.LongWritable")
+ @param keyConverter: (None by default)
+ @param valueConverter: (None by default)
+ @param conf: Hadoop configuration, passed in as a dict
+ (None by default)
+ """
+ jconf = self._dictToJavaMap(conf)
+ jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass,
+ keyConverter, valueConverter, jconf)
+ return RDD(jrdd, self, PickleSerializer())
+
def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
return RDD(jrdd, self, input_deserializer)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index ca0a95578fd28..9c69c79236edc 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -250,7 +250,7 @@ def getCheckpointFile(self):
def map(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each element of this RDD.
-
+
>>> rdd = sc.parallelize(["b", "a", "c"])
>>> sorted(rdd.map(lambda x: (x, 1)).collect())
[('a', 1), ('b', 1), ('c', 1)]
@@ -312,6 +312,15 @@ def mapPartitionsWithSplit(self, f, preservesPartitioning=False):
"use mapPartitionsWithIndex instead", DeprecationWarning, stacklevel=2)
return self.mapPartitionsWithIndex(f, preservesPartitioning)
+ def getNumPartitions(self):
+ """
+ Returns the number of partitions in RDD
+ >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
+ >>> rdd.getNumPartitions()
+ 2
+ """
+ return self._jrdd.splits().size()
+
def filter(self, f):
"""
Return a new RDD containing only the elements that satisfy a predicate.
@@ -413,9 +422,9 @@ def union(self, other):
def intersection(self, other):
"""
- Return the intersection of this RDD and another one. The output will not
+ Return the intersection of this RDD and another one. The output will not
contain any duplicate elements, even if the input RDDs did.
-
+
Note that this method performs a shuffle internally.
>>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5])
@@ -571,14 +580,14 @@ def foreachPartition(self, f):
"""
Applies a function to each partition of this RDD.
- >>> def f(iterator):
- ... for x in iterator:
- ... print x
+ >>> def f(iterator):
+ ... for x in iterator:
+ ... print x
... yield None
>>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f)
"""
self.mapPartitions(f).collect() # Force evaluation
-
+
def collect(self):
"""
Return a list that contains all of the elements in this RDD.
@@ -673,7 +682,7 @@ def func(iterator):
yield acc
return self.mapPartitions(func).fold(zeroValue, combOp)
-
+
def max(self):
"""
@@ -692,7 +701,7 @@ def min(self):
1.0
"""
return self.reduce(min)
-
+
def sum(self):
"""
Add up the elements in this RDD.
@@ -786,7 +795,7 @@ def mergeMaps(m1, m2):
m1[k] += v
return m1
return self.mapPartitions(countPartition).reduce(mergeMaps)
-
+
def top(self, num):
"""
Get the top N elements from a RDD.
@@ -814,7 +823,7 @@ def merge(a, b):
def takeOrdered(self, num, key=None):
"""
Get the N elements from a RDD ordered in ascending order or as specified
- by the optional key function.
+ by the optional key function.
>>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6)
[1, 2, 3, 4, 5, 6]
@@ -834,7 +843,7 @@ def unKey(x, key_=None):
if key_ != None:
x = [i[1] for i in x]
return x
-
+
def merge(a, b):
return next(topNKeyedElems(a + b))
result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge)
@@ -1169,12 +1178,12 @@ def _mergeCombiners(iterator):
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)
-
+
def foldByKey(self, zeroValue, func, numPartitions=None):
"""
Merge the values for each key using an associative function "func" and a neutral "zeroValue"
- which may be added to the result an arbitrary number of times, and must not change
- the result (e.g., 0 for addition, or 1 for multiplication.).
+ which may be added to the result an arbitrary number of times, and must not change
+ the result (e.g., 0 for addition, or 1 for multiplication.).
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> from operator import add
@@ -1182,8 +1191,8 @@ def foldByKey(self, zeroValue, func, numPartitions=None):
[('a', 2), ('b', 1)]
"""
return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions)
-
-
+
+
# TODO: support variant with custom partitioner
def groupByKey(self, numPartitions=None):
"""
@@ -1302,7 +1311,7 @@ def keyBy(self, f):
def repartition(self, numPartitions):
"""
Return a new RDD that has exactly numPartitions partitions.
-
+
Can increase or decrease the level of parallelism in this RDD. Internally, this uses
a shuffle to redistribute data.
If you are decreasing the number of partitions in this RDD, consider using `coalesce`,
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 1f2a6ea941cf2..184ee810b861b 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -198,6 +198,151 @@ def func(x):
self.sc.parallelize([1]).foreach(func)
+class TestInputFormat(PySparkTestCase):
+
+ def setUp(self):
+ PySparkTestCase.setUp(self)
+ self.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(self.tempdir.name)
+ self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc)
+
+ def tearDown(self):
+ PySparkTestCase.tearDown(self)
+ shutil.rmtree(self.tempdir.name)
+
+ def test_sequencefiles(self):
+ basepath = self.tempdir.name
+ ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text").collect())
+ ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
+ self.assertEqual(ints, ei)
+
+ doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/",
+ "org.apache.hadoop.io.DoubleWritable",
+ "org.apache.hadoop.io.Text").collect())
+ ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')]
+ self.assertEqual(doubles, ed)
+
+ text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/",
+ "org.apache.hadoop.io.Text",
+ "org.apache.hadoop.io.Text").collect())
+ et = [(u'1', u'aa'),
+ (u'1', u'aa'),
+ (u'2', u'aa'),
+ (u'2', u'bb'),
+ (u'2', u'bb'),
+ (u'3', u'cc')]
+ self.assertEqual(text, et)
+
+ bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.BooleanWritable").collect())
+ eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)]
+ self.assertEqual(bools, eb)
+
+ nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.BooleanWritable").collect())
+ en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)]
+ self.assertEqual(nulls, en)
+
+ maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.MapWritable").collect())
+ em = [(1, {2.0: u'aa'}),
+ (1, {3.0: u'bb'}),
+ (2, {1.0: u'aa'}),
+ (2, {1.0: u'cc'}),
+ (2, {3.0: u'bb'}),
+ (3, {2.0: u'dd'})]
+ self.assertEqual(maps, em)
+
+ clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
+ "org.apache.hadoop.io.Text",
+ "org.apache.spark.api.python.TestWritable").collect())
+ ec = (u'1',
+ {u'__class__': u'org.apache.spark.api.python.TestWritable',
+ u'double': 54.0, u'int': 123, u'str': u'test1'})
+ self.assertEqual(clazz[0], ec)
+
+ def test_oldhadoop(self):
+ basepath = self.tempdir.name
+ ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapred.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text").collect())
+ ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
+ self.assertEqual(ints, ei)
+
+ hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ hello = self.sc.hadoopFile(hellopath,
+ "org.apache.hadoop.mapred.TextInputFormat",
+ "org.apache.hadoop.io.LongWritable",
+ "org.apache.hadoop.io.Text").collect()
+ result = [(0, u'Hello World!')]
+ self.assertEqual(hello, result)
+
+ def test_newhadoop(self):
+ basepath = self.tempdir.name
+ ints = sorted(self.sc.newAPIHadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text").collect())
+ ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
+ self.assertEqual(ints, ei)
+
+ hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ hello = self.sc.newAPIHadoopFile(hellopath,
+ "org.apache.hadoop.mapreduce.lib.input.TextInputFormat",
+ "org.apache.hadoop.io.LongWritable",
+ "org.apache.hadoop.io.Text").collect()
+ result = [(0, u'Hello World!')]
+ self.assertEqual(hello, result)
+
+ def test_newolderror(self):
+ basepath = self.tempdir.name
+ self.assertRaises(Exception, lambda: self.sc.hadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text"))
+
+ self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapred.SequenceFileInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text"))
+
+ def test_bad_inputs(self):
+ basepath = self.tempdir.name
+ self.assertRaises(Exception, lambda: self.sc.sequenceFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.io.NotValidWritable",
+ "org.apache.hadoop.io.Text"))
+ self.assertRaises(Exception, lambda: self.sc.hadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapred.NotValidInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text"))
+ self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile(
+ basepath + "/sftestdata/sfint/",
+ "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.Text"))
+
+ def test_converter(self):
+ basepath = self.tempdir.name
+ maps = sorted(self.sc.sequenceFile(
+ basepath + "/sftestdata/sfmap/",
+ "org.apache.hadoop.io.IntWritable",
+ "org.apache.hadoop.io.MapWritable",
+ valueConverter="org.apache.spark.api.python.TestConverter").collect())
+ em = [(1, [2.0]), (1, [3.0]), (2, [1.0]), (2, [1.0]), (2, [3.0]), (3, [2.0])]
+ self.assertEqual(maps, em)
+
+
class TestDaemon(unittest.TestCase):
def connect(self, port):
from socket import socket, AF_INET, SOCK_STREAM
diff --git a/python/run-tests b/python/run-tests
index 36a96121cbc0d..3b4501178c89f 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -32,7 +32,8 @@ rm -f unit-tests.log
rm -rf metastore warehouse
function run_test() {
- SPARK_TESTING=0 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log
+ echo "Running test: $1"
+ SPARK_TESTING=1 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log
FAILED=$((PIPESTATUS[0]||$FAILED))
# Fail and exit on the first test failure.
@@ -46,15 +47,17 @@ function run_test() {
}
+echo "Running PySpark tests. Output is in python/unit-tests.log."
+
run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "pyspark/conf.py"
if [ -n "$_RUN_SQL_TESTS" ]; then
run_test "pyspark/sql.py"
fi
-run_test "-m doctest pyspark/broadcast.py"
-run_test "-m doctest pyspark/accumulators.py"
-run_test "-m doctest pyspark/serializers.py"
+run_test "pyspark/broadcast.py"
+run_test "pyspark/accumulators.py"
+run_test "pyspark/serializers.py"
run_test "pyspark/tests.py"
run_test "pyspark/mllib/_common.py"
run_test "pyspark/mllib/classification.py"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index a404e7441a1bd..36758f3114e59 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -41,10 +41,25 @@ import org.apache.spark.sql.catalyst.types._
* for a SQL like language should checkout the HiveQL support in the sql/hive sub-project.
*/
class SqlParser extends StandardTokenParsers with PackratParsers {
+
def apply(input: String): LogicalPlan = {
- phrase(query)(new lexical.Scanner(input)) match {
- case Success(r, x) => r
- case x => sys.error(x.toString)
+ // Special-case out set commands since the value fields can be
+ // complex to handle without RegexParsers. Also this approach
+ // is clearer for the several possible cases of set commands.
+ if (input.trim.toLowerCase.startsWith("set")) {
+ input.trim.drop(3).split("=", 2).map(_.trim) match {
+ case Array("") => // "set"
+ SetCommand(None, None)
+ case Array(key) => // "set key"
+ SetCommand(Some(key), None)
+ case Array(key, value) => // "set key=value"
+ SetCommand(Some(key), Some(value))
+ }
+ } else {
+ phrase(query)(new lexical.Scanner(input)) match {
+ case Success(r, x) => r
+ case x => sys.error(x.toString)
+ }
}
}
@@ -131,6 +146,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val OUTER = Keyword("OUTER")
protected val RIGHT = Keyword("RIGHT")
protected val SELECT = Keyword("SELECT")
+ protected val SEMI = Keyword("SEMI")
protected val STRING = Keyword("STRING")
protected val SUM = Keyword("SUM")
protected val TRUE = Keyword("TRUE")
@@ -168,11 +184,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
}
}
- protected lazy val query: Parser[LogicalPlan] =
+ protected lazy val query: Parser[LogicalPlan] = (
select * (
- UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } |
- UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
- ) | insert
+ UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } |
+ UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) }
+ )
+ | insert
+ )
protected lazy val select: Parser[LogicalPlan] =
SELECT ~> opt(DISTINCT) ~ projections ~
@@ -241,6 +259,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected lazy val joinType: Parser[JoinType] =
INNER ^^^ Inner |
+ LEFT ~ SEMI ^^^ LeftSemi |
LEFT ~ opt(OUTER) ^^^ LeftOuter |
RIGHT ~ opt(OUTER) ^^^ RightOuter |
FULL ~ opt(OUTER) ^^^ FullOuter
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 406ffd6801e98..ccb8245cc2e7d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -19,6 +19,10 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.FullOuter
+import org.apache.spark.sql.catalyst.plans.LeftOuter
+import org.apache.spark.sql.catalyst.plans.RightOuter
+import org.apache.spark.sql.catalyst.plans.LeftSemi
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.types._
@@ -34,7 +38,7 @@ object Optimizer extends RuleExecutor[LogicalPlan] {
Batch("Filter Pushdown", FixedPoint(100),
CombineFilters,
PushPredicateThroughProject,
- PushPredicateThroughInnerJoin,
+ PushPredicateThroughJoin,
ColumnPruning) :: Nil
}
@@ -254,28 +258,98 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
/**
* Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be
- * evaluated using only the attributes of the left or right side of an inner join. Other
+ * evaluated using only the attributes of the left or right side of a join. Other
* [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the
* [[catalyst.plans.logical.Join Join]].
+ * And also Pushes down the join filter, where the `condition` can be evaluated using only the
+ * attributes of the left or right side of sub query when applicable.
+ *
+ * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details
*/
-object PushPredicateThroughInnerJoin extends Rule[LogicalPlan] with PredicateHelper {
+object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper {
+ // split the condition expression into 3 parts,
+ // (canEvaluateInLeftSide, canEvaluateInRightSide, haveToEvaluateWithBothSide)
+ private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = {
+ val (leftEvaluateCondition, rest) =
+ condition.partition(_.references subsetOf left.outputSet)
+ val (rightEvaluateCondition, commonCondition) =
+ rest.partition(_.references subsetOf right.outputSet)
+
+ (leftEvaluateCondition, rightEvaluateCondition, commonCondition)
+ }
+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case f @ Filter(filterCondition, Join(left, right, Inner, joinCondition)) =>
- val allConditions =
- splitConjunctivePredicates(filterCondition) ++
- joinCondition.map(splitConjunctivePredicates).getOrElse(Nil)
-
- // Split the predicates into those that can be evaluated on the left, right, and those that
- // must be evaluated after the join.
- val (rightConditions, leftOrJoinConditions) =
- allConditions.partition(_.references subsetOf right.outputSet)
- val (leftConditions, joinConditions) =
- leftOrJoinConditions.partition(_.references subsetOf left.outputSet)
-
- // Build the new left and right side, optionally with the pushed down filters.
- val newLeft = leftConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
- val newRight = rightConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
- Join(newLeft, newRight, Inner, joinConditions.reduceLeftOption(And))
+ // push the where condition down into join filter
+ case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) =>
+ val (leftFilterConditions, rightFilterConditions, commonFilterCondition) =
+ split(splitConjunctivePredicates(filterCondition), left, right)
+
+ joinType match {
+ case Inner =>
+ // push down the single side `where` condition into respective sides
+ val newLeft = leftFilterConditions.
+ reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+ val newRight = rightFilterConditions.
+ reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And)
+
+ Join(newLeft, newRight, Inner, newJoinCond)
+ case RightOuter =>
+ // push down the right side only `where` condition
+ val newLeft = left
+ val newRight = rightFilterConditions.
+ reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ val newJoinCond = joinCondition
+ val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond)
+
+ (leftFilterConditions ++ commonFilterCondition).
+ reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
+ case _ @ (LeftOuter | LeftSemi) =>
+ // push down the left side only `where` condition
+ val newLeft = leftFilterConditions.
+ reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+ val newRight = right
+ val newJoinCond = joinCondition
+ val newJoin = Join(newLeft, newRight, joinType, newJoinCond)
+
+ (rightFilterConditions ++ commonFilterCondition).
+ reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin)
+ case FullOuter => f // DO Nothing for Full Outer Join
+ }
+
+ // push down the join filter into sub query scanning if applicable
+ case f @ Join(left, right, joinType, joinCondition) =>
+ val (leftJoinConditions, rightJoinConditions, commonJoinCondition) =
+ split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right)
+
+ joinType match {
+ case Inner =>
+ // push down the single side only join filter for both sides sub queries
+ val newLeft = leftJoinConditions.
+ reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+ val newRight = rightJoinConditions.
+ reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ val newJoinCond = commonJoinCondition.reduceLeftOption(And)
+
+ Join(newLeft, newRight, Inner, newJoinCond)
+ case RightOuter =>
+ // push down the left side only join filter for left side sub query
+ val newLeft = leftJoinConditions.
+ reduceLeftOption(And).map(Filter(_, left)).getOrElse(left)
+ val newRight = right
+ val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
+
+ Join(newLeft, newRight, RightOuter, newJoinCond)
+ case _ @ (LeftOuter | LeftSemi) =>
+ // push down the right side only join filter for right sub query
+ val newLeft = left
+ val newRight = rightJoinConditions.
+ reduceLeftOption(And).map(Filter(_, right)).getOrElse(right)
+ val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And)
+
+ Join(newLeft, newRight, joinType, newJoinCond)
+ case FullOuter => f
+ }
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 4544b32958c7e..820ecfb78b52e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper {
case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) =>
logger.debug(s"Considering hash inner join on: ${predicates ++ condition}")
splitPredicates(predicates ++ condition, join)
+ // All predicates can be evaluated for left semi join (those that are in the WHERE
+ // clause can only from left table, so they can all be pushed down.)
+ case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) =>
+ logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}")
+ splitPredicates(predicates ++ condition, join)
case join @ Join(left, right, joinType, condition) =>
logger.debug(s"Considering hash join on: $condition")
splitPredicates(condition.toSeq, join)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
index ae8d7d3e4257f..613f4bb09daf5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala
@@ -22,3 +22,4 @@ case object Inner extends JoinType
case object LeftOuter extends JoinType
case object RightOuter extends JoinType
case object FullOuter extends JoinType
+case object LeftSemi extends JoinType
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 2b8fbdcde9d37..7eeb98aea6368 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.QueryPlan
-import org.apache.spark.sql.catalyst.types.StructType
+import org.apache.spark.sql.catalyst.types.{StringType, StructType}
import org.apache.spark.sql.catalyst.trees
abstract class LogicalPlan extends QueryPlan[LogicalPlan] {
@@ -102,7 +102,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] {
*/
abstract class Command extends LeafNode {
self: Product =>
- def output = Seq.empty
+ def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this
}
/**
@@ -111,11 +111,23 @@ abstract class Command extends LeafNode {
*/
case class NativeCommand(cmd: String) extends Command
+/**
+ * Commands of the form "SET (key) (= value)".
+ */
+case class SetCommand(key: Option[String], value: Option[String]) extends Command {
+ override def output = Seq(
+ AttributeReference("key", StringType, nullable = false)(),
+ AttributeReference("value", StringType, nullable = false)()
+ )
+}
+
/**
* Returned by a parser when the users only wants to see what query plan would be executed, without
* actually performing the execution.
*/
-case class ExplainCommand(plan: LogicalPlan) extends Command
+case class ExplainCommand(plan: LogicalPlan) extends Command {
+ override def output = Seq(AttributeReference("plan", StringType, nullable = false)())
+}
/**
* A logical plan node with single child.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 732708e146b04..d3347b622f3d8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.JoinType
+import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType}
import org.apache.spark.sql.catalyst.types._
case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode {
@@ -81,7 +81,12 @@ case class Join(
condition: Option[Expression]) extends BinaryNode {
def references = condition.map(_.references).getOrElse(Set.empty)
- def output = left.output ++ right.output
+ def output = joinType match {
+ case LeftSemi =>
+ left.output
+ case _ =>
+ left.output ++ right.output
+ }
}
case class InsertIntoTable(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index ef47850455a37..02cc665f8a8c7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -20,11 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators
import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.Inner
+import org.apache.spark.sql.catalyst.plans.FullOuter
+import org.apache.spark.sql.catalyst.plans.LeftOuter
+import org.apache.spark.sql.catalyst.plans.RightOuter
import org.apache.spark.sql.catalyst.rules._
-
-/* Implicit conversions */
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.junit.Test
class FilterPushdownSuite extends OptimizerTest {
@@ -35,7 +38,7 @@ class FilterPushdownSuite extends OptimizerTest {
Batch("Filter Pushdown", Once,
CombineFilters,
PushPredicateThroughProject,
- PushPredicateThroughInnerJoin) :: Nil
+ PushPredicateThroughJoin) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
@@ -161,6 +164,184 @@ class FilterPushdownSuite extends OptimizerTest {
comparePlans(optimized, correctAnswer)
}
+
+ test("joins: push down left outer join #1") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter)
+ .where("x.b".attr === 1 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 1)
+ val correctAnswer =
+ left.join(y, LeftOuter).where("y.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #1") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter)
+ .where("x.b".attr === 1 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val right = testRelation.where('b === 2).subquery('d)
+ val correctAnswer =
+ x.join(right, RightOuter).where("x.b".attr === 1).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down left outer join #2") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter, Some("x.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 2).subquery('d)
+ val correctAnswer =
+ left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #2") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val right = testRelation.where('b === 2).subquery('d)
+ val correctAnswer =
+ x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down left outer join #3") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 2).subquery('l)
+ val right = testRelation.where('b === 1).subquery('r)
+ val correctAnswer =
+ left.join(right, LeftOuter).where("r.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #3") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val right = testRelation.where('b === 2).subquery('r)
+ val correctAnswer =
+ x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down left outer join #4") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 2).subquery('l)
+ val right = testRelation.where('b === 1).subquery('r)
+ val correctAnswer =
+ left.join(right, LeftOuter).where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #4") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter, Some("y.b".attr === 1))
+ .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.subquery('l)
+ val right = testRelation.where('b === 2).subquery('r)
+ val correctAnswer =
+ left.join(right, RightOuter, Some("r.b".attr === 1)).
+ where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down left outer join #5") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, LeftOuter, Some("y.b".attr === 1 && "x.a".attr === 3))
+ .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('b === 2).subquery('l)
+ val right = testRelation.where('b === 1).subquery('r)
+ val correctAnswer =
+ left.join(right, LeftOuter, Some("l.a".attr===3)).
+ where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("joins: push down right outer join #5") {
+ val x = testRelation.subquery('x)
+ val y = testRelation.subquery('y)
+
+ val originalQuery = {
+ x.join(y, RightOuter, Some("y.b".attr === 1 && "x.a".attr === 3))
+ .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr)
+ }
+
+ val optimized = Optimize(originalQuery.analyze)
+ val left = testRelation.where('a === 3).subquery('l)
+ val right = testRelation.where('b === 2).subquery('r)
+ val correctAnswer =
+ left.join(right, RightOuter, Some("r.b".attr === 1)).
+ where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
test("joins: can't push down") {
val x = testRelation.subquery('x)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
new file mode 100644
index 0000000000000..b378252ba2f55
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.util.Properties
+
+import scala.collection.JavaConverters._
+
+/**
+ * SQLConf holds mutable config parameters and hints. These can be set and
+ * queried either by passing SET commands into Spark SQL's DSL
+ * functions (sql(), hql(), etc.), or by programmatically using setters and
+ * getters of this class. This class is thread-safe.
+ */
+trait SQLConf {
+
+ /** Number of partitions to use for shuffle operators. */
+ private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt
+
+ @transient
+ private val settings = java.util.Collections.synchronizedMap(
+ new java.util.HashMap[String, String]())
+
+ def set(props: Properties): Unit = {
+ props.asScala.foreach { case (k, v) => this.settings.put(k, v) }
+ }
+
+ def set(key: String, value: String): Unit = {
+ require(key != null, "key cannot be null")
+ require(value != null, s"value cannot be null for ${key}")
+ settings.put(key, value)
+ }
+
+ def get(key: String): String = {
+ if (!settings.containsKey(key)) {
+ throw new NoSuchElementException(key)
+ }
+ settings.get(key)
+ }
+
+ def get(key: String, defaultValue: String): String = {
+ if (!settings.containsKey(key)) defaultValue else settings.get(key)
+ }
+
+ def getAll: Array[(String, String)] = settings.asScala.toArray
+
+ def getOption(key: String): Option[String] = {
+ if (!settings.containsKey(key)) None else Some(settings.get(key))
+ }
+
+ def contains(key: String): Boolean = settings.containsKey(key)
+
+ def toDebugString: String = {
+ settings.synchronized {
+ settings.asScala.toArray.sorted.map{ case (k, v) => s"$k=$v" }.mkString("\n")
+ }
+ }
+
+ private[spark] def clear() {
+ settings.clear()
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 043be58edc91b..021e0e8245a0d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.{ScalaReflection, dsl}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.catalyst.optimizer.Optimizer
-import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
@@ -52,6 +52,7 @@ import org.apache.spark.sql.parquet.ParquetRelation
@AlphaComponent
class SQLContext(@transient val sparkContext: SparkContext)
extends Logging
+ with SQLConf
with dsl.ExpressionConversions
with Serializable {
@@ -190,9 +191,13 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected[sql] class SparkPlanner extends SparkStrategies {
val sparkContext = self.sparkContext
+ def numPartitions = self.numShufflePartitions
+
val strategies: Seq[Strategy] =
+ CommandStrategy(self) ::
TakeOrdered ::
PartialAggregation ::
+ LeftSemiJoin ::
HashJoin ::
ParquetOperations ::
BasicOperators ::
@@ -244,6 +249,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val planner = new SparkPlanner
+ @transient
+ protected[sql] lazy val emptyResult =
+ sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
+
/**
* Prepares a planned SparkPlan for execution by binding references to specific ordinals, and
* inserting shuffle operations as needed.
@@ -251,7 +260,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
@transient
protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] {
val batches =
- Batch("Add exchange", Once, AddExchange) ::
+ Batch("Add exchange", Once, AddExchange(self)) ::
Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil
}
@@ -262,6 +271,22 @@ class SQLContext(@transient val sparkContext: SparkContext)
protected abstract class QueryExecution {
def logical: LogicalPlan
+ def eagerlyProcess(plan: LogicalPlan): RDD[Row] = plan match {
+ case SetCommand(key, value) =>
+ // Only this case needs to be executed eagerly. The other cases will
+ // be taken care of when the actual results are being extracted.
+ // In the case of HiveContext, sqlConf is overridden to also pass the
+ // pair into its HiveConf.
+ if (key.isDefined && value.isDefined) {
+ set(key.get, value.get)
+ }
+ // It doesn't matter what we return here, since this is only used
+ // to force the evaluation to happen eagerly. To query the results,
+ // one must use SchemaRDD operations to extract them.
+ emptyResult
+ case _ => executedPlan.execute()
+ }
+
lazy val analyzed = analyzer(logical)
lazy val optimizedPlan = optimizer(analyzed)
// TODO: Don't just pick the first one...
@@ -269,7 +294,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan)
/** Internal version of the RDD. Avoids copies and has no schema */
- lazy val toRdd: RDD[Row] = executedPlan.execute()
+ lazy val toRdd: RDD[Row] = {
+ logical match {
+ case s: SetCommand => eagerlyProcess(s)
+ case _ => executedPlan.execute()
+ }
+ }
protected def stringOrError[A](f: => A): String =
try f.toString catch { case e: Throwable => e.toString }
@@ -284,11 +314,6 @@ class SQLContext(@transient val sparkContext: SparkContext)
|== Physical Plan ==
|${stringOrError(executedPlan)}
""".stripMargin.trim
-
- /**
- * Runs the query after interposing operators that print the result of each intermediate step.
- */
- def debugExec() = DebugQuery(executedPlan).execute().collect()
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 604914e547790..34d88fe4bd7de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -77,8 +77,7 @@ case class Aggregate(
resultAttribute: AttributeReference)
/** A list of aggregates that need to be computed for each group. */
- @transient
- private[this] lazy val computedAggregates = aggregateExpressions.flatMap { agg =>
+ private[this] val computedAggregates = aggregateExpressions.flatMap { agg =>
agg.collect {
case a: AggregateExpression =>
ComputedAggregate(
@@ -89,8 +88,7 @@ case class Aggregate(
}.toArray
/** The schema of the result of all aggregate evaluations */
- @transient
- private[this] lazy val computedSchema = computedAggregates.map(_.resultAttribute)
+ private[this] val computedSchema = computedAggregates.map(_.resultAttribute)
/** Creates a new aggregate buffer for a group. */
private[this] def newAggregateBuffer(): Array[AggregateFunction] = {
@@ -104,8 +102,7 @@ case class Aggregate(
}
/** Named attributes used to substitute grouping attributes into the final result. */
- @transient
- private[this] lazy val namedGroups = groupingExpressions.map {
+ private[this] val namedGroups = groupingExpressions.map {
case ne: NamedExpression => ne -> ne.toAttribute
case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute
}
@@ -114,16 +111,14 @@ case class Aggregate(
* A map of substitutions that are used to insert the aggregate expressions and grouping
* expression into the final result expression.
*/
- @transient
- private[this] lazy val resultMap =
+ private[this] val resultMap =
(computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap
/**
* Substituted version of aggregateExpressions expressions which are used to compute final
* output rows given a group and the result of all aggregate computations.
*/
- @transient
- private[this] lazy val resultExpressions = aggregateExpressions.map { agg =>
+ private[this] val resultExpressions = aggregateExpressions.map { agg =>
agg.transform {
case e: Expression if resultMap.contains(e) => resultMap(e)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 3b4acb72e87b5..cef294167f146 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf}
import org.apache.spark.rdd.ShuffledRDD
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.{SQLConf, SQLContext, Row}
import org.apache.spark.sql.catalyst.errors.attachTree
import org.apache.spark.sql.catalyst.expressions.{MutableProjection, RowOrdering}
import org.apache.spark.sql.catalyst.plans.physical._
@@ -86,9 +86,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
* [[catalyst.plans.physical.Distribution Distribution]] requirements for each operator by inserting
* [[Exchange]] Operators where required.
*/
-private[sql] object AddExchange extends Rule[SparkPlan] {
+private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
- val numPartitions = 150
+ def numPartitions = sqlContext.numShufflePartitions
def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
case operator: SparkPlan =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index cfa8bdae58b11..0455748d40eec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.{SQLContext, execution}
+import org.apache.spark.sql.{SQLConf, SQLContext, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
@@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>
+ object LeftSemiJoin extends Strategy with PredicateHelper {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ // Find left semi joins where at least some predicates can be evaluated by matching hash
+ // keys using the HashFilteredJoin pattern.
+ case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) =>
+ val semiJoin = execution.LeftSemiJoinHash(
+ leftKeys, rightKeys, planLater(left), planLater(right))
+ condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil
+ // no predicate can be evaluated by matching hash keys
+ case logical.Join(left, right, LeftSemi, condition) =>
+ execution.LeftSemiJoinBNL(
+ planLater(left), planLater(right), condition)(sparkContext) :: Nil
+ case _ => Nil
+ }
+ }
+
object HashJoin extends Strategy with PredicateHelper {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
// Find inner joins where at least some predicates can be evaluated by matching hash keys
@@ -177,8 +193,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
// Can we automate these 'pass through' operations?
object BasicOperators extends Strategy {
- // TODO: Set
- val numPartitions = 200
+ def numPartitions = self.numPartitions
+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Distinct(child) =>
execution.Aggregate(
@@ -217,4 +233,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case _ => Nil
}
}
+
+ case class CommandStrategy(context: SQLContext) extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.SetCommand(key, value) =>
+ Seq(execution.SetCommandPhysical(key, value, plan.output)(context))
+ case logical.ExplainCommand(child) =>
+ val qe = context.executePlan(child)
+ Seq(execution.ExplainCommandPhysical(qe.executedPlan, plan.output)(context))
+ case _ => Nil
+ }
+ }
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
new file mode 100644
index 0000000000000..9364506691f38
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{SQLContext, Row}
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class SetCommandPhysical(key: Option[String], value: Option[String], output: Seq[Attribute])
+ (@transient context: SQLContext) extends LeafNode {
+ def execute(): RDD[Row] = (key, value) match {
+ // Set value for key k; the action itself would
+ // have been performed in QueryExecution eagerly.
+ case (Some(k), Some(v)) => context.emptyResult
+ // Query the value bound to key k.
+ case (Some(k), None) =>
+ val resultString = context.getOption(k) match {
+ case Some(v) => s"$k=$v"
+ case None => s"$k is undefined"
+ }
+ context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](resultString))), 1)
+ // Query all key-value pairs that are set in the SQLConf of the context.
+ case (None, None) =>
+ val pairs = context.getAll
+ val rows = pairs.map { case (k, v) =>
+ new GenericRow(Array[Any](s"$k=$v"))
+ }.toSeq
+ // Assume config parameters can fit into one split (machine) ;)
+ context.sparkContext.parallelize(rows, 1)
+ // The only other case is invalid semantics and is impossible.
+ case _ => context.emptyResult
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ */
+@DeveloperApi
+case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute])
+ (@transient context: SQLContext) extends UnaryNode {
+ def execute(): RDD[Row] = {
+ val planString = new GenericRow(Array[Any](child.toString))
+ context.sparkContext.parallelize(Seq(planString))
+ }
+
+ override def otherCopyArgs = context :: Nil
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
new file mode 100644
index 0000000000000..c6fbd6d2f6930
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -0,0 +1,119 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import scala.collection.mutable.HashSet
+
+import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.SparkContext._
+import org.apache.spark.sql.{SchemaRDD, Row}
+
+/**
+ * :: DeveloperApi ::
+ * Contains methods for debugging query execution.
+ *
+ * Usage:
+ * {{{
+ * sql("SELECT key FROM src").debug
+ * }}}
+ */
+package object debug {
+
+ /**
+ * :: DeveloperApi ::
+ * Augments SchemaRDDs with debug methods.
+ */
+ @DeveloperApi
+ implicit class DebugQuery(query: SchemaRDD) {
+ def debug(implicit sc: SparkContext): Unit = {
+ val plan = query.queryExecution.executedPlan
+ val visited = new collection.mutable.HashSet[Long]()
+ val debugPlan = plan transform {
+ case s: SparkPlan if !visited.contains(s.id) =>
+ visited += s.id
+ DebugNode(sc, s)
+ }
+ println(s"Results returned: ${debugPlan.execute().count()}")
+ debugPlan.foreach {
+ case d: DebugNode => d.dumpStats()
+ case _ =>
+ }
+ }
+ }
+
+ private[sql] case class DebugNode(
+ @transient sparkContext: SparkContext,
+ child: SparkPlan) extends UnaryNode {
+ def references = Set.empty
+
+ def output = child.output
+
+ implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] {
+ def zero(initialValue: HashSet[String]): HashSet[String] = {
+ initialValue.clear()
+ initialValue
+ }
+
+ def addInPlace(v1: HashSet[String], v2: HashSet[String]): HashSet[String] = {
+ v1 ++= v2
+ v1
+ }
+ }
+
+ /**
+ * A collection of stats for each column of output.
+ * @param elementTypes the actual runtime types for the output. Useful when there are bugs
+ * causing the wrong data to be projected.
+ */
+ case class ColumnStat(
+ elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty))
+ val tupleCount = sparkContext.accumulator[Int](0)
+
+ val numColumns = child.output.size
+ val columnStats = Array.fill(child.output.size)(new ColumnStat())
+
+ def dumpStats(): Unit = {
+ println(s"== ${child.simpleString} ==")
+ println(s"Tuples output: ${tupleCount.value}")
+ child.output.zip(columnStats).foreach { case(attr, stat) =>
+ val actualDataTypes =stat.elementTypes.value.mkString("{", ",", "}")
+ println(s" ${attr.name} ${attr.dataType}: $actualDataTypes")
+ }
+ }
+
+ def execute() = {
+ child.execute().mapPartitions { iter =>
+ new Iterator[Row] {
+ def hasNext = iter.hasNext
+ def next() = {
+ val currentRow = iter.next()
+ tupleCount += 1
+ var i = 0
+ while (i < numColumns) {
+ val value = currentRow(i)
+ columnStats(i).elementTypes += HashSet(value.getClass.getName)
+ i += 1
+ }
+ currentRow
+ }
+ }
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 31cc26962ad93..88ff3d49a79b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -140,6 +140,137 @@ case class HashJoin(
}
}
+/**
+ * :: DeveloperApi ::
+ * Build the right table's join keys into a HashSet, and iteratively go through the left
+ * table, to find the if join keys are in the Hash set.
+ */
+@DeveloperApi
+case class LeftSemiJoinHash(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ val (buildPlan, streamedPlan) = (right, left)
+ val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
+
+ def output = left.output
+
+ @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
+ @transient lazy val streamSideKeyGenerator =
+ () => new MutableProjection(streamedKeys, streamedPlan.output)
+
+ def execute() = {
+
+ buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
+ val hashTable = new java.util.HashSet[Row]()
+ var currentRow: Row = null
+
+ // Create a Hash set of buildKeys
+ while (buildIter.hasNext) {
+ currentRow = buildIter.next()
+ val rowKey = buildSideKeyGenerator(currentRow)
+ if(!rowKey.anyNull) {
+ val keyExists = hashTable.contains(rowKey)
+ if (!keyExists) {
+ hashTable.add(rowKey)
+ }
+ }
+ }
+
+ new Iterator[Row] {
+ private[this] var currentStreamedRow: Row = _
+ private[this] var currentHashMatched: Boolean = false
+
+ private[this] val joinKeys = streamSideKeyGenerator()
+
+ override final def hasNext: Boolean =
+ streamIter.hasNext && fetchNext()
+
+ override final def next() = {
+ currentStreamedRow
+ }
+
+ /**
+ * Searches the streamed iterator for the next row that has at least one match in hashtable.
+ *
+ * @return true if the search is successful, and false the streamed iterator runs out of
+ * tuples.
+ */
+ private final def fetchNext(): Boolean = {
+ currentHashMatched = false
+ while (!currentHashMatched && streamIter.hasNext) {
+ currentStreamedRow = streamIter.next()
+ if (!joinKeys(currentStreamedRow).anyNull) {
+ currentHashMatched = hashTable.contains(joinKeys.currentValue)
+ }
+ }
+ currentHashMatched
+ }
+ }
+ }
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
+ * for hash join.
+ */
+@DeveloperApi
+case class LeftSemiJoinBNL(
+ streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
+ (@transient sc: SparkContext)
+ extends BinaryNode {
+ // TODO: Override requiredChildDistribution.
+
+ override def outputPartitioning: Partitioning = streamed.outputPartitioning
+
+ override def otherCopyArgs = sc :: Nil
+
+ def output = left.output
+
+ /** The Streamed Relation */
+ def left = streamed
+ /** The Broadcast relation */
+ def right = broadcast
+
+ @transient lazy val boundCondition =
+ InterpretedPredicate(
+ condition
+ .map(c => BindReferences.bindReference(c, left.output ++ right.output))
+ .getOrElse(Literal(true)))
+
+
+ def execute() = {
+ val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+
+ streamed.execute().mapPartitions { streamedIter =>
+ val joinedRow = new JoinedRow
+
+ streamedIter.filter(streamedRow => {
+ var i = 0
+ var matched = false
+
+ while (i < broadcastedRelation.value.size && !matched) {
+ val broadcastedRow = broadcastedRelation.value(i)
+ if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
+ matched = true
+ }
+ i += 1
+ }
+ matched
+ })
+ }
+ }
+}
+
/**
* :: DeveloperApi ::
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index d6072b402a044..d7f6abaf5d381 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -44,7 +44,7 @@ class QueryTest extends FunSuite {
fail(
s"""
|Exception thrown while executing query:
- |${rdd.logicalPlan}
+ |${rdd.queryExecution}
|== Exception ==
|$e
""".stripMargin)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
new file mode 100644
index 0000000000000..5eb73a4eff980
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -0,0 +1,71 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test._
+
+/* Implicits */
+import TestSQLContext._
+
+class SQLConfSuite extends QueryTest {
+
+ val testKey = "test.key.0"
+ val testVal = "test.val.0"
+
+ test("programmatic ways of basic setting and getting") {
+ assert(getOption(testKey).isEmpty)
+ assert(getAll.toSet === Set())
+
+ set(testKey, testVal)
+ assert(get(testKey) == testVal)
+ assert(get(testKey, testVal + "_") == testVal)
+ assert(getOption(testKey) == Some(testVal))
+ assert(contains(testKey))
+
+ // Tests SQLConf as accessed from a SQLContext is mutable after
+ // the latter is initialized, unlike SparkConf inside a SparkContext.
+ assert(TestSQLContext.get(testKey) == testVal)
+ assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
+ assert(TestSQLContext.getOption(testKey) == Some(testVal))
+ assert(TestSQLContext.contains(testKey))
+
+ clear()
+ }
+
+ test("parse SQL set commands") {
+ sql(s"set $testKey=$testVal")
+ assert(get(testKey, testVal + "_") == testVal)
+ assert(TestSQLContext.get(testKey, testVal + "_") == testVal)
+
+ sql("set mapred.reduce.tasks=20")
+ assert(get("mapred.reduce.tasks", "0") == "20")
+ sql("set mapred.reduce.tasks = 40")
+ assert(get("mapred.reduce.tasks", "0") == "40")
+
+ val key = "spark.sql.key"
+ val vs = "val0,val_1,val2.3,my_table"
+ sql(s"set $key=$vs")
+ assert(get(key, "0") == vs)
+
+ sql(s"set $key=")
+ assert(get(key, "0") == "")
+
+ clear()
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index aa0c426f6fcb3..de02bbc7e4700 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest {
arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq)
}
+ test("left semi greater than predicate") {
+ checkAnswer(
+ sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"),
+ Seq((3,1), (3,2))
+ )
+ }
+
test("index into array of arrays") {
checkAnswer(
sql(
@@ -129,6 +136,12 @@ class SQLQuerySuite extends QueryTest {
2.0)
}
+ test("average overflow") {
+ checkAnswer(
+ sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"),
+ Seq((2147483645.0,1),(2.0,2)))
+ }
+
test("count") {
checkAnswer(
sql("SELECT COUNT(*) FROM testData2"),
@@ -354,6 +367,41 @@ class SQLQuerySuite extends QueryTest {
(1, "abc"),
(2, "abc"),
(3, null)))
- }
-
+ }
+
+ test("SET commands semantics using sql()") {
+ clear()
+ val testKey = "test.key.0"
+ val testVal = "test.val.0"
+ val nonexistentKey = "nonexistent"
+
+ // "set" itself returns all config variables currently specified in SQLConf.
+ assert(sql("SET").collect().size == 0)
+
+ // "set key=val"
+ sql(s"SET $testKey=$testVal")
+ checkAnswer(
+ sql("SET"),
+ Seq(Seq(s"$testKey=$testVal"))
+ )
+
+ sql(s"SET ${testKey + testKey}=${testVal + testVal}")
+ checkAnswer(
+ sql("set"),
+ Seq(
+ Seq(s"$testKey=$testVal"),
+ Seq(s"${testKey + testKey}=${testVal + testVal}"))
+ )
+
+ // "set key"
+ checkAnswer(
+ sql(s"SET $testKey"),
+ Seq(Seq(s"$testKey=$testVal"))
+ )
+ checkAnswer(
+ sql(s"SET $nonexistentKey"),
+ Seq(Seq(s"$nonexistentKey is undefined"))
+ )
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 05de736bbce1b..330b20b315d63 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -30,6 +30,17 @@ object TestData {
(1 to 100).map(i => TestData(i, i.toString)))
testData.registerAsTable("testData")
+ case class LargeAndSmallInts(a: Int, b: Int)
+ val largeAndSmallInts: SchemaRDD =
+ TestSQLContext.sparkContext.parallelize(
+ LargeAndSmallInts(2147483644, 1) ::
+ LargeAndSmallInts(1, 2) ::
+ LargeAndSmallInts(2147483645, 1) ::
+ LargeAndSmallInts(2, 2) ::
+ LargeAndSmallInts(2147483646, 1) ::
+ LargeAndSmallInts(3, 2) :: Nil)
+ largeAndSmallInts.registerAsTable("largeAndSmallInts")
+
case class TestData2(a: Int, b: Int)
val testData2: SchemaRDD =
TestSQLContext.sparkContext.parallelize(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index c563d63627544..df6b118360d01 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -30,8 +30,8 @@ class PlannerSuite extends FunSuite {
test("unions are collapsed") {
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
- val logicalUnions = query collect { case u: logical.Union => u}
- val physicalUnions = planned collect { case u: execution.Union => u}
+ val logicalUnions = query collect { case u: logical.Union => u }
+ val physicalUnions = planned collect { case u: execution.Union => u }
assert(logicalUnions.size === 2)
assert(physicalUnions.size === 1)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index b21f24dad785d..64978215542ec 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -18,11 +18,11 @@
package org.apache.spark.sql
package hive
-import scala.language.implicitConversions
-
import java.io.{BufferedReader, File, InputStreamReader, PrintStream}
import java.util.{ArrayList => JArrayList}
+import scala.collection.JavaConversions._
+import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import org.apache.hadoop.hive.conf.HiveConf
@@ -30,20 +30,15 @@ import org.apache.hadoop.hive.ql.Driver
import org.apache.hadoop.hive.ql.processors._
import org.apache.hadoop.hive.ql.session.SessionState
-import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog}
import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema}
-import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand, ExplainCommand}
-import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.types._
import org.apache.spark.sql.execution._
-/* Implicit conversions */
-import scala.collection.JavaConversions._
-
/**
* Starts up an instance of hive where metadata is stored locally. An in-process metadata data is
* created with data stored in ./metadata. Warehouse data is stored in in ./warehouse.
@@ -55,10 +50,9 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) {
/** Sets up the system initially or after a RESET command */
protected def configure() {
- // TODO: refactor this so we can work with other databases.
- runSqlHive(
- s"set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$metastorePath;create=true")
- runSqlHive("set hive.metastore.warehouse.dir=" + warehousePath)
+ set("javax.jdo.option.ConnectionURL",
+ s"jdbc:derby:;databaseName=$metastorePath;create=true")
+ set("hive.metastore.warehouse.dir", warehousePath)
}
configure() // Must be called before initializing the catalog below.
@@ -129,12 +123,27 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
}
}
+ /**
+ * SQLConf and HiveConf contracts: when the hive session is first initialized, params in
+ * HiveConf will get picked up by the SQLConf. Additionally, any properties set by
+ * set() or a SET command inside hql() or sql() will be set in the SQLConf *as well as*
+ * in the HiveConf.
+ */
@transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState])
- @transient protected[hive] lazy val sessionState = new SessionState(hiveconf)
+ @transient protected[hive] lazy val sessionState = {
+ val ss = new SessionState(hiveconf)
+ set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf.
+ ss
+ }
sessionState.err = new PrintStream(outputBuffer, true, "UTF-8")
sessionState.out = new PrintStream(outputBuffer, true, "UTF-8")
+ override def set(key: String, value: String): Unit = {
+ super.set(key, value)
+ runSqlHive(s"SET $key=$value")
+ }
+
/* A catalyst metadata catalog that points to the Hive Metastore. */
@transient
override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog {
@@ -218,12 +227,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
val hiveContext = self
override val strategies: Seq[Strategy] = Seq(
+ CommandStrategy(self),
TakeOrdered,
ParquetOperations,
HiveTableScans,
DataSinks,
Scripts,
PartialAggregation,
+ LeftSemiJoin,
HashJoin,
BasicOperators,
CartesianProduct,
@@ -234,30 +245,31 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
@transient
override protected[sql] val planner = hivePlanner
- @transient
- protected lazy val emptyResult =
- sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1)
-
/** Extends QueryExecution with hive specific features. */
protected[sql] abstract class QueryExecution extends super.QueryExecution {
// TODO: Create mixin for the analyzer instead of overriding things here.
override lazy val optimizedPlan =
optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed)))
- override lazy val toRdd: RDD[Row] =
- analyzed match {
- case NativeCommand(cmd) =>
- val output = runSqlHive(cmd)
+ override lazy val toRdd: RDD[Row] = {
+ def processCmd(cmd: String): RDD[Row] = {
+ val output = runSqlHive(cmd)
+ if (output.size == 0) {
+ emptyResult
+ } else {
+ val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]]))
+ sparkContext.parallelize(asRows, 1)
+ }
+ }
- if (output.size == 0) {
- emptyResult
- } else {
- val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]]))
- sparkContext.parallelize(asRows, 1)
- }
- case _ =>
- executedPlan.execute().map(_.copy())
+ logical match {
+ case s: SetCommand => eagerlyProcess(s)
+ case _ => analyzed match {
+ case NativeCommand(cmd) => processCmd(cmd)
+ case _ => executedPlan.execute().map(_.copy())
+ }
}
+ }
protected val primitiveTypes =
Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType,
@@ -303,7 +315,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
*/
def stringResult(): Seq[String] = analyzed match {
case NativeCommand(cmd) => runSqlHive(cmd)
- case ExplainCommand(plan) => new QueryExecution { val logical = plan }.toString.split("\n")
+ case ExplainCommand(plan) => executePlan(plan).toString.split("\n")
case query =>
val result: Seq[Seq[Any]] = toRdd.collect().toSeq
// We need the types so we can output struct field names
@@ -316,6 +328,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
override def simpleString: String =
logical match {
case _: NativeCommand => ""
+ case _: SetCommand => ""
case _ => executedPlan.toString
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 7e91c16c6b93a..4e74d9bc909fa 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -207,8 +207,17 @@ private[hive] object HiveQl {
/** Returns a LogicalPlan for a given HiveQL string. */
def parseSql(sql: String): LogicalPlan = {
try {
- if (sql.toLowerCase.startsWith("set")) {
- NativeCommand(sql)
+ if (sql.trim.toLowerCase.startsWith("set")) {
+ // Split in two parts since we treat the part before the first "="
+ // as key, and the part after as value, which may contain other "=" signs.
+ sql.trim.drop(3).split("=", 2).map(_.trim) match {
+ case Array("") => // "set"
+ SetCommand(None, None)
+ case Array(key) => // "set key"
+ SetCommand(Some(key), None)
+ case Array(key, value) => // "set key=value"
+ SetCommand(Some(key), Some(value))
+ }
} else if (sql.toLowerCase.startsWith("add jar")) {
AddJar(sql.drop(8))
} else if (sql.toLowerCase.startsWith("add file")) {
@@ -685,6 +694,7 @@ private[hive] object HiveQl {
case "TOK_RIGHTOUTERJOIN" => RightOuter
case "TOK_LEFTOUTERJOIN" => LeftOuter
case "TOK_FULLOUTERJOIN" => FullOuter
+ case "TOK_LEFTSEMIJOIN" => LeftSemi
}
assert(other.size <= 1, "Unhandled join clauses.")
Join(nodeToRelation(relation1),
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala
similarity index 100%
rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala
rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-0-80b6466213face7fbcb0de044611e1f5 b/sql/hive/src/test/resources/golden/leftsemijoin-0-80b6466213face7fbcb0de044611e1f5
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-1-d1f6a3dea28a5f0fee08026bf33d9129 b/sql/hive/src/test/resources/golden/leftsemijoin-1-d1f6a3dea28a5f0fee08026bf33d9129
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea
new file mode 100644
index 0000000000000..25ce912507d55
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea
@@ -0,0 +1,4 @@
+Hank 2
+Hank 2
+Joe 2
+Joe 2
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-11-80b6466213face7fbcb0de044611e1f5 b/sql/hive/src/test/resources/golden/leftsemijoin-11-80b6466213face7fbcb0de044611e1f5
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-12-d1f6a3dea28a5f0fee08026bf33d9129 b/sql/hive/src/test/resources/golden/leftsemijoin-12-d1f6a3dea28a5f0fee08026bf33d9129
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-2-43d53504df013e6b35f81811138a167a b/sql/hive/src/test/resources/golden/leftsemijoin-2-43d53504df013e6b35f81811138a167a
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin-2-43d53504df013e6b35f81811138a167a
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-3-b07d292423312aafa5e5762a579decd2 b/sql/hive/src/test/resources/golden/leftsemijoin-3-b07d292423312aafa5e5762a579decd2
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-4-3ac2226efe7cb5d999c1c5e4ac2114be b/sql/hive/src/test/resources/golden/leftsemijoin-4-3ac2226efe7cb5d999c1c5e4ac2114be
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-5-9c307c0559d735960ce77efa95b2b17b b/sql/hive/src/test/resources/golden/leftsemijoin-5-9c307c0559d735960ce77efa95b2b17b
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-6-82921fc96eef547ec0f71027ee88298c b/sql/hive/src/test/resources/golden/leftsemijoin-6-82921fc96eef547ec0f71027ee88298c
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-7-b30aa3b4a45db6b64bb46b4d9bd32ff0 b/sql/hive/src/test/resources/golden/leftsemijoin-7-b30aa3b4a45db6b64bb46b4d9bd32ff0
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013
new file mode 100644
index 0000000000000..25ce912507d55
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013
@@ -0,0 +1,4 @@
+Hank 2
+Hank 2
+Joe 2
+Joe 2
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-9-c5efa6b8771a51610d655be461670e1e b/sql/hive/src/test/resources/golden/leftsemijoin-9-c5efa6b8771a51610d655be461670e1e
new file mode 100644
index 0000000000000..f1470bad5782b
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin-9-c5efa6b8771a51610d655be461670e1e
@@ -0,0 +1,2 @@
+2 Tie
+2 Tie
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-0-7087fb6281a34d00f1812d2ff4ba8b75 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-0-7087fb6281a34d00f1812d2ff4ba8b75
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-1-aa3f07f028027ffd13ab5535dc821593 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-1-aa3f07f028027ffd13ab5535dc821593
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-10-9914f44ecb6ae7587b62e5349ff60d04 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-10-9914f44ecb6ae7587b62e5349ff60d04
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-10-9914f44ecb6ae7587b62e5349ff60d04
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-11-2027ecb1495d5550c5d56abf6b95b0a7 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-11-2027ecb1495d5550c5d56abf6b95b0a7
new file mode 100644
index 0000000000000..6ed281c757a96
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-11-2027ecb1495d5550c5d56abf6b95b0a7
@@ -0,0 +1,2 @@
+1
+1
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-2-3f65953ae60375156367c54533978782 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-2-3f65953ae60375156367c54533978782
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-3-645cf8b871c9b27418d6fa1d1bda9a52 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-3-645cf8b871c9b27418d6fa1d1bda9a52
new file mode 100644
index 0000000000000..e69de29bb2d1d
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-4-333895fe6abca27c8edb5c91bfe10d2f b/sql/hive/src/test/resources/golden/leftsemijoin_mr-4-333895fe6abca27c8edb5c91bfe10d2f
new file mode 100644
index 0000000000000..6ed281c757a96
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-4-333895fe6abca27c8edb5c91bfe10d2f
@@ -0,0 +1,2 @@
+1
+1
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-5-896d0948c1df849df9764a6d8ad8fff9 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-5-896d0948c1df849df9764a6d8ad8fff9
new file mode 100644
index 0000000000000..179ef0e0209e9
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-5-896d0948c1df849df9764a6d8ad8fff9
@@ -0,0 +1,20 @@
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
+1
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-6-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/leftsemijoin_mr-6-b1e2ade89ae898650f0be4f796d8947b
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-6-b1e2ade89ae898650f0be4f796d8947b
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c b/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-8-c61b972d4409babe41d8963e841af45b b/sql/hive/src/test/resources/golden/leftsemijoin_mr-8-c61b972d4409babe41d8963e841af45b
new file mode 100644
index 0000000000000..573541ac9702d
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-8-c61b972d4409babe41d8963e841af45b
@@ -0,0 +1 @@
+0
diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-9-2027ecb1495d5550c5d56abf6b95b0a7 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-9-2027ecb1495d5550c5d56abf6b95b0a7
new file mode 100644
index 0000000000000..6ed281c757a96
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-9-2027ecb1495d5550c5d56abf6b95b0a7
@@ -0,0 +1,2 @@
+1
+1
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 0f954103a85f2..357c7e654bd20 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -138,6 +138,9 @@ abstract class HiveComparisonTest
val orderedAnswer = hiveQuery.logical match {
// Clean out non-deterministic time schema info.
+ // Hack: Hive simply prints the result of a SET command to screen,
+ // and does not return it as a query answer.
+ case _: SetCommand => Seq("0")
case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "")
case _: ExplainCommand => answer
case plan => if (isSorted(plan)) answer else answer.sorted
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 9031abf733cd4..fb8f272d5abfe 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -480,6 +480,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"lateral_view",
"lateral_view_cp",
"lateral_view_ppd",
+ "leftsemijoin",
+ "leftsemijoin_mr",
"lineage1",
"literal_double",
"literal_ints",
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 125cc18bfb2b5..6c239b02ed09a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -17,7 +17,9 @@
package org.apache.spark.sql.hive.execution
+import org.apache.spark.sql.Row
import org.apache.spark.sql.hive.test.TestHive._
+import org.apache.spark.sql.hive.test.TestHive
/**
* A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution.
@@ -159,4 +161,89 @@ class HiveQuerySuite extends HiveComparisonTest {
hql("SHOW TABLES").toString
hql("SELECT * FROM src").toString
}
+
+ test("SPARK-1704: Explain commands as a SchemaRDD") {
+ hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)")
+ val rdd = hql("explain select key, count(value) from src group by key")
+ assert(rdd.collect().size == 1)
+ assert(rdd.toString.contains("ExplainCommand"))
+ assert(rdd.filter(row => row.toString.contains("ExplainCommand")).collect().size == 0,
+ "actual contents of the result should be the plans of the query to be explained")
+ TestHive.reset()
+ }
+
+ test("parse HQL set commands") {
+ // Adapted from its SQL counterpart.
+ val testKey = "spark.sql.key.usedfortestonly"
+ val testVal = "val0,val_1,val2.3,my_table"
+
+ hql(s"set $testKey=$testVal")
+ assert(get(testKey, testVal + "_") == testVal)
+
+ hql("set mapred.reduce.tasks=20")
+ assert(get("mapred.reduce.tasks", "0") == "20")
+ hql("set mapred.reduce.tasks = 40")
+ assert(get("mapred.reduce.tasks", "0") == "40")
+
+ hql(s"set $testKey=$testVal")
+ assert(get(testKey, "0") == testVal)
+
+ hql(s"set $testKey=")
+ assert(get(testKey, "0") == "")
+ }
+
+ test("SET commands semantics for a HiveContext") {
+ // Adapted from its SQL counterpart.
+ val testKey = "spark.sql.key.usedfortestonly"
+ var testVal = "test.val.0"
+ val nonexistentKey = "nonexistent"
+ def fromRows(row: Array[Row]): Array[String] = row.map(_.getString(0))
+
+ clear()
+
+ // "set" itself returns all config variables currently specified in SQLConf.
+ assert(hql("set").collect().size == 0)
+
+ // "set key=val"
+ hql(s"SET $testKey=$testVal")
+ assert(fromRows(hql("SET").collect()) sameElements Array(s"$testKey=$testVal"))
+ assert(hiveconf.get(testKey, "") == testVal)
+
+ hql(s"SET ${testKey + testKey}=${testVal + testVal}")
+ assert(fromRows(hql("SET").collect()) sameElements
+ Array(
+ s"$testKey=$testVal",
+ s"${testKey + testKey}=${testVal + testVal}"))
+ assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
+
+ // "set key"
+ assert(fromRows(hql(s"SET $testKey").collect()) sameElements
+ Array(s"$testKey=$testVal"))
+ assert(fromRows(hql(s"SET $nonexistentKey").collect()) sameElements
+ Array(s"$nonexistentKey is undefined"))
+
+ // Assert that sql() should have the same effects as hql() by repeating the above using sql().
+ clear()
+ assert(sql("set").collect().size == 0)
+
+ sql(s"SET $testKey=$testVal")
+ assert(fromRows(sql("SET").collect()) sameElements Array(s"$testKey=$testVal"))
+ assert(hiveconf.get(testKey, "") == testVal)
+
+ sql(s"SET ${testKey + testKey}=${testVal + testVal}")
+ assert(fromRows(sql("SET").collect()) sameElements
+ Array(
+ s"$testKey=$testVal",
+ s"${testKey + testKey}=${testVal + testVal}"))
+ assert(hiveconf.get(testKey + testKey, "") == testVal + testVal)
+
+ assert(fromRows(sql(s"SET $testKey").collect()) sameElements
+ Array(s"$testKey=$testVal"))
+ assert(fromRows(sql(s"SET $nonexistentKey").collect()) sameElements
+ Array(s"$nonexistentKey is undefined"))
+ }
+
+ // Put tests that depend on specific Hive settings before these last two test,
+ // since they modify /clear stuff.
+
}
diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
index aeb3f0062df3b..801e8b381588f 100644
--- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala
@@ -23,6 +23,7 @@ import java.nio.ByteBuffer
import scala.collection.JavaConversions._
import scala.collection.mutable.{HashMap, ListBuffer, Map}
+import scala.util.{Try, Success, Failure}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs._
@@ -220,7 +221,7 @@ trait ClientBase extends Logging {
}
}
- var cachedSecondaryJarLinks = ListBuffer.empty[String]
+ val cachedSecondaryJarLinks = ListBuffer.empty[String]
val fileLists = List( (args.addJars, LocalResourceType.FILE, true),
(args.files, LocalResourceType.FILE, false),
(args.archives, LocalResourceType.ARCHIVE, false) )
@@ -378,7 +379,7 @@ trait ClientBase extends Logging {
}
}
-object ClientBase {
+object ClientBase extends Logging {
val SPARK_JAR: String = "__spark__.jar"
val APP_JAR: String = "__app__.jar"
val LOG4J_PROP: String = "log4j.properties"
@@ -388,37 +389,47 @@ object ClientBase {
def getSparkJar = sys.env.get("SPARK_JAR").getOrElse(SparkContext.jarOfClass(this.getClass).head)
- // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps
- def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) {
- val classpathEntries = Option(conf.getStrings(
- YarnConfiguration.YARN_APPLICATION_CLASSPATH)).getOrElse(
- getDefaultYarnApplicationClasspath())
- if (classpathEntries != null) {
- for (c <- classpathEntries) {
- YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, c.trim,
- File.pathSeparator)
- }
+ def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) = {
+ val classPathElementsToAdd = getYarnAppClasspath(conf) ++ getMRAppClasspath(conf)
+ for (c <- classPathElementsToAdd.flatten) {
+ YarnSparkHadoopUtil.addToEnvironment(
+ env,
+ Environment.CLASSPATH.name,
+ c.trim,
+ File.pathSeparator)
}
+ classPathElementsToAdd
+ }
- val mrClasspathEntries = Option(conf.getStrings(
- "mapreduce.application.classpath")).getOrElse(
- getDefaultMRApplicationClasspath())
- if (mrClasspathEntries != null) {
- for (c <- mrClasspathEntries) {
- YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, c.trim,
- File.pathSeparator)
- }
- }
+ private def getYarnAppClasspath(conf: Configuration): Option[Seq[String]] =
+ Option(conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) match {
+ case Some(s) => Some(s.toSeq)
+ case None => getDefaultYarnApplicationClasspath
}
- def getDefaultYarnApplicationClasspath(): Array[String] = {
- try {
- val field = classOf[MRJobConfig].getField("DEFAULT_YARN_APPLICATION_CLASSPATH")
- field.get(null).asInstanceOf[Array[String]]
- } catch {
- case err: NoSuchFieldError => null
- case err: NoSuchFieldException => null
+ private def getMRAppClasspath(conf: Configuration): Option[Seq[String]] =
+ Option(conf.getStrings("mapreduce.application.classpath")) match {
+ case Some(s) => Some(s.toSeq)
+ case None => getDefaultMRApplicationClasspath
+ }
+
+ def getDefaultYarnApplicationClasspath: Option[Seq[String]] = {
+ val triedDefault = Try[Seq[String]] {
+ val field = classOf[YarnConfiguration].getField("DEFAULT_YARN_APPLICATION_CLASSPATH")
+ val value = field.get(null).asInstanceOf[Array[String]]
+ value.toSeq
+ } recoverWith {
+ case e: NoSuchFieldException => Success(Seq.empty[String])
}
+
+ triedDefault match {
+ case f: Failure[_] =>
+ logError("Unable to obtain the default YARN Application classpath.", f.exception)
+ case s: Success[_] =>
+ logDebug(s"Using the default YARN application classpath: ${s.get.mkString(",")}")
+ }
+
+ triedDefault.toOption
}
/**
@@ -426,20 +437,30 @@ object ClientBase {
* classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String.
* So we need to use reflection to retrieve it.
*/
- def getDefaultMRApplicationClasspath(): Array[String] = {
- try {
+ def getDefaultMRApplicationClasspath: Option[Seq[String]] = {
+ val triedDefault = Try[Seq[String]] {
val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH")
- if (field.getType == classOf[String]) {
- StringUtils.getStrings(field.get(null).asInstanceOf[String])
+ val value = if (field.getType == classOf[String]) {
+ StringUtils.getStrings(field.get(null).asInstanceOf[String]).toArray
} else {
field.get(null).asInstanceOf[Array[String]]
}
- } catch {
- case err: NoSuchFieldError => null
- case err: NoSuchFieldException => null
+ value.toSeq
+ } recoverWith {
+ case e: NoSuchFieldException => Success(Seq.empty[String])
}
+
+ triedDefault match {
+ case f: Failure[_] =>
+ logError("Unable to obtain the default MR Application classpath.", f.exception)
+ case s: Success[_] =>
+ logDebug(s"Using the default MR application classpath: ${s.get.mkString(",")}")
+ }
+
+ triedDefault.toOption
}
+
/**
* Returns the java command line argument for setting up log4j. If there is a log4j.properties
* in the given local resources, it is used, otherwise the SPARK_LOG4J_CONF environment variable
@@ -481,12 +502,14 @@ object ClientBase {
def addClasspathEntry(path: String) = YarnSparkHadoopUtil.addToEnvironment(env,
Environment.CLASSPATH.name, path, File.pathSeparator)
/** Add entry to the classpath. Interpreted as a path relative to the working directory. */
- def addPwdClasspathEntry(entry: String) = addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + entry)
+ def addPwdClasspathEntry(entry: String) =
+ addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + entry)
extraClassPath.foreach(addClasspathEntry)
val cachedSecondaryJarLinks =
sparkConf.getOption(CONF_SPARK_YARN_SECONDARY_JARS).getOrElse("").split(",")
+ .filter(_.nonEmpty)
// Normally the users app.jar is last in case conflicts with spark jars
if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) {
addPwdClasspathEntry(APP_JAR)
diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
index e01ed5a57d697..039cf4f276119 100644
--- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
+++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -112,7 +112,7 @@ private[spark] class YarnClientSchedulerBackend(
override def stop() {
super.stop()
- client.stop()
+ client.stop
logInfo("Stopped")
}
diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala
new file mode 100644
index 0000000000000..608c6e92624c6
--- /dev/null
+++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.URI
+
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.MRJobConfig
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers._
+
+import scala.collection.JavaConversions._
+import scala.collection.mutable.{ HashMap => MutableHashMap }
+import scala.util.Try
+
+
+class ClientBaseSuite extends FunSuite {
+
+ test("default Yarn application classpath") {
+ ClientBase.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP))
+ }
+
+ test("default MR application classpath") {
+ ClientBase.getDefaultMRApplicationClasspath should be(Some(Fixtures.knownDefMRAppCP))
+ }
+
+ test("resultant classpath for an application that defines a classpath for YARN") {
+ withAppConf(Fixtures.mapYARNAppConf) { conf =>
+ val env = newEnv
+ ClientBase.populateHadoopClasspath(conf, env)
+ classpath(env) should be(
+ flatten(Fixtures.knownYARNAppCP, ClientBase.getDefaultMRApplicationClasspath))
+ }
+ }
+
+ test("resultant classpath for an application that defines a classpath for MR") {
+ withAppConf(Fixtures.mapMRAppConf) { conf =>
+ val env = newEnv
+ ClientBase.populateHadoopClasspath(conf, env)
+ classpath(env) should be(
+ flatten(ClientBase.getDefaultYarnApplicationClasspath, Fixtures.knownMRAppCP))
+ }
+ }
+
+ test("resultant classpath for an application that defines both classpaths, YARN and MR") {
+ withAppConf(Fixtures.mapAppConf) { conf =>
+ val env = newEnv
+ ClientBase.populateHadoopClasspath(conf, env)
+ classpath(env) should be(flatten(Fixtures.knownYARNAppCP, Fixtures.knownMRAppCP))
+ }
+ }
+
+ object Fixtures {
+
+ val knownDefYarnAppCP: Seq[String] =
+ getFieldValue[Array[String], Seq[String]](classOf[YarnConfiguration],
+ "DEFAULT_YARN_APPLICATION_CLASSPATH",
+ Seq[String]())(a => a.toSeq)
+
+
+ val knownDefMRAppCP: Seq[String] =
+ getFieldValue[String, Seq[String]](classOf[MRJobConfig],
+ "DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH",
+ Seq[String]())(a => a.split(","))
+
+ val knownYARNAppCP = Some(Seq("/known/yarn/path"))
+
+ val knownMRAppCP = Some(Seq("/known/mr/path"))
+
+ val mapMRAppConf =
+ Map("mapreduce.application.classpath" -> knownMRAppCP.map(_.mkString(":")).get)
+
+ val mapYARNAppConf =
+ Map(YarnConfiguration.YARN_APPLICATION_CLASSPATH -> knownYARNAppCP.map(_.mkString(":")).get)
+
+ val mapAppConf = mapYARNAppConf ++ mapMRAppConf
+ }
+
+ def withAppConf(m: Map[String, String] = Map())(testCode: (Configuration) => Any) {
+ val conf = new Configuration
+ m.foreach { case (k, v) => conf.set(k, v, "ClientBaseSpec") }
+ testCode(conf)
+ }
+
+ def newEnv = MutableHashMap[String, String]()
+
+ def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;")
+
+ def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray
+
+ def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B =
+ Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults)
+
+}
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
index c1dfe3f53b40b..33a60d978c586 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala
@@ -252,15 +252,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
try {
logInfo("Allocating " + args.numExecutors + " executors.")
// Wait until all containers have finished
- // TODO: This is a bit ugly. Can we make it nicer?
- // TODO: Handle container failure
yarnAllocator.addResourceRequests(args.numExecutors)
+ yarnAllocator.allocateResources()
// Exits the loop if the user thread exits.
while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) {
- if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) {
- finishApplicationMaster(FinalApplicationStatus.FAILED,
- "max number of executor failures reached")
- }
+ checkNumExecutorsFailed()
+ allocateMissingExecutor()
yarnAllocator.allocateResources()
ApplicationMaster.incrementAllocatorLoop(1)
Thread.sleep(100)
@@ -289,23 +286,31 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration,
}
}
+ private def allocateMissingExecutor() {
+ val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning -
+ yarnAllocator.getNumPendingAllocate
+ if (missingExecutorCount > 0) {
+ logInfo("Allocating %d containers to make up for (potentially) lost containers".
+ format(missingExecutorCount))
+ yarnAllocator.addResourceRequests(missingExecutorCount)
+ }
+ }
+
+ private def checkNumExecutorsFailed() {
+ if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) {
+ finishApplicationMaster(FinalApplicationStatus.FAILED,
+ "max number of executor failures reached")
+ }
+ }
+
private def launchReporterThread(_sleepTime: Long): Thread = {
val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime
val t = new Thread {
override def run() {
while (userThread.isAlive) {
- if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) {
- finishApplicationMaster(FinalApplicationStatus.FAILED,
- "max number of executor failures reached")
- }
- val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning -
- yarnAllocator.getNumPendingAllocate
- if (missingExecutorCount > 0) {
- logInfo("Allocating %d containers to make up for (potentially) lost containers".
- format(missingExecutorCount))
- yarnAllocator.addResourceRequests(missingExecutorCount)
- }
+ checkNumExecutorsFailed()
+ allocateMissingExecutor()
sendProgress()
Thread.sleep(sleepTime)
}
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 1b6bfb42a5c1c..393edd1f2d670 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.yarn.api._
import org.apache.hadoop.yarn.api.ApplicationConstants.Environment
import org.apache.hadoop.yarn.api.protocolrecords._
import org.apache.hadoop.yarn.api.records._
-import org.apache.hadoop.yarn.client.api.impl.YarnClientImpl
+import org.apache.hadoop.yarn.client.api.YarnClient
import org.apache.hadoop.yarn.conf.YarnConfiguration
import org.apache.hadoop.yarn.ipc.YarnRPC
import org.apache.hadoop.yarn.util.{Apps, Records}
@@ -37,7 +37,9 @@ import org.apache.spark.{Logging, SparkConf}
* Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's stable API.
*/
class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: SparkConf)
- extends YarnClientImpl with ClientBase with Logging {
+ extends ClientBase with Logging {
+
+ val yarnClient = YarnClient.createYarnClient
def this(clientArgs: ClientArguments, spConf: SparkConf) =
this(clientArgs, new Configuration(), spConf)
@@ -53,8 +55,8 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
def runApp(): ApplicationId = {
validateArgs()
// Initialize and start the client service.
- init(yarnConf)
- start()
+ yarnClient.init(yarnConf)
+ yarnClient.start()
// Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers).
logClusterResourceDetails()
@@ -63,7 +65,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
// interface).
// Get a new client application.
- val newApp = super.createApplication()
+ val newApp = yarnClient.createApplication()
val newAppResponse = newApp.getNewApplicationResponse()
val appId = newAppResponse.getApplicationId()
@@ -99,11 +101,11 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
}
def logClusterResourceDetails() {
- val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics
+ val clusterMetrics: YarnClusterMetrics = yarnClient.getYarnClusterMetrics
logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " +
clusterMetrics.getNumNodeManagers)
- val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue)
+ val queueInfo: QueueInfo = yarnClient.getQueueInfo(args.amQueue)
logInfo( """Queue info ... queueName: %s, queueCurrentCapacity: %s, queueMaxCapacity: %s,
queueApplicationCount = %s, queueChildQueueCount = %s""".format(
queueInfo.getQueueName,
@@ -132,15 +134,20 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa
def submitApp(appContext: ApplicationSubmissionContext) = {
// Submit the application to the applications manager.
logInfo("Submitting application to ASM")
- super.submitApplication(appContext)
+ yarnClient.submitApplication(appContext)
}
+ def getApplicationReport(appId: ApplicationId) =
+ yarnClient.getApplicationReport(appId)
+
+ def stop = yarnClient.stop
+
def monitorApplication(appId: ApplicationId): Boolean = {
val interval = sparkConf.getLong("spark.yarn.report.interval", 1000)
while (true) {
Thread.sleep(interval)
- val report = super.getApplicationReport(appId)
+ val report = yarnClient.getApplicationReport(appId)
logInfo("Application report from ASM: \n" +
"\t application identifier: " + appId.toString() + "\n" +
diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
index a4ce8766d347c..d93e5bb0225d5 100644
--- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
+++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala
@@ -200,17 +200,25 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
logInfo("Allocating " + args.numExecutors + " executors.")
// Wait until all containers have finished
- // TODO: This is a bit ugly. Can we make it nicer?
- // TODO: Handle container failure
-
yarnAllocator.addResourceRequests(args.numExecutors)
+ yarnAllocator.allocateResources()
while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) {
+ allocateMissingExecutor()
yarnAllocator.allocateResources()
Thread.sleep(100)
}
logInfo("All executors have launched.")
+ }
+ private def allocateMissingExecutor() {
+ val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning -
+ yarnAllocator.getNumPendingAllocate
+ if (missingExecutorCount > 0) {
+ logInfo("Allocating %d containers to make up for (potentially) lost containers".
+ format(missingExecutorCount))
+ yarnAllocator.addResourceRequests(missingExecutorCount)
+ }
}
// TODO: We might want to extend this to allocate more containers in case they die !
@@ -220,13 +228,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp
val t = new Thread {
override def run() {
while (!driverClosed) {
- val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning -
- yarnAllocator.getNumPendingAllocate
- if (missingExecutorCount > 0) {
- logInfo("Allocating %d containers to make up for (potentially) lost containers".
- format(missingExecutorCount))
- yarnAllocator.addResourceRequests(missingExecutorCount)
- }
+ allocateMissingExecutor()
sendProgress()
Thread.sleep(sleepTime)
}