From 4280d04a32c69024bd200e407275a123b8373035 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 29 Jan 2015 17:25:31 -0800 Subject: [PATCH] address comments --- .../org/apache/spark/api/python/PythonRDD.scala | 3 --- .../apache/spark/api/python/PythonRDDSuite.scala | 15 ++++++++------- python/pyspark/tests.py | 4 +++- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a13a1b923c5f6..3308f155ccf2e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -376,19 +376,16 @@ private[spark] object PythonRDD extends Logging { def write(obj: Any): Unit = obj match { case null => dataOut.writeInt(SpecialLengths.NULL) - case arr: Array[Byte] => dataOut.writeInt(arr.length) dataOut.write(arr) case str: String => writeUTF(str, dataOut) - case stream: PortableDataStream => write(stream.toArray()) case (key, value) => write(key) write(value) - case other => throw new SparkException("Unexpected element type " + other.getClass) } diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala index f4cf02977033e..c63d834f9048b 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonRDDSuite.scala @@ -31,13 +31,14 @@ class PythonRDDSuite extends FunSuite { test("Handle nulls gracefully") { val buffer = new DataOutputStream(new ByteArrayOutputStream) - PythonRDD.writeIteratorToStream(List("a", null).iterator, buffer) - PythonRDD.writeIteratorToStream(List(null, "a").iterator, buffer) - PythonRDD.writeIteratorToStream(List("a".getBytes, null).iterator, buffer) - PythonRDD.writeIteratorToStream(List(null, "a".getBytes).iterator, buffer) - - PythonRDD.writeIteratorToStream(List((null, null), ("a", null), (null, "b")).iterator, buffer) + // Should not have NPE when write an Iterator with null in it + // The correctness will be tested in Python + PythonRDD.writeIteratorToStream(Iterator("a", null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a"), buffer) + PythonRDD.writeIteratorToStream(Iterator("a".getBytes, null), buffer) + PythonRDD.writeIteratorToStream(Iterator(null, "a".getBytes), buffer) + PythonRDD.writeIteratorToStream(Iterator((null, null), ("a", null), (null, "b")), buffer) PythonRDD.writeIteratorToStream( - List((null, null), ("a".getBytes, null), (null, "b".getBytes)).iterator, buffer) + Iterator((null, null), ("a".getBytes, null), (null, "b".getBytes)), buffer) } } diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index d001b39d615de..9f07bd49d5fd8 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,7 @@ from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ - CloudPickleSerializer, CompressedSerializer, UTF8Deserializer + CloudPickleSerializer, CompressedSerializer, UTF8Deserializer, NoOpSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ UserDefinedType, DoubleType @@ -720,6 +720,8 @@ def test_null_in_rdd(self): jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc) rdd = RDD(jrdd, self.sc, UTF8Deserializer()) self.assertEqual([u"a", None, u"b"], rdd.collect()) + rdd = RDD(jrdd, self.sc, NoOpSerializer()) + self.assertEqual(["a", None, "b"], rdd.collect()) def test_multiple_python_java_RDD_conversions(self): # Regression test for SPARK-5361