diff --git a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala index 110bd0a9a0c41..55241d33cd3f0 100644 --- a/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala +++ b/bagel/src/test/scala/org/apache/spark/bagel/BagelSuite.scala @@ -80,7 +80,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo test("large number of iterations") { // This tests whether jobs with a large number of iterations finish in a reasonable time, // because non-memoized recursion in RDD or DAGScheduler used to cause them to hang - failAfter(10 seconds) { + failAfter(30 seconds) { sc = new SparkContext("local", "test") val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) @@ -101,7 +101,7 @@ class BagelSuite extends FunSuite with Assertions with BeforeAndAfter with Timeo sc = new SparkContext("local", "test") val verts = sc.parallelize((1 to 4).map(id => (id.toString, new TestVertex(true, 0)))) val msgs = sc.parallelize(Array[(String, TestMessage)]()) - val numSupersteps = 50 + val numSupersteps = 20 val result = Bagel.run(sc, verts, msgs, sc.defaultParallelism, StorageLevel.DISK_ONLY) { (self: TestVertex, msgs: Option[Array[TestMessage]], superstep: Int) => diff --git a/bin/pyspark b/bin/pyspark index d0fa56f31913f..114cbbc3a8a8e 100755 --- a/bin/pyspark +++ b/bin/pyspark @@ -86,6 +86,10 @@ else if [[ "$IPYTHON" = "1" ]]; then exec ipython $IPYTHON_OPTS else - exec "$PYSPARK_PYTHON" + if [[ -n $SPARK_TESTING ]]; then + exec "$PYSPARK_PYTHON" -m doctest + else + exec "$PYSPARK_PYTHON" + fi fi fi diff --git a/bin/run-example b/bin/run-example index 7caab31daef39..e7a5fe3914fbd 100755 --- a/bin/run-example +++ b/bin/run-example @@ -51,7 +51,7 @@ if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then EXAMPLE_CLASS="org.apache.spark.examples.$EXAMPLE_CLASS" fi -./bin/spark-submit \ +"$FWDIR"/bin/spark-submit \ --master $EXAMPLE_MASTER \ --class $EXAMPLE_CLASS \ "$SPARK_EXAMPLES_JAR" \ diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 9155159cf6aeb..e7f75481939a8 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -83,11 +83,17 @@ class HashPartitioner(partitions: Int) extends Partitioner { case _ => false } + + override def hashCode: Int = numPartitions } /** * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly * equal ranges. The ranges are determined by sampling the content of the RDD passed in. + * + * Note that the actual number of partitions created by the RangePartitioner might not be the same + * as the `partitions` parameter, in the case where the number of sampled records is less than + * the value of `partitions`. */ class RangePartitioner[K : Ordering : ClassTag, V]( partitions: Int, @@ -119,7 +125,7 @@ class RangePartitioner[K : Ordering : ClassTag, V]( } } - def numPartitions = partitions + def numPartitions = rangeBounds.length + 1 private val binarySearch: ((Array[K], K) => Int) = CollectionsUtils.makeBinarySearch[K] @@ -155,4 +161,16 @@ class RangePartitioner[K : Ordering : ClassTag, V]( case _ => false } + + override def hashCode(): Int = { + val prime = 31 + var result = 1 + var i = 0 + while (i < rangeBounds.length) { + result = prime * result + rangeBounds(i).hashCode + i += 1 + } + result = prime * result + ascending.hashCode + result + } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d941aea9d7eb2..d721aba709600 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -455,7 +455,7 @@ class SparkContext(config: SparkConf) extends Logging { */ def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = { hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], - minPartitions).map(pair => pair._2.toString) + minPartitions).map(pair => pair._2.toString).setName(path) } /** @@ -496,7 +496,7 @@ class SparkContext(config: SparkConf) extends Logging { classOf[String], classOf[String], updateConf, - minPartitions) + minPartitions).setName(path) } /** @@ -551,7 +551,7 @@ class SparkContext(config: SparkConf) extends Logging { inputFormatClass, keyClass, valueClass, - minPartitions) + minPartitions).setName(path) } /** @@ -623,7 +623,7 @@ class SparkContext(config: SparkConf) extends Logging { val job = new NewHadoopJob(conf) NewFileInputFormat.addInputPath(job, new Path(path)) val updatedConf = job.getConfiguration - new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf) + new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } /** diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala new file mode 100644 index 0000000000000..adaa1ef6cf9ff --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import org.apache.spark.rdd.RDD +import org.apache.spark.Logging +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io._ +import scala.util.{Failure, Success, Try} +import org.apache.spark.annotation.Experimental + + +/** + * :: Experimental :: + * A trait for use with reading custom classes in PySpark. Implement this trait and add custom + * transformation code by overriding the convert method. + */ +@Experimental +trait Converter[T, U] extends Serializable { + def convert(obj: T): U +} + +private[python] object Converter extends Logging { + + def getInstance(converterClass: Option[String]): Converter[Any, Any] = { + converterClass.map { cc => + Try { + val c = Class.forName(cc).newInstance().asInstanceOf[Converter[Any, Any]] + logInfo(s"Loaded converter: $cc") + c + } match { + case Success(c) => c + case Failure(err) => + logError(s"Failed to load converter: $cc") + throw err + } + }.getOrElse { new DefaultConverter } + } +} + +/** + * A converter that handles conversion of common [[org.apache.hadoop.io.Writable]] objects. + * Other objects are passed through without conversion. + */ +private[python] class DefaultConverter extends Converter[Any, Any] { + + /** + * Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or + * object representation + */ + private def convertWritable(writable: Writable): Any = { + import collection.JavaConversions._ + writable match { + case iw: IntWritable => iw.get() + case dw: DoubleWritable => dw.get() + case lw: LongWritable => lw.get() + case fw: FloatWritable => fw.get() + case t: Text => t.toString + case bw: BooleanWritable => bw.get() + case byw: BytesWritable => byw.getBytes + case n: NullWritable => null + case aw: ArrayWritable => aw.get().map(convertWritable(_)) + case mw: MapWritable => mapAsJavaMap(mw.map { case (k, v) => + (convertWritable(k), convertWritable(v)) + }.toMap) + case other => other + } + } + + def convert(obj: Any): Any = { + obj match { + case writable: Writable => + convertWritable(writable) + case _ => + obj + } + } +} + +/** Utilities for working with Python objects <-> Hadoop-related objects */ +private[python] object PythonHadoopUtil { + + /** + * Convert a [[java.util.Map]] of properties to a [[org.apache.hadoop.conf.Configuration]] + */ + def mapToConf(map: java.util.Map[String, String]): Configuration = { + import collection.JavaConversions._ + val conf = new Configuration() + map.foreach{ case (k, v) => conf.set(k, v) } + conf + } + + /** + * Merges two configurations, returns a copy of left with keys from right overwriting + * any matching keys in left + */ + def mergeConfs(left: Configuration, right: Configuration): Configuration = { + import collection.JavaConversions._ + val copy = new Configuration(left) + right.iterator().foreach(entry => copy.set(entry.getKey, entry.getValue)) + copy + } + + /** + * Converts an RDD of key-value pairs, where key and/or value could be instances of + * [[org.apache.hadoop.io.Writable]], into an RDD[(K, V)] + */ + def convertRDD[K, V](rdd: RDD[(K, V)], + keyConverter: Converter[Any, Any], + valueConverter: Converter[Any, Any]): RDD[(Any, Any)] = { + rdd.map { case (k, v) => (keyConverter.convert(k), valueConverter.convert(v)) } + } + +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala index 95bec5030bfdd..e230d222b8604 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonPartitioner.scala @@ -50,4 +50,6 @@ private[spark] class PythonPartitioner( case _ => false } + + override def hashCode: Int = 31 * numPartitions + pyPartitionFunctionId.hashCode } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d1df99300c5b1..f6570d335757a 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -28,6 +28,9 @@ import scala.util.Try import net.razorvine.pickle.{Pickler, Unpickler} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapred.{InputFormat, JobConf} +import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import org.apache.spark._ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast @@ -266,7 +269,7 @@ private object SpecialLengths { val TIMING_DATA = -3 } -private[spark] object PythonRDD { +private[spark] object PythonRDD extends Logging { val UTF8 = Charset.forName("UTF-8") /** @@ -346,6 +349,180 @@ private[spark] object PythonRDD { } } + /** + * Create an RDD from a path using [[org.apache.hadoop.mapred.SequenceFileInputFormat]], + * key and value class. + * A key and/or value converter class can optionally be passed in + * (see [[org.apache.spark.api.python.Converter]]) + */ + def sequenceFile[K, V]( + sc: JavaSparkContext, + path: String, + keyClassMaybeNull: String, + valueClassMaybeNull: String, + keyConverterClass: String, + valueConverterClass: String, + minSplits: Int) = { + val keyClass = Option(keyClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") + val valueClass = Option(valueClassMaybeNull).getOrElse("org.apache.hadoop.io.Text") + implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] + implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] + val kc = kcm.runtimeClass.asInstanceOf[Class[K]] + val vc = vcm.runtimeClass.asInstanceOf[Class[V]] + + val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits) + val keyConverter = Converter.getInstance(Option(keyConverterClass)) + val valueConverter = Converter.getInstance(Option(valueConverterClass)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + } + + /** + * Create an RDD from a file path, using an arbitrary [[org.apache.hadoop.mapreduce.InputFormat]], + * key and value class. + * A key and/or value converter class can optionally be passed in + * (see [[org.apache.spark.api.python.Converter]]) + */ + def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]]( + sc: JavaSparkContext, + path: String, + inputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String]) = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + val baseConf = sc.hadoopConfiguration() + val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf) + val rdd = + newAPIHadoopRDDFromClassNames[K, V, F](sc, + Some(path), inputFormatClass, keyClass, valueClass, mergedConf) + val keyConverter = Converter.getInstance(Option(keyConverterClass)) + val valueConverter = Converter.getInstance(Option(valueConverterClass)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + } + + /** + * Create an RDD from a [[org.apache.hadoop.conf.Configuration]] converted from a map that is + * passed in from Python, using an arbitrary [[org.apache.hadoop.mapreduce.InputFormat]], + * key and value class. + * A key and/or value converter class can optionally be passed in + * (see [[org.apache.spark.api.python.Converter]]) + */ + def newAPIHadoopRDD[K, V, F <: NewInputFormat[K, V]]( + sc: JavaSparkContext, + inputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String]) = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + val rdd = + newAPIHadoopRDDFromClassNames[K, V, F](sc, + None, inputFormatClass, keyClass, valueClass, conf) + val keyConverter = Converter.getInstance(Option(keyConverterClass)) + val valueConverter = Converter.getInstance(Option(valueConverterClass)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + } + + private def newAPIHadoopRDDFromClassNames[K, V, F <: NewInputFormat[K, V]]( + sc: JavaSparkContext, + path: Option[String] = None, + inputFormatClass: String, + keyClass: String, + valueClass: String, + conf: Configuration) = { + implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] + implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] + implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]] + val kc = kcm.runtimeClass.asInstanceOf[Class[K]] + val vc = vcm.runtimeClass.asInstanceOf[Class[V]] + val fc = fcm.runtimeClass.asInstanceOf[Class[F]] + val rdd = if (path.isDefined) { + sc.sc.newAPIHadoopFile[K, V, F](path.get, fc, kc, vc, conf) + } else { + sc.sc.newAPIHadoopRDD[K, V, F](conf, fc, kc, vc) + } + rdd + } + + /** + * Create an RDD from a file path, using an arbitrary [[org.apache.hadoop.mapred.InputFormat]], + * key and value class. + * A key and/or value converter class can optionally be passed in + * (see [[org.apache.spark.api.python.Converter]]) + */ + def hadoopFile[K, V, F <: InputFormat[K, V]]( + sc: JavaSparkContext, + path: String, + inputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String]) = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + val baseConf = sc.hadoopConfiguration() + val mergedConf = PythonHadoopUtil.mergeConfs(baseConf, conf) + val rdd = + hadoopRDDFromClassNames[K, V, F](sc, + Some(path), inputFormatClass, keyClass, valueClass, mergedConf) + val keyConverter = Converter.getInstance(Option(keyConverterClass)) + val valueConverter = Converter.getInstance(Option(valueConverterClass)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + } + + /** + * Create an RDD from a [[org.apache.hadoop.conf.Configuration]] converted from a map + * that is passed in from Python, using an arbitrary [[org.apache.hadoop.mapred.InputFormat]], + * key and value class + * A key and/or value converter class can optionally be passed in + * (see [[org.apache.spark.api.python.Converter]]) + */ + def hadoopRDD[K, V, F <: InputFormat[K, V]]( + sc: JavaSparkContext, + inputFormatClass: String, + keyClass: String, + valueClass: String, + keyConverterClass: String, + valueConverterClass: String, + confAsMap: java.util.HashMap[String, String]) = { + val conf = PythonHadoopUtil.mapToConf(confAsMap) + val rdd = + hadoopRDDFromClassNames[K, V, F](sc, + None, inputFormatClass, keyClass, valueClass, conf) + val keyConverter = Converter.getInstance(Option(keyConverterClass)) + val valueConverter = Converter.getInstance(Option(valueConverterClass)) + val converted = PythonHadoopUtil.convertRDD[K, V](rdd, keyConverter, valueConverter) + JavaRDD.fromRDD(SerDeUtil.rddToPython(converted)) + } + + private def hadoopRDDFromClassNames[K, V, F <: InputFormat[K, V]]( + sc: JavaSparkContext, + path: Option[String] = None, + inputFormatClass: String, + keyClass: String, + valueClass: String, + conf: Configuration) = { + implicit val kcm = ClassTag(Class.forName(keyClass)).asInstanceOf[ClassTag[K]] + implicit val vcm = ClassTag(Class.forName(valueClass)).asInstanceOf[ClassTag[V]] + implicit val fcm = ClassTag(Class.forName(inputFormatClass)).asInstanceOf[ClassTag[F]] + val kc = kcm.runtimeClass.asInstanceOf[Class[K]] + val vc = vcm.runtimeClass.asInstanceOf[Class[V]] + val fc = fcm.runtimeClass.asInstanceOf[Class[F]] + val rdd = if (path.isDefined) { + sc.sc.hadoopFile(path.get, fc, kc, vc) + } else { + sc.sc.hadoopRDD(new JobConf(conf), fc, kc, vc) + } + rdd + } + def writeUTF(str: String, dataOut: DataOutputStream) { val bytes = str.getBytes(UTF8) dataOut.writeInt(bytes.length) diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala new file mode 100644 index 0000000000000..9a012e7254901 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import scala.util.Try +import org.apache.spark.rdd.RDD +import org.apache.spark.Logging +import scala.util.Success +import scala.util.Failure +import net.razorvine.pickle.Pickler + + +/** Utilities for serialization / deserialization between Python and Java, using Pickle. */ +private[python] object SerDeUtil extends Logging { + + private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = { + val pickle = new Pickler + val kt = Try { + pickle.dumps(t._1) + } + val vt = Try { + pickle.dumps(t._2) + } + (kt, vt) match { + case (Failure(kf), Failure(vf)) => + logWarning(s""" + |Failed to pickle Java object as key: ${t._1.getClass.getSimpleName}, falling back + |to 'toString'. Error: ${kf.getMessage}""".stripMargin) + logWarning(s""" + |Failed to pickle Java object as value: ${t._2.getClass.getSimpleName}, falling back + |to 'toString'. Error: ${vf.getMessage}""".stripMargin) + (true, true) + case (Failure(kf), _) => + logWarning(s""" + |Failed to pickle Java object as key: ${t._1.getClass.getSimpleName}, falling back + |to 'toString'. Error: ${kf.getMessage}""".stripMargin) + (true, false) + case (_, Failure(vf)) => + logWarning(s""" + |Failed to pickle Java object as value: ${t._2.getClass.getSimpleName}, falling back + |to 'toString'. Error: ${vf.getMessage}""".stripMargin) + (false, true) + case _ => + (false, false) + } + } + + /** + * Convert an RDD of key-value pairs to an RDD of serialized Python objects, that is usable + * by PySpark. By default, if serialization fails, toString is called and the string + * representation is serialized + */ + def rddToPython(rdd: RDD[(Any, Any)]): RDD[Array[Byte]] = { + val (keyFailed, valueFailed) = checkPickle(rdd.first()) + rdd.mapPartitions { iter => + val pickle = new Pickler + iter.map { case (k, v) => + if (keyFailed && valueFailed) { + pickle.dumps(Array(k.toString, v.toString)) + } else if (keyFailed) { + pickle.dumps(Array(k.toString, v)) + } else if (!keyFailed && valueFailed) { + pickle.dumps(Array(k, v.toString)) + } else { + pickle.dumps(Array(k, v)) + } + } + } + } + +} + diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala new file mode 100644 index 0000000000000..f0e3fb9aff5a0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -0,0 +1,150 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import org.apache.spark.SparkContext +import org.apache.hadoop.io._ +import scala.Array +import java.io.{DataOutput, DataInput} +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat +import org.apache.spark.api.java.JavaSparkContext + +/** + * A class to test MsgPack serialization on the Scala side, that will be deserialized + * in Python + * @param str + * @param int + * @param double + */ +case class TestWritable(var str: String, var int: Int, var double: Double) extends Writable { + def this() = this("", 0, 0.0) + + def getStr = str + def setStr(str: String) { this.str = str } + def getInt = int + def setInt(int: Int) { this.int = int } + def getDouble = double + def setDouble(double: Double) { this.double = double } + + def write(out: DataOutput) = { + out.writeUTF(str) + out.writeInt(int) + out.writeDouble(double) + } + + def readFields(in: DataInput) = { + str = in.readUTF() + int = in.readInt() + double = in.readDouble() + } +} + +class TestConverter extends Converter[Any, Any] { + import collection.JavaConversions._ + override def convert(obj: Any) = { + val m = obj.asInstanceOf[MapWritable] + seqAsJavaList(m.keySet.map(w => w.asInstanceOf[DoubleWritable].get()).toSeq) + } +} + +/** + * This object contains method to generate SequenceFile test data and write it to a + * given directory (probably a temp directory) + */ +object WriteInputFormatTestDataGenerator { + import SparkContext._ + + def main(args: Array[String]) { + val path = args(0) + val sc = new JavaSparkContext("local[4]", "test-writables") + generateData(path, sc) + } + + def generateData(path: String, jsc: JavaSparkContext) { + val sc = jsc.sc + + val basePath = s"$path/sftestdata/" + val textPath = s"$basePath/sftext/" + val intPath = s"$basePath/sfint/" + val doublePath = s"$basePath/sfdouble/" + val arrPath = s"$basePath/sfarray/" + val mapPath = s"$basePath/sfmap/" + val classPath = s"$basePath/sfclass/" + val bytesPath = s"$basePath/sfbytes/" + val boolPath = s"$basePath/sfbool/" + val nullPath = s"$basePath/sfnull/" + + /* + * Create test data for IntWritable, DoubleWritable, Text, BytesWritable, + * BooleanWritable and NullWritable + */ + val intKeys = Seq((1, "aa"), (2, "bb"), (2, "aa"), (3, "cc"), (2, "bb"), (1, "aa")) + sc.parallelize(intKeys).saveAsSequenceFile(intPath) + sc.parallelize(intKeys.map{ case (k, v) => (k.toDouble, v) }).saveAsSequenceFile(doublePath) + sc.parallelize(intKeys.map{ case (k, v) => (k.toString, v) }).saveAsSequenceFile(textPath) + sc.parallelize(intKeys.map{ case (k, v) => (k, v.getBytes) }).saveAsSequenceFile(bytesPath) + val bools = Seq((1, true), (2, true), (2, false), (3, true), (2, false), (1, false)) + sc.parallelize(bools).saveAsSequenceFile(boolPath) + sc.parallelize(intKeys).map{ case (k, v) => + (new IntWritable(k), NullWritable.get()) + }.saveAsSequenceFile(nullPath) + + // Create test data for ArrayWritable + val data = Seq( + (1, Array(1.0, 2.0, 3.0)), + (2, Array(3.0, 4.0, 5.0)), + (3, Array(4.0, 5.0, 6.0)) + ) + sc.parallelize(data, numSlices = 2) + .map{ case (k, v) => + (new IntWritable(k), new ArrayWritable(classOf[DoubleWritable], v.map(new DoubleWritable(_)))) + }.saveAsNewAPIHadoopFile[SequenceFileOutputFormat[IntWritable, ArrayWritable]](arrPath) + + // Create test data for MapWritable, with keys DoubleWritable and values Text + val mapData = Seq( + (1, Map(2.0 -> "aa")), + (2, Map(3.0 -> "bb")), + (2, Map(1.0 -> "cc")), + (3, Map(2.0 -> "dd")), + (2, Map(1.0 -> "aa")), + (1, Map(3.0 -> "bb")) + ) + sc.parallelize(mapData, numSlices = 2).map{ case (i, m) => + val mw = new MapWritable() + val k = m.keys.head + val v = m.values.head + mw.put(new DoubleWritable(k), new Text(v)) + (new IntWritable(i), mw) + }.saveAsSequenceFile(mapPath) + + // Create test data for arbitrary custom writable TestWritable + val testClass = Seq( + ("1", TestWritable("test1", 123, 54.0)), + ("2", TestWritable("test2", 456, 8762.3)), + ("1", TestWritable("test3", 123, 423.1)), + ("3", TestWritable("test56", 456, 423.5)), + ("2", TestWritable("test2", 123, 5435.2)) + ) + val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) } + rdd.saveAsNewAPIHadoopFile(classPath, + classOf[Text], classOf[TestWritable], + classOf[SequenceFileOutputFormat[Text, TestWritable]]) + } + + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index 153eee3bc5889..f1032ea8dbada 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -360,6 +360,9 @@ private[spark] class SparkSubmitArguments(args: Seq[String]) { | | --executor-memory MEM Memory per executor (e.g. 1000M, 2G) (Default: 1G). | + | --help, -h Show this help message and exit + | --verbose, -v Print additional debug output + | | Spark standalone with cluster deploy mode only: | --driver-cores NUM Cores for driver (Default: 1). | --supervise If given, restarts the driver on failure. diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a41286d3e4a00..9cd79d262ea53 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1024,7 +1024,7 @@ private[spark] class BlockManager( if (blockId.isShuffle) { // Reducer may need to read many local shuffle blocks and will wrap them into Iterators // at the beginning. The wrapping will cost some memory (compression instance - // initialization, etc.). Reducer read shuffle blocks one by one so we could do the + // initialization, etc.). Reducer reads shuffle blocks one by one so we could do the // wrapping lazily to save memory. class LazyProxyIterator(f: => Iterator[Any]) extends Iterator[Any] { lazy val proxy = f diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index a43314f48112f..1b104253d545d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -168,7 +168,7 @@ private[spark] object UIUtils extends Logging { diff --git a/examples/src/main/python/cassandra_inputformat.py b/examples/src/main/python/cassandra_inputformat.py new file mode 100644 index 0000000000000..39fa6b0d22ef5 --- /dev/null +++ b/examples/src/main/python/cassandra_inputformat.py @@ -0,0 +1,79 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Create data in Cassandra fist +(following: https://wiki.apache.org/cassandra/GettingStarted) + +cqlsh> CREATE KEYSPACE test + ... WITH REPLICATION = { 'class' : 'SimpleStrategy', 'replication_factor' : 1 }; +cqlsh> use test; +cqlsh:test> CREATE TABLE users ( + ... user_id int PRIMARY KEY, + ... fname text, + ... lname text + ... ); +cqlsh:test> INSERT INTO users (user_id, fname, lname) + ... VALUES (1745, 'john', 'smith'); +cqlsh:test> INSERT INTO users (user_id, fname, lname) + ... VALUES (1744, 'john', 'doe'); +cqlsh:test> INSERT INTO users (user_id, fname, lname) + ... VALUES (1746, 'john', 'smith'); +cqlsh:test> SELECT * FROM users; + + user_id | fname | lname +---------+-------+------- + 1745 | john | smith + 1744 | john | doe + 1746 | john | smith +""" +if __name__ == "__main__": + if len(sys.argv) != 4: + print >> sys.stderr, """ + Usage: cassandra_inputformat + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/cassandra_inputformat.py + Assumes you have some data in Cassandra already, running on , in and + """ + exit(-1) + + host = sys.argv[1] + keyspace = sys.argv[2] + cf = sys.argv[3] + sc = SparkContext(appName="CassandraInputFormat") + + conf = {"cassandra.input.thrift.address":host, + "cassandra.input.thrift.port":"9160", + "cassandra.input.keyspace":keyspace, + "cassandra.input.columnfamily":cf, + "cassandra.input.partitioner.class":"Murmur3Partitioner", + "cassandra.input.page.row.size":"3"} + cass_rdd = sc.newAPIHadoopRDD( + "org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat", + "java.util.Map", + "java.util.Map", + keyConverter="org.apache.spark.examples.pythonconverters.CassandraCQLKeyConverter", + valueConverter="org.apache.spark.examples.pythonconverters.CassandraCQLValueConverter", + conf=conf) + output = cass_rdd.collect() + for (k, v) in output: + print (k, v) diff --git a/examples/src/main/python/hbase_inputformat.py b/examples/src/main/python/hbase_inputformat.py new file mode 100644 index 0000000000000..3289d9880a0f5 --- /dev/null +++ b/examples/src/main/python/hbase_inputformat.py @@ -0,0 +1,72 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import sys + +from pyspark import SparkContext + +""" +Create test data in HBase first: + +hbase(main):016:0> create 'test', 'f1' +0 row(s) in 1.0430 seconds + +hbase(main):017:0> put 'test', 'row1', 'f1', 'value1' +0 row(s) in 0.0130 seconds + +hbase(main):018:0> put 'test', 'row2', 'f1', 'value2' +0 row(s) in 0.0030 seconds + +hbase(main):019:0> put 'test', 'row3', 'f1', 'value3' +0 row(s) in 0.0050 seconds + +hbase(main):020:0> put 'test', 'row4', 'f1', 'value4' +0 row(s) in 0.0110 seconds + +hbase(main):021:0> scan 'test' +ROW COLUMN+CELL + row1 column=f1:, timestamp=1401883411986, value=value1 + row2 column=f1:, timestamp=1401883415212, value=value2 + row3 column=f1:, timestamp=1401883417858, value=value3 + row4 column=f1:, timestamp=1401883420805, value=value4 +4 row(s) in 0.0240 seconds +""" +if __name__ == "__main__": + if len(sys.argv) != 3: + print >> sys.stderr, """ + Usage: hbase_inputformat + + Run with example jar: + ./bin/spark-submit --driver-class-path /path/to/example/jar /path/to/examples/hbase_inputformat.py
+ Assumes you have some data in HBase already, running on , in
+ """ + exit(-1) + + host = sys.argv[1] + table = sys.argv[2] + sc = SparkContext(appName="HBaseInputFormat") + + conf = {"hbase.zookeeper.quorum": host, "hbase.mapreduce.inputtable": table} + hbase_rdd = sc.newAPIHadoopRDD( + "org.apache.hadoop.hbase.mapreduce.TableInputFormat", + "org.apache.hadoop.hbase.io.ImmutableBytesWritable", + "org.apache.hadoop.hbase.client.Result", + valueConverter="org.apache.spark.examples.pythonconverters.HBaseConverter", + conf=conf) + output = hbase_rdd.collect() + for (k, v) in output: + print (k, v) diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 9a00701f985f0..71f53af68f4d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -33,6 +33,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ + /* Need to create following keyspace and column family in cassandra before running this example Start CQL shell using ./bin/cqlsh and execute following commands diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index a8c338480e6e2..4893b017ed819 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -22,7 +22,7 @@ import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor} import org.apache.hadoop.hbase.mapreduce.TableInputFormat import org.apache.spark._ -import org.apache.spark.rdd.NewHadoopRDD + object HBaseTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala index b97cb8fb02823..e06f4dcd54442 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/PageRankUtils.scala @@ -124,4 +124,6 @@ class CustomPartitioner(partitions: Int) extends Partitioner { c.numPartitions == numPartitions case _ => false } + + override def hashCode: Int = numPartitions } diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala new file mode 100644 index 0000000000000..29a65c7a5f295 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/CassandraConverters.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.pythonconverters + +import org.apache.spark.api.python.Converter +import java.nio.ByteBuffer +import org.apache.cassandra.utils.ByteBufferUtil +import collection.JavaConversions.{mapAsJavaMap, mapAsScalaMap} + + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts Cassandra + * output to a Map[String, Int] + */ +class CassandraCQLKeyConverter extends Converter[Any, java.util.Map[String, Int]] { + override def convert(obj: Any): java.util.Map[String, Int] = { + val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] + mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.toInt(bb))) + } +} + +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts Cassandra + * output to a Map[String, String] + */ +class CassandraCQLValueConverter extends Converter[Any, java.util.Map[String, String]] { + override def convert(obj: Any): java.util.Map[String, String] = { + val result = obj.asInstanceOf[java.util.Map[String, ByteBuffer]] + mapAsJavaMap(result.mapValues(bb => ByteBufferUtil.string(bb))) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala similarity index 54% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala rename to examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala index a0d29100f505a..42ae960bd64a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverter.scala @@ -15,31 +15,19 @@ * limitations under the License. */ -package org.apache.spark.sql.execution +package org.apache.spark.examples.pythonconverters -private[sql] object DebugQuery { - def apply(plan: SparkPlan): SparkPlan = { - val visited = new collection.mutable.HashSet[Long]() - plan transform { - case s: SparkPlan if !visited.contains(s.id) => - visited += s.id - DebugNode(s) - } - } -} +import org.apache.spark.api.python.Converter +import org.apache.hadoop.hbase.client.Result +import org.apache.hadoop.hbase.util.Bytes -private[sql] case class DebugNode(child: SparkPlan) extends UnaryNode { - def references = Set.empty - def output = child.output - def execute() = { - val childRdd = child.execute() - println( - s""" - |========================= - |${child.simpleString} - |========================= - """.stripMargin) - childRdd.foreach(println(_)) - childRdd +/** + * Implementation of [[org.apache.spark.api.python.Converter]] that converts a HBase Result + * to a String + */ +class HBaseConverter extends Converter[Any, String] { + override def convert(obj: Any): String = { + val result = obj.asInstanceOf[Result] + Bytes.toStringBinary(result.value()) } } diff --git a/pom.xml b/pom.xml index 891468b21bfff..0d46bb4114f73 100644 --- a/pom.xml +++ b/pom.xml @@ -209,14 +209,14 @@ spring-releases Spring Release Repository - http://repo.spring.io/libs-release + http://repo.spring.io/libs-release true false - + @@ -987,11 +987,15 @@ avro + + 0.23.10 + hadoop-2.2 + 2.2.0 2.5.0 @@ -999,6 +1003,7 @@ hadoop-2.3 + 2.3.0 2.5.0 0.9.0 @@ -1007,6 +1012,7 @@ hadoop-2.4 + 2.4.0 2.5.0 0.9.0 diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 211918f5a05ec..062bec2381a8f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -342,6 +342,143 @@ def wholeTextFiles(self, path, minPartitions=None): return RDD(self._jsc.wholeTextFiles(path, minPartitions), self, PairDeserializer(UTF8Deserializer(), UTF8Deserializer())) + def _dictToJavaMap(self, d): + jm = self._jvm.java.util.HashMap() + if not d: + d = {} + for k, v in d.iteritems(): + jm[k] = v + return jm + + def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None, + valueConverter=None, minSplits=None): + """ + Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS, + a local file system (available on all nodes), or any Hadoop-supported file system URI. + The mechanism is as follows: + 1. A Java RDD is created from the SequenceFile or other InputFormat, and the key + and value Writable classes + 2. Serialization is attempted via Pyrolite pickling + 3. If this fails, the fallback is to call 'toString' on each key and value + 4. C{PickleSerializer} is used to deserialize pickled objects on the Python side + + @param path: path to sequncefile + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.Text") + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.LongWritable") + @param keyConverter: + @param valueConverter: + @param minSplits: minimum splits in dataset + (default min(2, sc.defaultParallelism)) + """ + minSplits = minSplits or min(self.defaultParallelism, 2) + jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass, + keyConverter, valueConverter, minSplits) + return RDD(jrdd, self, PickleSerializer()) + + def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, + valueConverter=None, conf=None): + """ + Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS, + a local file system (available on all nodes), or any Hadoop-supported file system URI. + The mechanism is the same as for sc.sequenceFile. + + A Hadoop configuration can be passed in as a Python dict. This will be converted into a + Configuration in Java + + @param path: path to Hadoop file + @param inputFormatClass: fully qualified classname of Hadoop InputFormat + (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.Text") + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.LongWritable") + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: Hadoop configuration, passed in as a dict + (None by default) + """ + jconf = self._dictToJavaMap(conf) + jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, + valueClass, keyConverter, valueConverter, jconf) + return RDD(jrdd, self, PickleSerializer()) + + def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, + valueConverter=None, conf=None): + """ + Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary + Hadoop configuration, which is passed in as a Python dict. + This will be converted into a Configuration in Java. + The mechanism is the same as for sc.sequenceFile. + + @param inputFormatClass: fully qualified classname of Hadoop InputFormat + (e.g. "org.apache.hadoop.mapreduce.lib.input.TextInputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.Text") + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.LongWritable") + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: Hadoop configuration, passed in as a dict + (None by default) + """ + jconf = self._dictToJavaMap(conf) + jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass, + valueClass, keyConverter, valueConverter, jconf) + return RDD(jrdd, self, PickleSerializer()) + + def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, + valueConverter=None, conf=None): + """ + Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS, + a local file system (available on all nodes), or any Hadoop-supported file system URI. + The mechanism is the same as for sc.sequenceFile. + + A Hadoop configuration can be passed in as a Python dict. This will be converted into a + Configuration in Java. + + @param path: path to Hadoop file + @param inputFormatClass: fully qualified classname of Hadoop InputFormat + (e.g. "org.apache.hadoop.mapred.TextInputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.Text") + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.LongWritable") + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: Hadoop configuration, passed in as a dict + (None by default) + """ + jconf = self._dictToJavaMap(conf) + jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass, + valueClass, keyConverter, valueConverter, jconf) + return RDD(jrdd, self, PickleSerializer()) + + def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None, + valueConverter=None, conf=None): + """ + Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary + Hadoop configuration, which is passed in as a Python dict. + This will be converted into a Configuration in Java. + The mechanism is the same as for sc.sequenceFile. + + @param inputFormatClass: fully qualified classname of Hadoop InputFormat + (e.g. "org.apache.hadoop.mapred.TextInputFormat") + @param keyClass: fully qualified classname of key Writable class + (e.g. "org.apache.hadoop.io.Text") + @param valueClass: fully qualified classname of value Writable class + (e.g. "org.apache.hadoop.io.LongWritable") + @param keyConverter: (None by default) + @param valueConverter: (None by default) + @param conf: Hadoop configuration, passed in as a dict + (None by default) + """ + jconf = self._dictToJavaMap(conf) + jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass, valueClass, + keyConverter, valueConverter, jconf) + return RDD(jrdd, self, PickleSerializer()) + def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) return RDD(jrdd, self, input_deserializer) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index ca0a95578fd28..9c69c79236edc 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -250,7 +250,7 @@ def getCheckpointFile(self): def map(self, f, preservesPartitioning=False): """ Return a new RDD by applying a function to each element of this RDD. - + >>> rdd = sc.parallelize(["b", "a", "c"]) >>> sorted(rdd.map(lambda x: (x, 1)).collect()) [('a', 1), ('b', 1), ('c', 1)] @@ -312,6 +312,15 @@ def mapPartitionsWithSplit(self, f, preservesPartitioning=False): "use mapPartitionsWithIndex instead", DeprecationWarning, stacklevel=2) return self.mapPartitionsWithIndex(f, preservesPartitioning) + def getNumPartitions(self): + """ + Returns the number of partitions in RDD + >>> rdd = sc.parallelize([1, 2, 3, 4], 2) + >>> rdd.getNumPartitions() + 2 + """ + return self._jrdd.splits().size() + def filter(self, f): """ Return a new RDD containing only the elements that satisfy a predicate. @@ -413,9 +422,9 @@ def union(self, other): def intersection(self, other): """ - Return the intersection of this RDD and another one. The output will not + Return the intersection of this RDD and another one. The output will not contain any duplicate elements, even if the input RDDs did. - + Note that this method performs a shuffle internally. >>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5]) @@ -571,14 +580,14 @@ def foreachPartition(self, f): """ Applies a function to each partition of this RDD. - >>> def f(iterator): - ... for x in iterator: - ... print x + >>> def f(iterator): + ... for x in iterator: + ... print x ... yield None >>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f) """ self.mapPartitions(f).collect() # Force evaluation - + def collect(self): """ Return a list that contains all of the elements in this RDD. @@ -673,7 +682,7 @@ def func(iterator): yield acc return self.mapPartitions(func).fold(zeroValue, combOp) - + def max(self): """ @@ -692,7 +701,7 @@ def min(self): 1.0 """ return self.reduce(min) - + def sum(self): """ Add up the elements in this RDD. @@ -786,7 +795,7 @@ def mergeMaps(m1, m2): m1[k] += v return m1 return self.mapPartitions(countPartition).reduce(mergeMaps) - + def top(self, num): """ Get the top N elements from a RDD. @@ -814,7 +823,7 @@ def merge(a, b): def takeOrdered(self, num, key=None): """ Get the N elements from a RDD ordered in ascending order or as specified - by the optional key function. + by the optional key function. >>> sc.parallelize([10, 1, 2, 9, 3, 4, 5, 6, 7]).takeOrdered(6) [1, 2, 3, 4, 5, 6] @@ -834,7 +843,7 @@ def unKey(x, key_=None): if key_ != None: x = [i[1] for i in x] return x - + def merge(a, b): return next(topNKeyedElems(a + b)) result = self.mapPartitions(lambda i: topNKeyedElems(i, key)).reduce(merge) @@ -1169,12 +1178,12 @@ def _mergeCombiners(iterator): combiners[k] = mergeCombiners(combiners[k], v) return combiners.iteritems() return shuffled.mapPartitions(_mergeCombiners) - + def foldByKey(self, zeroValue, func, numPartitions=None): """ Merge the values for each key using an associative function "func" and a neutral "zeroValue" - which may be added to the result an arbitrary number of times, and must not change - the result (e.g., 0 for addition, or 1 for multiplication.). + which may be added to the result an arbitrary number of times, and must not change + the result (e.g., 0 for addition, or 1 for multiplication.). >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)]) >>> from operator import add @@ -1182,8 +1191,8 @@ def foldByKey(self, zeroValue, func, numPartitions=None): [('a', 2), ('b', 1)] """ return self.combineByKey(lambda v: func(zeroValue, v), func, func, numPartitions) - - + + # TODO: support variant with custom partitioner def groupByKey(self, numPartitions=None): """ @@ -1302,7 +1311,7 @@ def keyBy(self, f): def repartition(self, numPartitions): """ Return a new RDD that has exactly numPartitions partitions. - + Can increase or decrease the level of parallelism in this RDD. Internally, this uses a shuffle to redistribute data. If you are decreasing the number of partitions in this RDD, consider using `coalesce`, diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 1f2a6ea941cf2..184ee810b861b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -198,6 +198,151 @@ def func(x): self.sc.parallelize([1]).foreach(func) +class TestInputFormat(PySparkTestCase): + + def setUp(self): + PySparkTestCase.setUp(self) + self.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(self.tempdir.name) + self.sc._jvm.WriteInputFormatTestDataGenerator.generateData(self.tempdir.name, self.sc._jsc) + + def tearDown(self): + PySparkTestCase.tearDown(self) + shutil.rmtree(self.tempdir.name) + + def test_sequencefiles(self): + basepath = self.tempdir.name + ints = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + doubles = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfdouble/", + "org.apache.hadoop.io.DoubleWritable", + "org.apache.hadoop.io.Text").collect()) + ed = [(1.0, u'aa'), (1.0, u'aa'), (2.0, u'aa'), (2.0, u'bb'), (2.0, u'bb'), (3.0, u'cc')] + self.assertEqual(doubles, ed) + + text = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sftext/", + "org.apache.hadoop.io.Text", + "org.apache.hadoop.io.Text").collect()) + et = [(u'1', u'aa'), + (u'1', u'aa'), + (u'2', u'aa'), + (u'2', u'bb'), + (u'2', u'bb'), + (u'3', u'cc')] + self.assertEqual(text, et) + + bools = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfbool/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + eb = [(1, False), (1, True), (2, False), (2, False), (2, True), (3, True)] + self.assertEqual(bools, eb) + + nulls = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfnull/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.BooleanWritable").collect()) + en = [(1, None), (1, None), (2, None), (2, None), (2, None), (3, None)] + self.assertEqual(nulls, en) + + maps = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable").collect()) + em = [(1, {2.0: u'aa'}), + (1, {3.0: u'bb'}), + (2, {1.0: u'aa'}), + (2, {1.0: u'cc'}), + (2, {3.0: u'bb'}), + (3, {2.0: u'dd'})] + self.assertEqual(maps, em) + + clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/", + "org.apache.hadoop.io.Text", + "org.apache.spark.api.python.TestWritable").collect()) + ec = (u'1', + {u'__class__': u'org.apache.spark.api.python.TestWritable', + u'double': 54.0, u'int': 123, u'str': u'test1'}) + self.assertEqual(clazz[0], ec) + + def test_oldhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.hadoopFile(basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + hello = self.sc.hadoopFile(hellopath, + "org.apache.hadoop.mapred.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text").collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newhadoop(self): + basepath = self.tempdir.name + ints = sorted(self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text").collect()) + ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')] + self.assertEqual(ints, ei) + + hellopath = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + hello = self.sc.newAPIHadoopFile(hellopath, + "org.apache.hadoop.mapreduce.lib.input.TextInputFormat", + "org.apache.hadoop.io.LongWritable", + "org.apache.hadoop.io.Text").collect() + result = [(0, u'Hello World!')] + self.assertEqual(hello, result) + + def test_newolderror(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.SequenceFileInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_bad_inputs(self): + basepath = self.tempdir.name + self.assertRaises(Exception, lambda: self.sc.sequenceFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.io.NotValidWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.hadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapred.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + self.assertRaises(Exception, lambda: self.sc.newAPIHadoopFile( + basepath + "/sftestdata/sfint/", + "org.apache.hadoop.mapreduce.lib.input.NotValidInputFormat", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.Text")) + + def test_converter(self): + basepath = self.tempdir.name + maps = sorted(self.sc.sequenceFile( + basepath + "/sftestdata/sfmap/", + "org.apache.hadoop.io.IntWritable", + "org.apache.hadoop.io.MapWritable", + valueConverter="org.apache.spark.api.python.TestConverter").collect()) + em = [(1, [2.0]), (1, [3.0]), (2, [1.0]), (2, [1.0]), (2, [3.0]), (3, [2.0])] + self.assertEqual(maps, em) + + class TestDaemon(unittest.TestCase): def connect(self, port): from socket import socket, AF_INET, SOCK_STREAM diff --git a/python/run-tests b/python/run-tests index 36a96121cbc0d..3b4501178c89f 100755 --- a/python/run-tests +++ b/python/run-tests @@ -32,7 +32,8 @@ rm -f unit-tests.log rm -rf metastore warehouse function run_test() { - SPARK_TESTING=0 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log + echo "Running test: $1" + SPARK_TESTING=1 $FWDIR/bin/pyspark $1 2>&1 | tee -a > unit-tests.log FAILED=$((PIPESTATUS[0]||$FAILED)) # Fail and exit on the first test failure. @@ -46,15 +47,17 @@ function run_test() { } +echo "Running PySpark tests. Output is in python/unit-tests.log." + run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "pyspark/conf.py" if [ -n "$_RUN_SQL_TESTS" ]; then run_test "pyspark/sql.py" fi -run_test "-m doctest pyspark/broadcast.py" -run_test "-m doctest pyspark/accumulators.py" -run_test "-m doctest pyspark/serializers.py" +run_test "pyspark/broadcast.py" +run_test "pyspark/accumulators.py" +run_test "pyspark/serializers.py" run_test "pyspark/tests.py" run_test "pyspark/mllib/_common.py" run_test "pyspark/mllib/classification.py" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala index a404e7441a1bd..36758f3114e59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala @@ -41,10 +41,25 @@ import org.apache.spark.sql.catalyst.types._ * for a SQL like language should checkout the HiveQL support in the sql/hive sub-project. */ class SqlParser extends StandardTokenParsers with PackratParsers { + def apply(input: String): LogicalPlan = { - phrase(query)(new lexical.Scanner(input)) match { - case Success(r, x) => r - case x => sys.error(x.toString) + // Special-case out set commands since the value fields can be + // complex to handle without RegexParsers. Also this approach + // is clearer for the several possible cases of set commands. + if (input.trim.toLowerCase.startsWith("set")) { + input.trim.drop(3).split("=", 2).map(_.trim) match { + case Array("") => // "set" + SetCommand(None, None) + case Array(key) => // "set key" + SetCommand(Some(key), None) + case Array(key, value) => // "set key=value" + SetCommand(Some(key), Some(value)) + } + } else { + phrase(query)(new lexical.Scanner(input)) match { + case Success(r, x) => r + case x => sys.error(x.toString) + } } } @@ -131,6 +146,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected val OUTER = Keyword("OUTER") protected val RIGHT = Keyword("RIGHT") protected val SELECT = Keyword("SELECT") + protected val SEMI = Keyword("SEMI") protected val STRING = Keyword("STRING") protected val SUM = Keyword("SUM") protected val TRUE = Keyword("TRUE") @@ -168,11 +184,13 @@ class SqlParser extends StandardTokenParsers with PackratParsers { } } - protected lazy val query: Parser[LogicalPlan] = + protected lazy val query: Parser[LogicalPlan] = ( select * ( - UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | - UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } - ) | insert + UNION ~ ALL ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Union(q1, q2) } | + UNION ~ opt(DISTINCT) ^^^ { (q1: LogicalPlan, q2: LogicalPlan) => Distinct(Union(q1, q2)) } + ) + | insert + ) protected lazy val select: Parser[LogicalPlan] = SELECT ~> opt(DISTINCT) ~ projections ~ @@ -241,6 +259,7 @@ class SqlParser extends StandardTokenParsers with PackratParsers { protected lazy val joinType: Parser[JoinType] = INNER ^^^ Inner | + LEFT ~ SEMI ^^^ LeftSemi | LEFT ~ opt(OUTER) ^^^ LeftOuter | RIGHT ~ opt(OUTER) ^^^ RightOuter | FULL ~ opt(OUTER) ^^^ FullOuter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 406ffd6801e98..ccb8245cc2e7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,6 +19,10 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.FullOuter +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.RightOuter +import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.types._ @@ -34,7 +38,7 @@ object Optimizer extends RuleExecutor[LogicalPlan] { Batch("Filter Pushdown", FixedPoint(100), CombineFilters, PushPredicateThroughProject, - PushPredicateThroughInnerJoin, + PushPredicateThroughJoin, ColumnPruning) :: Nil } @@ -254,28 +258,98 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { /** * Pushes down [[catalyst.plans.logical.Filter Filter]] operators where the `condition` can be - * evaluated using only the attributes of the left or right side of an inner join. Other + * evaluated using only the attributes of the left or right side of a join. Other * [[catalyst.plans.logical.Filter Filter]] conditions are moved into the `condition` of the * [[catalyst.plans.logical.Join Join]]. + * And also Pushes down the join filter, where the `condition` can be evaluated using only the + * attributes of the left or right side of sub query when applicable. + * + * Check https://cwiki.apache.org/confluence/display/Hive/OuterJoinBehavior for more details */ -object PushPredicateThroughInnerJoin extends Rule[LogicalPlan] with PredicateHelper { +object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { + // split the condition expression into 3 parts, + // (canEvaluateInLeftSide, canEvaluateInRightSide, haveToEvaluateWithBothSide) + private def split(condition: Seq[Expression], left: LogicalPlan, right: LogicalPlan) = { + val (leftEvaluateCondition, rest) = + condition.partition(_.references subsetOf left.outputSet) + val (rightEvaluateCondition, commonCondition) = + rest.partition(_.references subsetOf right.outputSet) + + (leftEvaluateCondition, rightEvaluateCondition, commonCondition) + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case f @ Filter(filterCondition, Join(left, right, Inner, joinCondition)) => - val allConditions = - splitConjunctivePredicates(filterCondition) ++ - joinCondition.map(splitConjunctivePredicates).getOrElse(Nil) - - // Split the predicates into those that can be evaluated on the left, right, and those that - // must be evaluated after the join. - val (rightConditions, leftOrJoinConditions) = - allConditions.partition(_.references subsetOf right.outputSet) - val (leftConditions, joinConditions) = - leftOrJoinConditions.partition(_.references subsetOf left.outputSet) - - // Build the new left and right side, optionally with the pushed down filters. - val newLeft = leftConditions.reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) - val newRight = rightConditions.reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) - Join(newLeft, newRight, Inner, joinConditions.reduceLeftOption(And)) + // push the where condition down into join filter + case f @ Filter(filterCondition, Join(left, right, joinType, joinCondition)) => + val (leftFilterConditions, rightFilterConditions, commonFilterCondition) = + split(splitConjunctivePredicates(filterCondition), left, right) + + joinType match { + case Inner => + // push down the single side `where` condition into respective sides + val newLeft = leftFilterConditions. + reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = rightFilterConditions. + reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + val newJoinCond = (commonFilterCondition ++ joinCondition).reduceLeftOption(And) + + Join(newLeft, newRight, Inner, newJoinCond) + case RightOuter => + // push down the right side only `where` condition + val newLeft = left + val newRight = rightFilterConditions. + reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + val newJoinCond = joinCondition + val newJoin = Join(newLeft, newRight, RightOuter, newJoinCond) + + (leftFilterConditions ++ commonFilterCondition). + reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) + case _ @ (LeftOuter | LeftSemi) => + // push down the left side only `where` condition + val newLeft = leftFilterConditions. + reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = right + val newJoinCond = joinCondition + val newJoin = Join(newLeft, newRight, joinType, newJoinCond) + + (rightFilterConditions ++ commonFilterCondition). + reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) + case FullOuter => f // DO Nothing for Full Outer Join + } + + // push down the join filter into sub query scanning if applicable + case f @ Join(left, right, joinType, joinCondition) => + val (leftJoinConditions, rightJoinConditions, commonJoinCondition) = + split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) + + joinType match { + case Inner => + // push down the single side only join filter for both sides sub queries + val newLeft = leftJoinConditions. + reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = rightJoinConditions. + reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + val newJoinCond = commonJoinCondition.reduceLeftOption(And) + + Join(newLeft, newRight, Inner, newJoinCond) + case RightOuter => + // push down the left side only join filter for left side sub query + val newLeft = leftJoinConditions. + reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) + val newRight = right + val newJoinCond = (rightJoinConditions ++ commonJoinCondition).reduceLeftOption(And) + + Join(newLeft, newRight, RightOuter, newJoinCond) + case _ @ (LeftOuter | LeftSemi) => + // push down the right side only join filter for right sub query + val newLeft = left + val newRight = rightJoinConditions. + reduceLeftOption(And).map(Filter(_, right)).getOrElse(right) + val newJoinCond = (leftJoinConditions ++ commonJoinCondition).reduceLeftOption(And) + + Join(newLeft, newRight, joinType, newJoinCond) + case FullOuter => f + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 4544b32958c7e..820ecfb78b52e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -119,6 +119,11 @@ object HashFilteredJoin extends Logging with PredicateHelper { case FilteredOperation(predicates, join @ Join(left, right, Inner, condition)) => logger.debug(s"Considering hash inner join on: ${predicates ++ condition}") splitPredicates(predicates ++ condition, join) + // All predicates can be evaluated for left semi join (those that are in the WHERE + // clause can only from left table, so they can all be pushed down.) + case FilteredOperation(predicates, join @ Join(left, right, LeftSemi, condition)) => + logger.debug(s"Considering hash left semi join on: ${predicates ++ condition}") + splitPredicates(predicates ++ condition, join) case join @ Join(left, right, joinType, condition) => logger.debug(s"Considering hash join on: $condition") splitPredicates(condition.toSeq, join) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index ae8d7d3e4257f..613f4bb09daf5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -22,3 +22,4 @@ case object Inner extends JoinType case object LeftOuter extends JoinType case object RightOuter extends JoinType case object FullOuter extends JoinType +case object LeftSemi extends JoinType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 2b8fbdcde9d37..7eeb98aea6368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.types.StructType +import org.apache.spark.sql.catalyst.types.{StringType, StructType} import org.apache.spark.sql.catalyst.trees abstract class LogicalPlan extends QueryPlan[LogicalPlan] { @@ -102,7 +102,7 @@ abstract class LeafNode extends LogicalPlan with trees.LeafNode[LogicalPlan] { */ abstract class Command extends LeafNode { self: Product => - def output = Seq.empty + def output: Seq[Attribute] = Seq.empty // TODO: SPARK-2081 should fix this } /** @@ -111,11 +111,23 @@ abstract class Command extends LeafNode { */ case class NativeCommand(cmd: String) extends Command +/** + * Commands of the form "SET (key) (= value)". + */ +case class SetCommand(key: Option[String], value: Option[String]) extends Command { + override def output = Seq( + AttributeReference("key", StringType, nullable = false)(), + AttributeReference("value", StringType, nullable = false)() + ) +} + /** * Returned by a parser when the users only wants to see what query plan would be executed, without * actually performing the execution. */ -case class ExplainCommand(plan: LogicalPlan) extends Command +case class ExplainCommand(plan: LogicalPlan) extends Command { + override def output = Seq(AttributeReference("plan", StringType, nullable = false)()) +} /** * A logical plan node with single child. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 732708e146b04..d3347b622f3d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.JoinType +import org.apache.spark.sql.catalyst.plans.{LeftSemi, JoinType} import org.apache.spark.sql.catalyst.types._ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { @@ -81,7 +81,12 @@ case class Join( condition: Option[Expression]) extends BinaryNode { def references = condition.map(_.references).getOrElse(Set.empty) - def output = left.output ++ right.output + def output = joinType match { + case LeftSemi => + left.output + case _ => + left.output ++ right.output + } } case class InsertIntoTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index ef47850455a37..02cc665f8a8c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -20,11 +20,14 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateAnalysisOperators import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.Inner +import org.apache.spark.sql.catalyst.plans.FullOuter +import org.apache.spark.sql.catalyst.plans.LeftOuter +import org.apache.spark.sql.catalyst.plans.RightOuter import org.apache.spark.sql.catalyst.rules._ - -/* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.junit.Test class FilterPushdownSuite extends OptimizerTest { @@ -35,7 +38,7 @@ class FilterPushdownSuite extends OptimizerTest { Batch("Filter Pushdown", Once, CombineFilters, PushPredicateThroughProject, - PushPredicateThroughInnerJoin) :: Nil + PushPredicateThroughJoin) :: Nil } val testRelation = LocalRelation('a.int, 'b.int, 'c.int) @@ -161,6 +164,184 @@ class FilterPushdownSuite extends OptimizerTest { comparePlans(optimized, correctAnswer) } + + test("joins: push down left outer join #1") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, LeftOuter) + .where("x.b".attr === 1 && "y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 1) + val correctAnswer = + left.join(y, LeftOuter).where("y.b".attr === 2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down right outer join #1") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, RightOuter) + .where("x.b".attr === 1 && "y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val right = testRelation.where('b === 2).subquery('d) + val correctAnswer = + x.join(right, RightOuter).where("x.b".attr === 1).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down left outer join #2") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, LeftOuter, Some("x.b".attr === 1)) + .where("x.b".attr === 2 && "y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 2).subquery('d) + val correctAnswer = + left.join(y, LeftOuter, Some("d.b".attr === 1)).where("y.b".attr === 2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down right outer join #2") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, RightOuter, Some("y.b".attr === 1)) + .where("x.b".attr === 2 && "y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val right = testRelation.where('b === 2).subquery('d) + val correctAnswer = + x.join(right, RightOuter, Some("d.b".attr === 1)).where("x.b".attr === 2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down left outer join #3") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, LeftOuter, Some("y.b".attr === 1)) + .where("x.b".attr === 2 && "y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 2).subquery('l) + val right = testRelation.where('b === 1).subquery('r) + val correctAnswer = + left.join(right, LeftOuter).where("r.b".attr === 2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down right outer join #3") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, RightOuter, Some("y.b".attr === 1)) + .where("x.b".attr === 2 && "y.b".attr === 2) + } + + val optimized = Optimize(originalQuery.analyze) + val right = testRelation.where('b === 2).subquery('r) + val correctAnswer = + x.join(right, RightOuter, Some("r.b".attr === 1)).where("x.b".attr === 2).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down left outer join #4") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, LeftOuter, Some("y.b".attr === 1)) + .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 2).subquery('l) + val right = testRelation.where('b === 1).subquery('r) + val correctAnswer = + left.join(right, LeftOuter).where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down right outer join #4") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, RightOuter, Some("y.b".attr === 1)) + .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.subquery('l) + val right = testRelation.where('b === 2).subquery('r) + val correctAnswer = + left.join(right, RightOuter, Some("r.b".attr === 1)). + where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down left outer join #5") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, LeftOuter, Some("y.b".attr === 1 && "x.a".attr === 3)) + .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('b === 2).subquery('l) + val right = testRelation.where('b === 1).subquery('r) + val correctAnswer = + left.join(right, LeftOuter, Some("l.a".attr===3)). + where("r.b".attr === 2 && "l.c".attr === "r.c".attr).analyze + + comparePlans(optimized, correctAnswer) + } + + test("joins: push down right outer join #5") { + val x = testRelation.subquery('x) + val y = testRelation.subquery('y) + + val originalQuery = { + x.join(y, RightOuter, Some("y.b".attr === 1 && "x.a".attr === 3)) + .where("x.b".attr === 2 && "y.b".attr === 2 && "x.c".attr === "y.c".attr) + } + + val optimized = Optimize(originalQuery.analyze) + val left = testRelation.where('a === 3).subquery('l) + val right = testRelation.where('b === 2).subquery('r) + val correctAnswer = + left.join(right, RightOuter, Some("r.b".attr === 1)). + where("l.b".attr === 2 && "l.c".attr === "r.c".attr).analyze + + comparePlans(optimized, correctAnswer) + } test("joins: can't push down") { val x = testRelation.subquery('x) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala new file mode 100644 index 0000000000000..b378252ba2f55 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.Properties + +import scala.collection.JavaConverters._ + +/** + * SQLConf holds mutable config parameters and hints. These can be set and + * queried either by passing SET commands into Spark SQL's DSL + * functions (sql(), hql(), etc.), or by programmatically using setters and + * getters of this class. This class is thread-safe. + */ +trait SQLConf { + + /** Number of partitions to use for shuffle operators. */ + private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt + + @transient + private val settings = java.util.Collections.synchronizedMap( + new java.util.HashMap[String, String]()) + + def set(props: Properties): Unit = { + props.asScala.foreach { case (k, v) => this.settings.put(k, v) } + } + + def set(key: String, value: String): Unit = { + require(key != null, "key cannot be null") + require(value != null, s"value cannot be null for ${key}") + settings.put(key, value) + } + + def get(key: String): String = { + if (!settings.containsKey(key)) { + throw new NoSuchElementException(key) + } + settings.get(key) + } + + def get(key: String, defaultValue: String): String = { + if (!settings.containsKey(key)) defaultValue else settings.get(key) + } + + def getAll: Array[(String, String)] = settings.asScala.toArray + + def getOption(key: String): Option[String] = { + if (!settings.containsKey(key)) None else Some(settings.get(key)) + } + + def contains(key: String): Boolean = settings.containsKey(key) + + def toDebugString: String = { + settings.synchronized { + settings.asScala.toArray.sorted.map{ case (k, v) => s"$k=$v" }.mkString("\n") + } + } + + private[spark] def clear() { + settings.clear() + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 043be58edc91b..021e0e8245a0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.{ScalaReflection, dsl} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.catalyst.optimizer.Optimizer -import org.apache.spark.sql.catalyst.plans.logical.{Subquery, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{SetCommand, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.columnar.InMemoryColumnarTableScan @@ -52,6 +52,7 @@ import org.apache.spark.sql.parquet.ParquetRelation @AlphaComponent class SQLContext(@transient val sparkContext: SparkContext) extends Logging + with SQLConf with dsl.ExpressionConversions with Serializable { @@ -190,9 +191,13 @@ class SQLContext(@transient val sparkContext: SparkContext) protected[sql] class SparkPlanner extends SparkStrategies { val sparkContext = self.sparkContext + def numPartitions = self.numShufflePartitions + val strategies: Seq[Strategy] = + CommandStrategy(self) :: TakeOrdered :: PartialAggregation :: + LeftSemiJoin :: HashJoin :: ParquetOperations :: BasicOperators :: @@ -244,6 +249,10 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val planner = new SparkPlanner + @transient + protected[sql] lazy val emptyResult = + sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) + /** * Prepares a planned SparkPlan for execution by binding references to specific ordinals, and * inserting shuffle operations as needed. @@ -251,7 +260,7 @@ class SQLContext(@transient val sparkContext: SparkContext) @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = - Batch("Add exchange", Once, AddExchange) :: + Batch("Add exchange", Once, AddExchange(self)) :: Batch("Prepare Expressions", Once, new BindReferences[SparkPlan]) :: Nil } @@ -262,6 +271,22 @@ class SQLContext(@transient val sparkContext: SparkContext) protected abstract class QueryExecution { def logical: LogicalPlan + def eagerlyProcess(plan: LogicalPlan): RDD[Row] = plan match { + case SetCommand(key, value) => + // Only this case needs to be executed eagerly. The other cases will + // be taken care of when the actual results are being extracted. + // In the case of HiveContext, sqlConf is overridden to also pass the + // pair into its HiveConf. + if (key.isDefined && value.isDefined) { + set(key.get, value.get) + } + // It doesn't matter what we return here, since this is only used + // to force the evaluation to happen eagerly. To query the results, + // one must use SchemaRDD operations to extract them. + emptyResult + case _ => executedPlan.execute() + } + lazy val analyzed = analyzer(logical) lazy val optimizedPlan = optimizer(analyzed) // TODO: Don't just pick the first one... @@ -269,7 +294,12 @@ class SQLContext(@transient val sparkContext: SparkContext) lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ - lazy val toRdd: RDD[Row] = executedPlan.execute() + lazy val toRdd: RDD[Row] = { + logical match { + case s: SetCommand => eagerlyProcess(s) + case _ => executedPlan.execute() + } + } protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } @@ -284,11 +314,6 @@ class SQLContext(@transient val sparkContext: SparkContext) |== Physical Plan == |${stringOrError(executedPlan)} """.stripMargin.trim - - /** - * Runs the query after interposing operators that print the result of each intermediate step. - */ - def debugExec() = DebugQuery(executedPlan).execute().collect() } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 604914e547790..34d88fe4bd7de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -77,8 +77,7 @@ case class Aggregate( resultAttribute: AttributeReference) /** A list of aggregates that need to be computed for each group. */ - @transient - private[this] lazy val computedAggregates = aggregateExpressions.flatMap { agg => + private[this] val computedAggregates = aggregateExpressions.flatMap { agg => agg.collect { case a: AggregateExpression => ComputedAggregate( @@ -89,8 +88,7 @@ case class Aggregate( }.toArray /** The schema of the result of all aggregate evaluations */ - @transient - private[this] lazy val computedSchema = computedAggregates.map(_.resultAttribute) + private[this] val computedSchema = computedAggregates.map(_.resultAttribute) /** Creates a new aggregate buffer for a group. */ private[this] def newAggregateBuffer(): Array[AggregateFunction] = { @@ -104,8 +102,7 @@ case class Aggregate( } /** Named attributes used to substitute grouping attributes into the final result. */ - @transient - private[this] lazy val namedGroups = groupingExpressions.map { + private[this] val namedGroups = groupingExpressions.map { case ne: NamedExpression => ne -> ne.toAttribute case e => e -> Alias(e, s"groupingExpr:$e")().toAttribute } @@ -114,16 +111,14 @@ case class Aggregate( * A map of substitutions that are used to insert the aggregate expressions and grouping * expression into the final result expression. */ - @transient - private[this] lazy val resultMap = + private[this] val resultMap = (computedAggregates.map { agg => agg.unbound -> agg.resultAttribute } ++ namedGroups).toMap /** * Substituted version of aggregateExpressions expressions which are used to compute final * output rows given a group and the result of all aggregate computations. */ - @transient - private[this] lazy val resultExpressions = aggregateExpressions.map { agg => + private[this] val resultExpressions = aggregateExpressions.map { agg => agg.transform { case e: Expression if resultMap.contains(e) => resultMap(e) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 3b4acb72e87b5..cef294167f146 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.annotation.DeveloperApi import org.apache.spark.{HashPartitioner, RangePartitioner, SparkConf} import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.sql.Row +import org.apache.spark.sql.{SQLConf, SQLContext, Row} import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions.{MutableProjection, RowOrdering} import org.apache.spark.sql.catalyst.plans.physical._ @@ -86,9 +86,9 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una * [[catalyst.plans.physical.Distribution Distribution]] requirements for each operator by inserting * [[Exchange]] Operators where required. */ -private[sql] object AddExchange extends Rule[SparkPlan] { +private[sql] case class AddExchange(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - val numPartitions = 150 + def numPartitions = sqlContext.numShufflePartitions def apply(plan: SparkPlan): SparkPlan = plan.transformUp { case operator: SparkPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index cfa8bdae58b11..0455748d40eec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{SQLContext, execution} +import org.apache.spark.sql.{SQLConf, SQLContext, execution} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ @@ -28,6 +28,22 @@ import org.apache.spark.sql.parquet._ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SQLContext#SparkPlanner => + object LeftSemiJoin extends Strategy with PredicateHelper { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + // Find left semi joins where at least some predicates can be evaluated by matching hash + // keys using the HashFilteredJoin pattern. + case HashFilteredJoin(LeftSemi, leftKeys, rightKeys, condition, left, right) => + val semiJoin = execution.LeftSemiJoinHash( + leftKeys, rightKeys, planLater(left), planLater(right)) + condition.map(Filter(_, semiJoin)).getOrElse(semiJoin) :: Nil + // no predicate can be evaluated by matching hash keys + case logical.Join(left, right, LeftSemi, condition) => + execution.LeftSemiJoinBNL( + planLater(left), planLater(right), condition)(sparkContext) :: Nil + case _ => Nil + } + } + object HashJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { // Find inner joins where at least some predicates can be evaluated by matching hash keys @@ -177,8 +193,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { - // TODO: Set - val numPartitions = 200 + def numPartitions = self.numPartitions + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Distinct(child) => execution.Aggregate( @@ -217,4 +233,16 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case _ => Nil } } + + case class CommandStrategy(context: SQLContext) extends Strategy { + def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case logical.SetCommand(key, value) => + Seq(execution.SetCommandPhysical(key, value, plan.output)(context)) + case logical.ExplainCommand(child) => + val qe = context.executePlan(child) + Seq(execution.ExplainCommandPhysical(qe.executedPlan, plan.output)(context)) + case _ => Nil + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala new file mode 100644 index 0000000000000..9364506691f38 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class SetCommandPhysical(key: Option[String], value: Option[String], output: Seq[Attribute]) + (@transient context: SQLContext) extends LeafNode { + def execute(): RDD[Row] = (key, value) match { + // Set value for key k; the action itself would + // have been performed in QueryExecution eagerly. + case (Some(k), Some(v)) => context.emptyResult + // Query the value bound to key k. + case (Some(k), None) => + val resultString = context.getOption(k) match { + case Some(v) => s"$k=$v" + case None => s"$k is undefined" + } + context.sparkContext.parallelize(Seq(new GenericRow(Array[Any](resultString))), 1) + // Query all key-value pairs that are set in the SQLConf of the context. + case (None, None) => + val pairs = context.getAll + val rows = pairs.map { case (k, v) => + new GenericRow(Array[Any](s"$k=$v")) + }.toSeq + // Assume config parameters can fit into one split (machine) ;) + context.sparkContext.parallelize(rows, 1) + // The only other case is invalid semantics and is impossible. + case _ => context.emptyResult + } +} + +/** + * :: DeveloperApi :: + */ +@DeveloperApi +case class ExplainCommandPhysical(child: SparkPlan, output: Seq[Attribute]) + (@transient context: SQLContext) extends UnaryNode { + def execute(): RDD[Row] = { + val planString = new GenericRow(Array[Any](child.toString)) + context.sparkContext.parallelize(Seq(planString)) + } + + override def otherCopyArgs = context :: Nil +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala new file mode 100644 index 0000000000000..c6fbd6d2f6930 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import scala.collection.mutable.HashSet + +import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.SparkContext._ +import org.apache.spark.sql.{SchemaRDD, Row} + +/** + * :: DeveloperApi :: + * Contains methods for debugging query execution. + * + * Usage: + * {{{ + * sql("SELECT key FROM src").debug + * }}} + */ +package object debug { + + /** + * :: DeveloperApi :: + * Augments SchemaRDDs with debug methods. + */ + @DeveloperApi + implicit class DebugQuery(query: SchemaRDD) { + def debug(implicit sc: SparkContext): Unit = { + val plan = query.queryExecution.executedPlan + val visited = new collection.mutable.HashSet[Long]() + val debugPlan = plan transform { + case s: SparkPlan if !visited.contains(s.id) => + visited += s.id + DebugNode(sc, s) + } + println(s"Results returned: ${debugPlan.execute().count()}") + debugPlan.foreach { + case d: DebugNode => d.dumpStats() + case _ => + } + } + } + + private[sql] case class DebugNode( + @transient sparkContext: SparkContext, + child: SparkPlan) extends UnaryNode { + def references = Set.empty + + def output = child.output + + implicit object SetAccumulatorParam extends AccumulatorParam[HashSet[String]] { + def zero(initialValue: HashSet[String]): HashSet[String] = { + initialValue.clear() + initialValue + } + + def addInPlace(v1: HashSet[String], v2: HashSet[String]): HashSet[String] = { + v1 ++= v2 + v1 + } + } + + /** + * A collection of stats for each column of output. + * @param elementTypes the actual runtime types for the output. Useful when there are bugs + * causing the wrong data to be projected. + */ + case class ColumnStat( + elementTypes: Accumulator[HashSet[String]] = sparkContext.accumulator(HashSet.empty)) + val tupleCount = sparkContext.accumulator[Int](0) + + val numColumns = child.output.size + val columnStats = Array.fill(child.output.size)(new ColumnStat()) + + def dumpStats(): Unit = { + println(s"== ${child.simpleString} ==") + println(s"Tuples output: ${tupleCount.value}") + child.output.zip(columnStats).foreach { case(attr, stat) => + val actualDataTypes =stat.elementTypes.value.mkString("{", ",", "}") + println(s" ${attr.name} ${attr.dataType}: $actualDataTypes") + } + } + + def execute() = { + child.execute().mapPartitions { iter => + new Iterator[Row] { + def hasNext = iter.hasNext + def next() = { + val currentRow = iter.next() + tupleCount += 1 + var i = 0 + while (i < numColumns) { + val value = currentRow(i) + columnStats(i).elementTypes += HashSet(value.getClass.getName) + i += 1 + } + currentRow + } + } + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 31cc26962ad93..88ff3d49a79b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -140,6 +140,137 @@ case class HashJoin( } } +/** + * :: DeveloperApi :: + * Build the right table's join keys into a HashSet, and iteratively go through the left + * table, to find the if join keys are in the Hash set. + */ +@DeveloperApi +case class LeftSemiJoinHash( + leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution = + ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil + + val (buildPlan, streamedPlan) = (right, left) + val (buildKeys, streamedKeys) = (rightKeys, leftKeys) + + def output = left.output + + @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output) + @transient lazy val streamSideKeyGenerator = + () => new MutableProjection(streamedKeys, streamedPlan.output) + + def execute() = { + + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashTable = new java.util.HashSet[Row]() + var currentRow: Row = null + + // Create a Hash set of buildKeys + while (buildIter.hasNext) { + currentRow = buildIter.next() + val rowKey = buildSideKeyGenerator(currentRow) + if(!rowKey.anyNull) { + val keyExists = hashTable.contains(rowKey) + if (!keyExists) { + hashTable.add(rowKey) + } + } + } + + new Iterator[Row] { + private[this] var currentStreamedRow: Row = _ + private[this] var currentHashMatched: Boolean = false + + private[this] val joinKeys = streamSideKeyGenerator() + + override final def hasNext: Boolean = + streamIter.hasNext && fetchNext() + + override final def next() = { + currentStreamedRow + } + + /** + * Searches the streamed iterator for the next row that has at least one match in hashtable. + * + * @return true if the search is successful, and false the streamed iterator runs out of + * tuples. + */ + private final def fetchNext(): Boolean = { + currentHashMatched = false + while (!currentHashMatched && streamIter.hasNext) { + currentStreamedRow = streamIter.next() + if (!joinKeys(currentStreamedRow).anyNull) { + currentHashMatched = hashTable.contains(joinKeys.currentValue) + } + } + currentHashMatched + } + } + } + } +} + +/** + * :: DeveloperApi :: + * Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys + * for hash join. + */ +@DeveloperApi +case class LeftSemiJoinBNL( + streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) + (@transient sc: SparkContext) + extends BinaryNode { + // TODO: Override requiredChildDistribution. + + override def outputPartitioning: Partitioning = streamed.outputPartitioning + + override def otherCopyArgs = sc :: Nil + + def output = left.output + + /** The Streamed Relation */ + def left = streamed + /** The Broadcast relation */ + def right = broadcast + + @transient lazy val boundCondition = + InterpretedPredicate( + condition + .map(c => BindReferences.bindReference(c, left.output ++ right.output)) + .getOrElse(Literal(true))) + + + def execute() = { + val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + + streamed.execute().mapPartitions { streamedIter => + val joinedRow = new JoinedRow + + streamedIter.filter(streamedRow => { + var i = 0 + var matched = false + + while (i < broadcastedRelation.value.size && !matched) { + val broadcastedRow = broadcastedRelation.value(i) + if (boundCondition(joinedRow(streamedRow, broadcastedRow))) { + matched = true + } + i += 1 + } + matched + }) + } + } +} + /** * :: DeveloperApi :: */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d6072b402a044..d7f6abaf5d381 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -44,7 +44,7 @@ class QueryTest extends FunSuite { fail( s""" |Exception thrown while executing query: - |${rdd.logicalPlan} + |${rdd.queryExecution} |== Exception == |$e """.stripMargin) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala new file mode 100644 index 0000000000000..5eb73a4eff980 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -0,0 +1,71 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql + +import org.apache.spark.sql.test._ + +/* Implicits */ +import TestSQLContext._ + +class SQLConfSuite extends QueryTest { + + val testKey = "test.key.0" + val testVal = "test.val.0" + + test("programmatic ways of basic setting and getting") { + assert(getOption(testKey).isEmpty) + assert(getAll.toSet === Set()) + + set(testKey, testVal) + assert(get(testKey) == testVal) + assert(get(testKey, testVal + "_") == testVal) + assert(getOption(testKey) == Some(testVal)) + assert(contains(testKey)) + + // Tests SQLConf as accessed from a SQLContext is mutable after + // the latter is initialized, unlike SparkConf inside a SparkContext. + assert(TestSQLContext.get(testKey) == testVal) + assert(TestSQLContext.get(testKey, testVal + "_") == testVal) + assert(TestSQLContext.getOption(testKey) == Some(testVal)) + assert(TestSQLContext.contains(testKey)) + + clear() + } + + test("parse SQL set commands") { + sql(s"set $testKey=$testVal") + assert(get(testKey, testVal + "_") == testVal) + assert(TestSQLContext.get(testKey, testVal + "_") == testVal) + + sql("set mapred.reduce.tasks=20") + assert(get("mapred.reduce.tasks", "0") == "20") + sql("set mapred.reduce.tasks = 40") + assert(get("mapred.reduce.tasks", "0") == "40") + + val key = "spark.sql.key" + val vs = "val0,val_1,val2.3,my_table" + sql(s"set $key=$vs") + assert(get(key, "0") == vs) + + sql(s"set $key=") + assert(get(key, "0") == "") + + clear() + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index aa0c426f6fcb3..de02bbc7e4700 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -40,6 +40,13 @@ class SQLQuerySuite extends QueryTest { arrayData.map(d => (d.data, d.data(0), d.data(0) + d.data(1), d.data(1))).collect().toSeq) } + test("left semi greater than predicate") { + checkAnswer( + sql("SELECT * FROM testData2 x LEFT SEMI JOIN testData2 y ON x.a >= y.a + 2"), + Seq((3,1), (3,2)) + ) + } + test("index into array of arrays") { checkAnswer( sql( @@ -129,6 +136,12 @@ class SQLQuerySuite extends QueryTest { 2.0) } + test("average overflow") { + checkAnswer( + sql("SELECT AVG(a),b FROM largeAndSmallInts group by b"), + Seq((2147483645.0,1),(2.0,2))) + } + test("count") { checkAnswer( sql("SELECT COUNT(*) FROM testData2"), @@ -354,6 +367,41 @@ class SQLQuerySuite extends QueryTest { (1, "abc"), (2, "abc"), (3, null))) - } - + } + + test("SET commands semantics using sql()") { + clear() + val testKey = "test.key.0" + val testVal = "test.val.0" + val nonexistentKey = "nonexistent" + + // "set" itself returns all config variables currently specified in SQLConf. + assert(sql("SET").collect().size == 0) + + // "set key=val" + sql(s"SET $testKey=$testVal") + checkAnswer( + sql("SET"), + Seq(Seq(s"$testKey=$testVal")) + ) + + sql(s"SET ${testKey + testKey}=${testVal + testVal}") + checkAnswer( + sql("set"), + Seq( + Seq(s"$testKey=$testVal"), + Seq(s"${testKey + testKey}=${testVal + testVal}")) + ) + + // "set key" + checkAnswer( + sql(s"SET $testKey"), + Seq(Seq(s"$testKey=$testVal")) + ) + checkAnswer( + sql(s"SET $nonexistentKey"), + Seq(Seq(s"$nonexistentKey is undefined")) + ) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 05de736bbce1b..330b20b315d63 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -30,6 +30,17 @@ object TestData { (1 to 100).map(i => TestData(i, i.toString))) testData.registerAsTable("testData") + case class LargeAndSmallInts(a: Int, b: Int) + val largeAndSmallInts: SchemaRDD = + TestSQLContext.sparkContext.parallelize( + LargeAndSmallInts(2147483644, 1) :: + LargeAndSmallInts(1, 2) :: + LargeAndSmallInts(2147483645, 1) :: + LargeAndSmallInts(2, 2) :: + LargeAndSmallInts(2147483646, 1) :: + LargeAndSmallInts(3, 2) :: Nil) + largeAndSmallInts.registerAsTable("largeAndSmallInts") + case class TestData2(a: Int, b: Int) val testData2: SchemaRDD = TestSQLContext.sparkContext.parallelize( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index c563d63627544..df6b118360d01 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -30,8 +30,8 @@ class PlannerSuite extends FunSuite { test("unions are collapsed") { val query = testData.unionAll(testData).unionAll(testData).logicalPlan val planned = BasicOperators(query).head - val logicalUnions = query collect { case u: logical.Union => u} - val physicalUnions = planned collect { case u: execution.Union => u} + val logicalUnions = query collect { case u: logical.Union => u } + val physicalUnions = planned collect { case u: execution.Union => u } assert(logicalUnions.size === 2) assert(physicalUnions.size === 1) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index b21f24dad785d..64978215542ec 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql package hive -import scala.language.implicitConversions - import java.io.{BufferedReader, File, InputStreamReader, PrintStream} import java.util.{ArrayList => JArrayList} +import scala.collection.JavaConversions._ +import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.hive.conf.HiveConf @@ -30,20 +30,15 @@ import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Analyzer, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LowerCaseSchema} -import org.apache.spark.sql.catalyst.plans.logical.{NativeCommand, ExplainCommand} -import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.types._ import org.apache.spark.sql.execution._ -/* Implicit conversions */ -import scala.collection.JavaConversions._ - /** * Starts up an instance of hive where metadata is stored locally. An in-process metadata data is * created with data stored in ./metadata. Warehouse data is stored in in ./warehouse. @@ -55,10 +50,9 @@ class LocalHiveContext(sc: SparkContext) extends HiveContext(sc) { /** Sets up the system initially or after a RESET command */ protected def configure() { - // TODO: refactor this so we can work with other databases. - runSqlHive( - s"set javax.jdo.option.ConnectionURL=jdbc:derby:;databaseName=$metastorePath;create=true") - runSqlHive("set hive.metastore.warehouse.dir=" + warehousePath) + set("javax.jdo.option.ConnectionURL", + s"jdbc:derby:;databaseName=$metastorePath;create=true") + set("hive.metastore.warehouse.dir", warehousePath) } configure() // Must be called before initializing the catalog below. @@ -129,12 +123,27 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { } } + /** + * SQLConf and HiveConf contracts: when the hive session is first initialized, params in + * HiveConf will get picked up by the SQLConf. Additionally, any properties set by + * set() or a SET command inside hql() or sql() will be set in the SQLConf *as well as* + * in the HiveConf. + */ @transient protected[hive] lazy val hiveconf = new HiveConf(classOf[SessionState]) - @transient protected[hive] lazy val sessionState = new SessionState(hiveconf) + @transient protected[hive] lazy val sessionState = { + val ss = new SessionState(hiveconf) + set(hiveconf.getAllProperties) // Have SQLConf pick up the initial set of HiveConf. + ss + } sessionState.err = new PrintStream(outputBuffer, true, "UTF-8") sessionState.out = new PrintStream(outputBuffer, true, "UTF-8") + override def set(key: String, value: String): Unit = { + super.set(key, value) + runSqlHive(s"SET $key=$value") + } + /* A catalyst metadata catalog that points to the Hive Metastore. */ @transient override protected[sql] lazy val catalog = new HiveMetastoreCatalog(this) with OverrideCatalog { @@ -218,12 +227,14 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { val hiveContext = self override val strategies: Seq[Strategy] = Seq( + CommandStrategy(self), TakeOrdered, ParquetOperations, HiveTableScans, DataSinks, Scripts, PartialAggregation, + LeftSemiJoin, HashJoin, BasicOperators, CartesianProduct, @@ -234,30 +245,31 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { @transient override protected[sql] val planner = hivePlanner - @transient - protected lazy val emptyResult = - sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): Row), 1) - /** Extends QueryExecution with hive specific features. */ protected[sql] abstract class QueryExecution extends super.QueryExecution { // TODO: Create mixin for the analyzer instead of overriding things here. override lazy val optimizedPlan = optimizer(catalog.PreInsertionCasts(catalog.CreateTables(analyzed))) - override lazy val toRdd: RDD[Row] = - analyzed match { - case NativeCommand(cmd) => - val output = runSqlHive(cmd) + override lazy val toRdd: RDD[Row] = { + def processCmd(cmd: String): RDD[Row] = { + val output = runSqlHive(cmd) + if (output.size == 0) { + emptyResult + } else { + val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) + sparkContext.parallelize(asRows, 1) + } + } - if (output.size == 0) { - emptyResult - } else { - val asRows = output.map(r => new GenericRow(r.split("\t").asInstanceOf[Array[Any]])) - sparkContext.parallelize(asRows, 1) - } - case _ => - executedPlan.execute().map(_.copy()) + logical match { + case s: SetCommand => eagerlyProcess(s) + case _ => analyzed match { + case NativeCommand(cmd) => processCmd(cmd) + case _ => executedPlan.execute().map(_.copy()) + } } + } protected val primitiveTypes = Seq(StringType, IntegerType, LongType, DoubleType, FloatType, BooleanType, ByteType, @@ -303,7 +315,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { */ def stringResult(): Seq[String] = analyzed match { case NativeCommand(cmd) => runSqlHive(cmd) - case ExplainCommand(plan) => new QueryExecution { val logical = plan }.toString.split("\n") + case ExplainCommand(plan) => executePlan(plan).toString.split("\n") case query => val result: Seq[Seq[Any]] = toRdd.collect().toSeq // We need the types so we can output struct field names @@ -316,6 +328,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { override def simpleString: String = logical match { case _: NativeCommand => "" + case _: SetCommand => "" case _ => executedPlan.toString } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7e91c16c6b93a..4e74d9bc909fa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -207,8 +207,17 @@ private[hive] object HiveQl { /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = { try { - if (sql.toLowerCase.startsWith("set")) { - NativeCommand(sql) + if (sql.trim.toLowerCase.startsWith("set")) { + // Split in two parts since we treat the part before the first "=" + // as key, and the part after as value, which may contain other "=" signs. + sql.trim.drop(3).split("=", 2).map(_.trim) match { + case Array("") => // "set" + SetCommand(None, None) + case Array(key) => // "set key" + SetCommand(Some(key), None) + case Array(key, value) => // "set key=value" + SetCommand(Some(key), Some(value)) + } } else if (sql.toLowerCase.startsWith("add jar")) { AddJar(sql.drop(8)) } else if (sql.toLowerCase.startsWith("add file")) { @@ -685,6 +694,7 @@ private[hive] object HiveQl { case "TOK_RIGHTOUTERJOIN" => RightOuter case "TOK_LEFTOUTERJOIN" => LeftOuter case "TOK_FULLOUTERJOIN" => FullOuter + case "TOK_LEFTSEMIJOIN" => LeftSemi } assert(other.size <= 1, "Unhandled join clauses.") Join(nodeToRelation(relation1), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala similarity index 100% rename from sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveOperators.scala rename to sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/hiveOperators.scala diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-0-80b6466213face7fbcb0de044611e1f5 b/sql/hive/src/test/resources/golden/leftsemijoin-0-80b6466213face7fbcb0de044611e1f5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-1-d1f6a3dea28a5f0fee08026bf33d9129 b/sql/hive/src/test/resources/golden/leftsemijoin-1-d1f6a3dea28a5f0fee08026bf33d9129 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea new file mode 100644 index 0000000000000..25ce912507d55 --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin-10-89737a8857b5b61cc909e0c797f86aea @@ -0,0 +1,4 @@ +Hank 2 +Hank 2 +Joe 2 +Joe 2 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-11-80b6466213face7fbcb0de044611e1f5 b/sql/hive/src/test/resources/golden/leftsemijoin-11-80b6466213face7fbcb0de044611e1f5 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-12-d1f6a3dea28a5f0fee08026bf33d9129 b/sql/hive/src/test/resources/golden/leftsemijoin-12-d1f6a3dea28a5f0fee08026bf33d9129 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-2-43d53504df013e6b35f81811138a167a b/sql/hive/src/test/resources/golden/leftsemijoin-2-43d53504df013e6b35f81811138a167a new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin-2-43d53504df013e6b35f81811138a167a @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-3-b07d292423312aafa5e5762a579decd2 b/sql/hive/src/test/resources/golden/leftsemijoin-3-b07d292423312aafa5e5762a579decd2 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-4-3ac2226efe7cb5d999c1c5e4ac2114be b/sql/hive/src/test/resources/golden/leftsemijoin-4-3ac2226efe7cb5d999c1c5e4ac2114be new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-5-9c307c0559d735960ce77efa95b2b17b b/sql/hive/src/test/resources/golden/leftsemijoin-5-9c307c0559d735960ce77efa95b2b17b new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-6-82921fc96eef547ec0f71027ee88298c b/sql/hive/src/test/resources/golden/leftsemijoin-6-82921fc96eef547ec0f71027ee88298c new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-7-b30aa3b4a45db6b64bb46b4d9bd32ff0 b/sql/hive/src/test/resources/golden/leftsemijoin-7-b30aa3b4a45db6b64bb46b4d9bd32ff0 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 new file mode 100644 index 0000000000000..25ce912507d55 --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin-8-73cad58a10a1483ccb15e94a857013 @@ -0,0 +1,4 @@ +Hank 2 +Hank 2 +Joe 2 +Joe 2 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin-9-c5efa6b8771a51610d655be461670e1e b/sql/hive/src/test/resources/golden/leftsemijoin-9-c5efa6b8771a51610d655be461670e1e new file mode 100644 index 0000000000000..f1470bad5782b --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin-9-c5efa6b8771a51610d655be461670e1e @@ -0,0 +1,2 @@ +2 Tie +2 Tie diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-0-7087fb6281a34d00f1812d2ff4ba8b75 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-0-7087fb6281a34d00f1812d2ff4ba8b75 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-1-aa3f07f028027ffd13ab5535dc821593 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-1-aa3f07f028027ffd13ab5535dc821593 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-10-9914f44ecb6ae7587b62e5349ff60d04 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-10-9914f44ecb6ae7587b62e5349ff60d04 new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-10-9914f44ecb6ae7587b62e5349ff60d04 @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-11-2027ecb1495d5550c5d56abf6b95b0a7 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-11-2027ecb1495d5550c5d56abf6b95b0a7 new file mode 100644 index 0000000000000..6ed281c757a96 --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-11-2027ecb1495d5550c5d56abf6b95b0a7 @@ -0,0 +1,2 @@ +1 +1 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-2-3f65953ae60375156367c54533978782 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-2-3f65953ae60375156367c54533978782 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-3-645cf8b871c9b27418d6fa1d1bda9a52 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-3-645cf8b871c9b27418d6fa1d1bda9a52 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-4-333895fe6abca27c8edb5c91bfe10d2f b/sql/hive/src/test/resources/golden/leftsemijoin_mr-4-333895fe6abca27c8edb5c91bfe10d2f new file mode 100644 index 0000000000000..6ed281c757a96 --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-4-333895fe6abca27c8edb5c91bfe10d2f @@ -0,0 +1,2 @@ +1 +1 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-5-896d0948c1df849df9764a6d8ad8fff9 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-5-896d0948c1df849df9764a6d8ad8fff9 new file mode 100644 index 0000000000000..179ef0e0209e9 --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-5-896d0948c1df849df9764a6d8ad8fff9 @@ -0,0 +1,20 @@ +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 +1 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-6-b1e2ade89ae898650f0be4f796d8947b b/sql/hive/src/test/resources/golden/leftsemijoin_mr-6-b1e2ade89ae898650f0be4f796d8947b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-6-b1e2ade89ae898650f0be4f796d8947b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c b/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-7-8e9c2969b999557363e40f9ebb3f6d7c @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-8-c61b972d4409babe41d8963e841af45b b/sql/hive/src/test/resources/golden/leftsemijoin_mr-8-c61b972d4409babe41d8963e841af45b new file mode 100644 index 0000000000000..573541ac9702d --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-8-c61b972d4409babe41d8963e841af45b @@ -0,0 +1 @@ +0 diff --git a/sql/hive/src/test/resources/golden/leftsemijoin_mr-9-2027ecb1495d5550c5d56abf6b95b0a7 b/sql/hive/src/test/resources/golden/leftsemijoin_mr-9-2027ecb1495d5550c5d56abf6b95b0a7 new file mode 100644 index 0000000000000..6ed281c757a96 --- /dev/null +++ b/sql/hive/src/test/resources/golden/leftsemijoin_mr-9-2027ecb1495d5550c5d56abf6b95b0a7 @@ -0,0 +1,2 @@ +1 +1 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 0f954103a85f2..357c7e654bd20 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -138,6 +138,9 @@ abstract class HiveComparisonTest val orderedAnswer = hiveQuery.logical match { // Clean out non-deterministic time schema info. + // Hack: Hive simply prints the result of a SET command to screen, + // and does not return it as a query answer. + case _: SetCommand => Seq("0") case _: NativeCommand => answer.filterNot(nonDeterministicLine).filterNot(_ == "") case _: ExplainCommand => answer case plan => if (isSorted(plan)) answer else answer.sorted diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 9031abf733cd4..fb8f272d5abfe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -480,6 +480,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "lateral_view", "lateral_view_cp", "lateral_view_ppd", + "leftsemijoin", + "leftsemijoin_mr", "lineage1", "literal_double", "literal_ints", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 125cc18bfb2b5..6c239b02ed09a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.Row import org.apache.spark.sql.hive.test.TestHive._ +import org.apache.spark.sql.hive.test.TestHive /** * A set of test cases expressed in Hive QL that are not covered by the tests included in the hive distribution. @@ -159,4 +161,89 @@ class HiveQuerySuite extends HiveComparisonTest { hql("SHOW TABLES").toString hql("SELECT * FROM src").toString } + + test("SPARK-1704: Explain commands as a SchemaRDD") { + hql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") + val rdd = hql("explain select key, count(value) from src group by key") + assert(rdd.collect().size == 1) + assert(rdd.toString.contains("ExplainCommand")) + assert(rdd.filter(row => row.toString.contains("ExplainCommand")).collect().size == 0, + "actual contents of the result should be the plans of the query to be explained") + TestHive.reset() + } + + test("parse HQL set commands") { + // Adapted from its SQL counterpart. + val testKey = "spark.sql.key.usedfortestonly" + val testVal = "val0,val_1,val2.3,my_table" + + hql(s"set $testKey=$testVal") + assert(get(testKey, testVal + "_") == testVal) + + hql("set mapred.reduce.tasks=20") + assert(get("mapred.reduce.tasks", "0") == "20") + hql("set mapred.reduce.tasks = 40") + assert(get("mapred.reduce.tasks", "0") == "40") + + hql(s"set $testKey=$testVal") + assert(get(testKey, "0") == testVal) + + hql(s"set $testKey=") + assert(get(testKey, "0") == "") + } + + test("SET commands semantics for a HiveContext") { + // Adapted from its SQL counterpart. + val testKey = "spark.sql.key.usedfortestonly" + var testVal = "test.val.0" + val nonexistentKey = "nonexistent" + def fromRows(row: Array[Row]): Array[String] = row.map(_.getString(0)) + + clear() + + // "set" itself returns all config variables currently specified in SQLConf. + assert(hql("set").collect().size == 0) + + // "set key=val" + hql(s"SET $testKey=$testVal") + assert(fromRows(hql("SET").collect()) sameElements Array(s"$testKey=$testVal")) + assert(hiveconf.get(testKey, "") == testVal) + + hql(s"SET ${testKey + testKey}=${testVal + testVal}") + assert(fromRows(hql("SET").collect()) sameElements + Array( + s"$testKey=$testVal", + s"${testKey + testKey}=${testVal + testVal}")) + assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + + // "set key" + assert(fromRows(hql(s"SET $testKey").collect()) sameElements + Array(s"$testKey=$testVal")) + assert(fromRows(hql(s"SET $nonexistentKey").collect()) sameElements + Array(s"$nonexistentKey is undefined")) + + // Assert that sql() should have the same effects as hql() by repeating the above using sql(). + clear() + assert(sql("set").collect().size == 0) + + sql(s"SET $testKey=$testVal") + assert(fromRows(sql("SET").collect()) sameElements Array(s"$testKey=$testVal")) + assert(hiveconf.get(testKey, "") == testVal) + + sql(s"SET ${testKey + testKey}=${testVal + testVal}") + assert(fromRows(sql("SET").collect()) sameElements + Array( + s"$testKey=$testVal", + s"${testKey + testKey}=${testVal + testVal}")) + assert(hiveconf.get(testKey + testKey, "") == testVal + testVal) + + assert(fromRows(sql(s"SET $testKey").collect()) sameElements + Array(s"$testKey=$testVal")) + assert(fromRows(sql(s"SET $nonexistentKey").collect()) sameElements + Array(s"$nonexistentKey is undefined")) + } + + // Put tests that depend on specific Hive settings before these last two test, + // since they modify /clear stuff. + } diff --git a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala index aeb3f0062df3b..801e8b381588f 100644 --- a/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala +++ b/yarn/common/src/main/scala/org/apache/spark/deploy/yarn/ClientBase.scala @@ -23,6 +23,7 @@ import java.nio.ByteBuffer import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, ListBuffer, Map} +import scala.util.{Try, Success, Failure} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ @@ -220,7 +221,7 @@ trait ClientBase extends Logging { } } - var cachedSecondaryJarLinks = ListBuffer.empty[String] + val cachedSecondaryJarLinks = ListBuffer.empty[String] val fileLists = List( (args.addJars, LocalResourceType.FILE, true), (args.files, LocalResourceType.FILE, false), (args.archives, LocalResourceType.ARCHIVE, false) ) @@ -378,7 +379,7 @@ trait ClientBase extends Logging { } } -object ClientBase { +object ClientBase extends Logging { val SPARK_JAR: String = "__spark__.jar" val APP_JAR: String = "__app__.jar" val LOG4J_PROP: String = "log4j.properties" @@ -388,37 +389,47 @@ object ClientBase { def getSparkJar = sys.env.get("SPARK_JAR").getOrElse(SparkContext.jarOfClass(this.getClass).head) - // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps - def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) { - val classpathEntries = Option(conf.getStrings( - YarnConfiguration.YARN_APPLICATION_CLASSPATH)).getOrElse( - getDefaultYarnApplicationClasspath()) - if (classpathEntries != null) { - for (c <- classpathEntries) { - YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, c.trim, - File.pathSeparator) - } + def populateHadoopClasspath(conf: Configuration, env: HashMap[String, String]) = { + val classPathElementsToAdd = getYarnAppClasspath(conf) ++ getMRAppClasspath(conf) + for (c <- classPathElementsToAdd.flatten) { + YarnSparkHadoopUtil.addToEnvironment( + env, + Environment.CLASSPATH.name, + c.trim, + File.pathSeparator) } + classPathElementsToAdd + } - val mrClasspathEntries = Option(conf.getStrings( - "mapreduce.application.classpath")).getOrElse( - getDefaultMRApplicationClasspath()) - if (mrClasspathEntries != null) { - for (c <- mrClasspathEntries) { - YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, c.trim, - File.pathSeparator) - } - } + private def getYarnAppClasspath(conf: Configuration): Option[Seq[String]] = + Option(conf.getStrings(YarnConfiguration.YARN_APPLICATION_CLASSPATH)) match { + case Some(s) => Some(s.toSeq) + case None => getDefaultYarnApplicationClasspath } - def getDefaultYarnApplicationClasspath(): Array[String] = { - try { - val field = classOf[MRJobConfig].getField("DEFAULT_YARN_APPLICATION_CLASSPATH") - field.get(null).asInstanceOf[Array[String]] - } catch { - case err: NoSuchFieldError => null - case err: NoSuchFieldException => null + private def getMRAppClasspath(conf: Configuration): Option[Seq[String]] = + Option(conf.getStrings("mapreduce.application.classpath")) match { + case Some(s) => Some(s.toSeq) + case None => getDefaultMRApplicationClasspath + } + + def getDefaultYarnApplicationClasspath: Option[Seq[String]] = { + val triedDefault = Try[Seq[String]] { + val field = classOf[YarnConfiguration].getField("DEFAULT_YARN_APPLICATION_CLASSPATH") + val value = field.get(null).asInstanceOf[Array[String]] + value.toSeq + } recoverWith { + case e: NoSuchFieldException => Success(Seq.empty[String]) } + + triedDefault match { + case f: Failure[_] => + logError("Unable to obtain the default YARN Application classpath.", f.exception) + case s: Success[_] => + logDebug(s"Using the default YARN application classpath: ${s.get.mkString(",")}") + } + + triedDefault.toOption } /** @@ -426,20 +437,30 @@ object ClientBase { * classpath. In Hadoop 2.0, it's an array of Strings, and in 2.2+ it's a String. * So we need to use reflection to retrieve it. */ - def getDefaultMRApplicationClasspath(): Array[String] = { - try { + def getDefaultMRApplicationClasspath: Option[Seq[String]] = { + val triedDefault = Try[Seq[String]] { val field = classOf[MRJobConfig].getField("DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH") - if (field.getType == classOf[String]) { - StringUtils.getStrings(field.get(null).asInstanceOf[String]) + val value = if (field.getType == classOf[String]) { + StringUtils.getStrings(field.get(null).asInstanceOf[String]).toArray } else { field.get(null).asInstanceOf[Array[String]] } - } catch { - case err: NoSuchFieldError => null - case err: NoSuchFieldException => null + value.toSeq + } recoverWith { + case e: NoSuchFieldException => Success(Seq.empty[String]) } + + triedDefault match { + case f: Failure[_] => + logError("Unable to obtain the default MR Application classpath.", f.exception) + case s: Success[_] => + logDebug(s"Using the default MR application classpath: ${s.get.mkString(",")}") + } + + triedDefault.toOption } + /** * Returns the java command line argument for setting up log4j. If there is a log4j.properties * in the given local resources, it is used, otherwise the SPARK_LOG4J_CONF environment variable @@ -481,12 +502,14 @@ object ClientBase { def addClasspathEntry(path: String) = YarnSparkHadoopUtil.addToEnvironment(env, Environment.CLASSPATH.name, path, File.pathSeparator) /** Add entry to the classpath. Interpreted as a path relative to the working directory. */ - def addPwdClasspathEntry(entry: String) = addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + entry) + def addPwdClasspathEntry(entry: String) = + addClasspathEntry(Environment.PWD.$() + Path.SEPARATOR + entry) extraClassPath.foreach(addClasspathEntry) val cachedSecondaryJarLinks = sparkConf.getOption(CONF_SPARK_YARN_SECONDARY_JARS).getOrElse("").split(",") + .filter(_.nonEmpty) // Normally the users app.jar is last in case conflicts with spark jars if (sparkConf.get("spark.yarn.user.classpath.first", "false").toBoolean) { addPwdClasspathEntry(APP_JAR) diff --git a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index e01ed5a57d697..039cf4f276119 100644 --- a/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/common/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -112,7 +112,7 @@ private[spark] class YarnClientSchedulerBackend( override def stop() { super.stop() - client.stop() + client.stop logInfo("Stopped") } diff --git a/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala new file mode 100644 index 0000000000000..608c6e92624c6 --- /dev/null +++ b/yarn/common/src/test/scala/org/apache/spark/deploy/yarn/ClientBaseSuite.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.URI + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.mapreduce.MRJobConfig +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment + +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers._ + +import scala.collection.JavaConversions._ +import scala.collection.mutable.{ HashMap => MutableHashMap } +import scala.util.Try + + +class ClientBaseSuite extends FunSuite { + + test("default Yarn application classpath") { + ClientBase.getDefaultYarnApplicationClasspath should be(Some(Fixtures.knownDefYarnAppCP)) + } + + test("default MR application classpath") { + ClientBase.getDefaultMRApplicationClasspath should be(Some(Fixtures.knownDefMRAppCP)) + } + + test("resultant classpath for an application that defines a classpath for YARN") { + withAppConf(Fixtures.mapYARNAppConf) { conf => + val env = newEnv + ClientBase.populateHadoopClasspath(conf, env) + classpath(env) should be( + flatten(Fixtures.knownYARNAppCP, ClientBase.getDefaultMRApplicationClasspath)) + } + } + + test("resultant classpath for an application that defines a classpath for MR") { + withAppConf(Fixtures.mapMRAppConf) { conf => + val env = newEnv + ClientBase.populateHadoopClasspath(conf, env) + classpath(env) should be( + flatten(ClientBase.getDefaultYarnApplicationClasspath, Fixtures.knownMRAppCP)) + } + } + + test("resultant classpath for an application that defines both classpaths, YARN and MR") { + withAppConf(Fixtures.mapAppConf) { conf => + val env = newEnv + ClientBase.populateHadoopClasspath(conf, env) + classpath(env) should be(flatten(Fixtures.knownYARNAppCP, Fixtures.knownMRAppCP)) + } + } + + object Fixtures { + + val knownDefYarnAppCP: Seq[String] = + getFieldValue[Array[String], Seq[String]](classOf[YarnConfiguration], + "DEFAULT_YARN_APPLICATION_CLASSPATH", + Seq[String]())(a => a.toSeq) + + + val knownDefMRAppCP: Seq[String] = + getFieldValue[String, Seq[String]](classOf[MRJobConfig], + "DEFAULT_MAPREDUCE_APPLICATION_CLASSPATH", + Seq[String]())(a => a.split(",")) + + val knownYARNAppCP = Some(Seq("/known/yarn/path")) + + val knownMRAppCP = Some(Seq("/known/mr/path")) + + val mapMRAppConf = + Map("mapreduce.application.classpath" -> knownMRAppCP.map(_.mkString(":")).get) + + val mapYARNAppConf = + Map(YarnConfiguration.YARN_APPLICATION_CLASSPATH -> knownYARNAppCP.map(_.mkString(":")).get) + + val mapAppConf = mapYARNAppConf ++ mapMRAppConf + } + + def withAppConf(m: Map[String, String] = Map())(testCode: (Configuration) => Any) { + val conf = new Configuration + m.foreach { case (k, v) => conf.set(k, v, "ClientBaseSpec") } + testCode(conf) + } + + def newEnv = MutableHashMap[String, String]() + + def classpath(env: MutableHashMap[String, String]) = env(Environment.CLASSPATH.name).split(":|;") + + def flatten(a: Option[Seq[String]], b: Option[Seq[String]]) = (a ++ b).flatten.toArray + + def getFieldValue[A, B](clazz: Class[_], field: String, defaults: => B)(mapTo: A => B): B = + Try(clazz.getField(field)).map(_.get(null).asInstanceOf[A]).toOption.map(mapTo).getOrElse(defaults) + +} diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index c1dfe3f53b40b..33a60d978c586 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -252,15 +252,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, try { logInfo("Allocating " + args.numExecutors + " executors.") // Wait until all containers have finished - // TODO: This is a bit ugly. Can we make it nicer? - // TODO: Handle container failure yarnAllocator.addResourceRequests(args.numExecutors) + yarnAllocator.allocateResources() // Exits the loop if the user thread exits. while (yarnAllocator.getNumExecutorsRunning < args.numExecutors && userThread.isAlive) { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } + checkNumExecutorsFailed() + allocateMissingExecutor() yarnAllocator.allocateResources() ApplicationMaster.incrementAllocatorLoop(1) Thread.sleep(100) @@ -289,23 +286,31 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration, } } + private def allocateMissingExecutor() { + val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - + yarnAllocator.getNumPendingAllocate + if (missingExecutorCount > 0) { + logInfo("Allocating %d containers to make up for (potentially) lost containers". + format(missingExecutorCount)) + yarnAllocator.addResourceRequests(missingExecutorCount) + } + } + + private def checkNumExecutorsFailed() { + if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of executor failures reached") + } + } + private def launchReporterThread(_sleepTime: Long): Thread = { val sleepTime = if (_sleepTime <= 0) 0 else _sleepTime val t = new Thread { override def run() { while (userThread.isAlive) { - if (yarnAllocator.getNumExecutorsFailed >= maxNumExecutorFailures) { - finishApplicationMaster(FinalApplicationStatus.FAILED, - "max number of executor failures reached") - } - val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - - yarnAllocator.getNumPendingAllocate - if (missingExecutorCount > 0) { - logInfo("Allocating %d containers to make up for (potentially) lost containers". - format(missingExecutorCount)) - yarnAllocator.addResourceRequests(missingExecutorCount) - } + checkNumExecutorsFailed() + allocateMissingExecutor() sendProgress() Thread.sleep(sleepTime) } diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 1b6bfb42a5c1c..393edd1f2d670 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.client.api.impl.YarnClientImpl +import org.apache.hadoop.yarn.client.api.YarnClient import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, Records} @@ -37,7 +37,9 @@ import org.apache.spark.{Logging, SparkConf} * Version of [[org.apache.spark.deploy.yarn.ClientBase]] tailored to YARN's stable API. */ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: SparkConf) - extends YarnClientImpl with ClientBase with Logging { + extends ClientBase with Logging { + + val yarnClient = YarnClient.createYarnClient def this(clientArgs: ClientArguments, spConf: SparkConf) = this(clientArgs, new Configuration(), spConf) @@ -53,8 +55,8 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def runApp(): ApplicationId = { validateArgs() // Initialize and start the client service. - init(yarnConf) - start() + yarnClient.init(yarnConf) + yarnClient.start() // Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers). logClusterResourceDetails() @@ -63,7 +65,7 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa // interface). // Get a new client application. - val newApp = super.createApplication() + val newApp = yarnClient.createApplication() val newAppResponse = newApp.getNewApplicationResponse() val appId = newAppResponse.getApplicationId() @@ -99,11 +101,11 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa } def logClusterResourceDetails() { - val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics + val clusterMetrics: YarnClusterMetrics = yarnClient.getYarnClusterMetrics logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " + clusterMetrics.getNumNodeManagers) - val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue) + val queueInfo: QueueInfo = yarnClient.getQueueInfo(args.amQueue) logInfo( """Queue info ... queueName: %s, queueCurrentCapacity: %s, queueMaxCapacity: %s, queueApplicationCount = %s, queueChildQueueCount = %s""".format( queueInfo.getQueueName, @@ -132,15 +134,20 @@ class Client(clientArgs: ClientArguments, hadoopConf: Configuration, spConf: Spa def submitApp(appContext: ApplicationSubmissionContext) = { // Submit the application to the applications manager. logInfo("Submitting application to ASM") - super.submitApplication(appContext) + yarnClient.submitApplication(appContext) } + def getApplicationReport(appId: ApplicationId) = + yarnClient.getApplicationReport(appId) + + def stop = yarnClient.stop + def monitorApplication(appId: ApplicationId): Boolean = { val interval = sparkConf.getLong("spark.yarn.report.interval", 1000) while (true) { Thread.sleep(interval) - val report = super.getApplicationReport(appId) + val report = yarnClient.getApplicationReport(appId) logInfo("Application report from ASM: \n" + "\t application identifier: " + appId.toString() + "\n" + diff --git a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala index a4ce8766d347c..d93e5bb0225d5 100644 --- a/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala +++ b/yarn/stable/src/main/scala/org/apache/spark/deploy/yarn/ExecutorLauncher.scala @@ -200,17 +200,25 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp logInfo("Allocating " + args.numExecutors + " executors.") // Wait until all containers have finished - // TODO: This is a bit ugly. Can we make it nicer? - // TODO: Handle container failure - yarnAllocator.addResourceRequests(args.numExecutors) + yarnAllocator.allocateResources() while ((yarnAllocator.getNumExecutorsRunning < args.numExecutors) && (!driverClosed)) { + allocateMissingExecutor() yarnAllocator.allocateResources() Thread.sleep(100) } logInfo("All executors have launched.") + } + private def allocateMissingExecutor() { + val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - + yarnAllocator.getNumPendingAllocate + if (missingExecutorCount > 0) { + logInfo("Allocating %d containers to make up for (potentially) lost containers". + format(missingExecutorCount)) + yarnAllocator.addResourceRequests(missingExecutorCount) + } } // TODO: We might want to extend this to allocate more containers in case they die ! @@ -220,13 +228,7 @@ class ExecutorLauncher(args: ApplicationMasterArguments, conf: Configuration, sp val t = new Thread { override def run() { while (!driverClosed) { - val missingExecutorCount = args.numExecutors - yarnAllocator.getNumExecutorsRunning - - yarnAllocator.getNumPendingAllocate - if (missingExecutorCount > 0) { - logInfo("Allocating %d containers to make up for (potentially) lost containers". - format(missingExecutorCount)) - yarnAllocator.addResourceRequests(missingExecutorCount) - } + allocateMissingExecutor() sendProgress() Thread.sleep(sleepTime) }