Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jan 28, 2015
1 parent 97386b3 commit a74da87
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 18 deletions.
35 changes: 18 additions & 17 deletions core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -372,10 +372,8 @@ private[spark] object PythonRDD extends Logging {
}

def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
// The right way to implement this would be to use TypeTags to get the full
// type of T. Since I don't want to introduce breaking changes throughout the
// entire Spark API, I have to use this hacky approach:
def write(bytes: Array[Byte]) {

def writeBytes(bytes: Array[Byte]) {
if (bytes == null) {
dataOut.writeInt(SpecialLengths.NULL)
} else {
Expand All @@ -384,50 +382,53 @@ private[spark] object PythonRDD extends Logging {
}
}

def writeS(str: String) {
def writeString(str: String) {
if (str == null) {
dataOut.writeInt(SpecialLengths.NULL)
} else {
writeUTF(str, dataOut)
}
}

// The right way to implement this would be to use TypeTags to get the full
// type of T. Since I don't want to introduce breaking changes throughout the
// entire Spark API, I have to use this hacky approach:
if (iter.hasNext) {
val first = iter.next()
val newIter = Seq(first).iterator ++ iter
first match {
case arr: Array[Byte] =>
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(write)
newIter.asInstanceOf[Iterator[Array[Byte]]].foreach(writeBytes)
case string: String =>
newIter.asInstanceOf[Iterator[String]].foreach(writeS)
newIter.asInstanceOf[Iterator[String]].foreach(writeString)
case stream: PortableDataStream =>
newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
write(stream.toArray())
writeBytes(stream.toArray())
}
case (key: String, stream: PortableDataStream) =>
newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
case (key, stream) =>
writeS(key)
write(stream.toArray())
writeString(key)
writeBytes(stream.toArray())
}
case (key: String, value: String) =>
newIter.asInstanceOf[Iterator[(String, String)]].foreach {
case (key, value) =>
writeS(key)
writeS(value)
writeString(key)
writeString(value)
}
case (key: Array[Byte], value: Array[Byte]) =>
newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
case (key, value) =>
write(key)
write(value)
writeBytes(key)
writeBytes(value)
}
// key is null
case (null, v:Array[Byte]) =>
case (null, value: Array[Byte]) =>
newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
case (key, value) =>
write(key)
write(value)
writeBytes(key)
writeBytes(value)
}

case other =>
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def loads(self, stream):
length = read_int(stream)
if length == SpecialLengths.END_OF_DATA_SECTION:
raise EOFError
if length == SpecialLengths.NULL:
elif length == SpecialLengths.NULL:
return None
s = stream.read(length)
return s.decode("utf-8") if self.use_unicode else s
Expand Down

0 comments on commit a74da87

Please sign in to comment.