Skip to content

Commit

Permalink
add tests for null in RDD
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jan 29, 2015
1 parent 23b039a commit f257071
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 69 deletions.
71 changes: 15 additions & 56 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -373,68 +373,27 @@ private[spark] object PythonRDD extends Logging {

def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {

def writeBytes(bytes: Array[Byte]) {
if (bytes == null) {
def write(obj: Any): Unit = obj match {
case null =>
dataOut.writeInt(SpecialLengths.NULL)
} else {
dataOut.writeInt(bytes.length)
dataOut.write(bytes)
}
}

def writeString(str: String) {
if (str == null) {
dataOut.writeInt(SpecialLengths.NULL)
} else {
case arr: Array[Byte] =>
dataOut.writeInt(arr.length)
dataOut.write(arr)
case str: String =>
writeUTF(str, dataOut)
}
}

// The right way to implement this would be to use TypeTags to get the full
// type of T. Since I don't want to introduce breaking changes throughout the
// entire Spark API, I have to use this hacky approach:
if (iter.hasNext) {
val first = iter.next()
val newIter = Seq(first).iterator ++ iter
first match {
case arr: Array[Byte] =>
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(writeBytes)
case string: String =>
newIter.asInstanceOf[Iterator[String]].foreach(writeString)
case stream: PortableDataStream =>
newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
writeBytes(stream.toArray())
}
case (key: String, stream: PortableDataStream) =>
newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
case (key, stream) =>
writeString(key)
writeBytes(stream.toArray())
}
case (key: String, value: String) =>
newIter.asInstanceOf[Iterator[(String, String)]].foreach {
case (key, value) =>
writeString(key)
writeString(value)
}
case (key: Array[Byte], value: Array[Byte]) =>
newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
case (key, value) =>
writeBytes(key)
writeBytes(value)
}
// key is null
case (null, value: Array[Byte]) =>
newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
case (key, value) =>
writeBytes(key)
writeBytes(value)
}
case stream: PortableDataStream =>
write(stream.toArray())
case (key, value) =>
write(key)
write(value)

case other =>
throw new SparkException("Unexpected element type " + other.getClass)
}
case other =>
throw new SparkException("Unexpected element type " + other.getClass)
}

iter.foreach(write)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.io.{File, InputStream, IOException, OutputStream}
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkContext
import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}

private[spark] object PythonUtils {
/** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */
Expand All @@ -39,4 +40,8 @@ private[spark] object PythonUtils {
def mergePythonPaths(paths: String*): String = {
paths.filter(_ != "").mkString(File.pathSeparator)
}

def generateRDDWithNull(sc: JavaSparkContext): JavaRDD[String] = {
sc.parallelize(List("a", null, "b"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,21 @@ import org.scalatest.FunSuite

class PythonRDDSuite extends FunSuite {

test("Writing large strings to the worker") {
val input: List[String] = List("a"*100000)
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(input.iterator, buffer)
}
test("Writing large strings to the worker") {
val input: List[String] = List("a"*100000)
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(input.iterator, buffer)
}

}
test("Handle nulls gracefully") {
val buffer = new DataOutputStream(new ByteArrayOutputStream)
PythonRDD.writeIteratorToStream(List("a", null).iterator, buffer)
PythonRDD.writeIteratorToStream(List(null, "a").iterator, buffer)
PythonRDD.writeIteratorToStream(List("a".getBytes, null).iterator, buffer)
PythonRDD.writeIteratorToStream(List(null, "a".getBytes).iterator, buffer)

PythonRDD.writeIteratorToStream(List((null, null), ("a", null), (null, "b")).iterator, buffer)
PythonRDD.writeIteratorToStream(
List((null, null), ("a".getBytes, null), (null, "b".getBytes)).iterator, buffer)
}
}
2 changes: 2 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def load_stream(self, stream):

def _write_with_length(self, obj, stream):
serialized = self.dumps(obj)
if serialized is None:
raise ValueError("serialized value should not be None")
if len(serialized) > (1 << 31):
raise ValueError("can not serialize object larger than 2G")
write_int(len(serialized), stream)
Expand Down
13 changes: 7 additions & 6 deletions python/pyspark/streaming/kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def utf8_decoder(s):
class KafkaUtils(object):

@staticmethod
def createStream(ssc, zkQuorum, groupId, topics,
def createStream(ssc, zkQuorum, groupId, topics, kafkaParams={},
storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2,
keyDecoder=utf8_decoder, valueDecoder=utf8_decoder):
"""
Expand All @@ -44,22 +44,23 @@ def createStream(ssc, zkQuorum, groupId, topics,
:param groupId: The group id for this consumer.
:param topics: Dict of (topic_name -> numPartitions) to consume.
Each partition is consumed in its own thread.
:param kafkaParams: Additional params for Kafka
:param storageLevel: RDD storage level.
:param keyDecoder: A function used to decode key
:param valueDecoder: A function used to decode value
:param keyDecoder: A function used to decode key (default is utf8_decoder)
:param valueDecoder: A function used to decode value (default is utf8_decoder)
:return: A DStream object
"""
java_import(ssc._jvm, "org.apache.spark.streaming.kafka.KafkaUtils")

param = {
kafkaParams.update({
"zookeeper.connect": zkQuorum,
"group.id": groupId,
"zookeeper.connection.timeout.ms": "10000",
}
})
if not isinstance(topics, dict):
raise TypeError("topics should be dict")
jtopics = MapConverter().convert(topics, ssc.sparkContext._gateway._gateway_client)
jparam = MapConverter().convert(param, ssc.sparkContext._gateway._gateway_client)
jparam = MapConverter().convert(kafkaParams, ssc.sparkContext._gateway._gateway_client)
jlevel = ssc._sc._getJavaStorageLevel(storageLevel)

def getClassByName(name):
Expand Down
8 changes: 7 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@

from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer, CompressedSerializer
CloudPickleSerializer, CompressedSerializer, UTF8Deserializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
UserDefinedType, DoubleType
Expand Down Expand Up @@ -714,6 +715,11 @@ def test_sample(self):
wr_s21 = rdd.sample(True, 0.4, 21).collect()
self.assertNotEqual(set(wr_s11), set(wr_s21))

def test_null_in_rdd(self):
jrdd = self.sc._jvm.PythonUtils.generateRDDWithNull(self.sc._jsc)
rdd = RDD(jrdd, self.sc, UTF8Deserializer())
self.assertEqual([u"a", None, u"b"], rdd.collect())


class ProfilerTests(PySparkTestCase):

Expand Down

0 comments on commit f257071

Please sign in to comment.