diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 0fb7e195b34c4..f430a33db1e4a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -49,8 +49,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) - override val classTag: ClassTag[(K, V)] = - implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]] + override val classTag: ClassTag[(K, V)] = rdd.elementClassTag import JavaPairRDD._ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 82527fe663848..57bde8d85f1a8 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -78,9 +78,7 @@ private[spark] class PythonRDD[T: ClassTag]( dataOut.writeInt(command.length) dataOut.write(command) // Data values - for (elem <- parent.iterator(split, context)) { - PythonRDD.writeToStream(elem, dataOut) - } + PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut) dataOut.flush() worker.shutdownOutput() } catch { @@ -206,20 +204,43 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeToStream(elem: Any, dataOut: DataOutputStream) { - elem match { - case bytes: Array[Byte] => - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - case pair: (Array[Byte], Array[Byte]) => - dataOut.writeInt(pair._1.length) - dataOut.write(pair._1) - dataOut.writeInt(pair._2.length) - dataOut.write(pair._2) - case str: String => - dataOut.writeUTF(str) - case other => - throw new SparkException("Unexpected element type " + other.getClass) + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) { + // The right way to implement this would be to use TypeTags to get the full + // type of T. Since I don't want to introduce breaking changes throughout the + // entire Spark API, I have to use this hacky approach: + if (iter.hasNext) { + val first = iter.next() + val newIter = Seq(first).iterator ++ iter + first match { + case arr: Array[Byte] => + newIter.asInstanceOf[Iterator[Array[Byte]]].foreach { bytes => + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } + case string: String => + newIter.asInstanceOf[Iterator[String]].foreach { str => + dataOut.writeUTF(str) + } + case pair: Tuple2[_, _] => + pair._1 match { + case bytePair: Array[Byte] => + newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair => + dataOut.writeInt(pair._1.length) + dataOut.write(pair._1) + dataOut.writeInt(pair._2.length) + dataOut.write(pair._2) + } + case stringPair: String => + newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair => + dataOut.writeUTF(pair._1) + dataOut.writeUTF(pair._2) + } + case other => + throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass) + } + case other => + throw new SparkException("Unexpected element type " + first.getClass) + } } } @@ -230,9 +251,7 @@ private[spark] object PythonRDD { def writeToFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) - for (item <- items) { - writeToStream(item, file) - } + writeIteratorToStream(items, file) file.close() } diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7acb6eaf10931..acd1ca5676209 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -152,6 +152,22 @@ def test_save_as_textfile_with_unicode(self): raw_contents = ''.join(input(glob(tempFile.name + "/part-0000*"))) self.assertEqual(x, unicode(raw_contents.strip(), "utf-8")) + def test_transforming_cartesian_result(self): + # Regression test for SPARK-1034 + rdd1 = self.sc.parallelize([1, 2]) + rdd2 = self.sc.parallelize([3, 4]) + cart = rdd1.cartesian(rdd2) + result = cart.map(lambda (x, y): x + y).collect() + + def test_cartesian_on_textfile(self): + # Regression test for + path = os.path.join(SPARK_HOME, "python/test_support/hello.txt") + a = self.sc.textFile(path) + result = a.cartesian(a).collect() + (x, y) = result[0] + self.assertEqual("Hello World!", x.strip()) + self.assertEqual("Hello World!", y.strip()) + class TestIO(PySparkTestCase):