Skip to content

Commit

Permalink
[SPARK-2875] [PySpark] [SQL] handle null in schemaRDD()
Browse files Browse the repository at this point in the history
Handle null in schemaRDD during converting them into Python.

Author: Davies Liu <davies.liu@gmail.com>

Closes apache#1802 from davies/json and squashes the following commits:

88e6b1f [Davies Liu] handle null in schemaRDD()
  • Loading branch information
davies authored and marmbrus committed Aug 6, 2014
1 parent 09f7e45 commit 4878911
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 11 deletions.
7 changes: 7 additions & 0 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1231,6 +1231,13 @@ def jsonRDD(self, rdd, schema=None):
... "field3.field5[0] as f3 from table3")
>>> srdd6.collect()
[Row(f1=u'row1', f2=None,...Row(f1=u'row3', f2=[], f3=None)]
>>> sqlCtx.jsonRDD(sc.parallelize(['{}',
... '{"key0": {"key1": "value1"}}'])).collect()
[Row(key0=None), Row(key0=Row(key1=u'value1'))]
>>> sqlCtx.jsonRDD(sc.parallelize(['{"key0": null}',
... '{"key0": {"key1": "value1"}}'])).collect()
[Row(key0=None), Row(key0=Row(key1=u'value1'))]
"""

def func(iterator):
Expand Down
27 changes: 16 additions & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -382,21 +382,26 @@ class SchemaRDD(
private[sql] def javaToPython: JavaRDD[Array[Byte]] = {
import scala.collection.Map

def toJava(obj: Any, dataType: DataType): Any = dataType match {
case struct: StructType => rowToArray(obj.asInstanceOf[Row], struct)
case array: ArrayType => obj match {
case seq: Seq[Any] => seq.map(x => toJava(x, array.elementType)).asJava
case list: JList[_] => list.map(x => toJava(x, array.elementType)).asJava
case arr if arr != null && arr.getClass.isArray =>
arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
case other => other
}
case mt: MapType => obj.asInstanceOf[Map[_, _]].map {
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null

case (obj: Row, struct: StructType) => rowToArray(obj, struct)

case (seq: Seq[Any], array: ArrayType) =>
seq.map(x => toJava(x, array.elementType)).asJava
case (list: JList[_], array: ArrayType) =>
list.map(x => toJava(x, array.elementType)).asJava
case (arr, array: ArrayType) if arr.getClass.isArray =>
arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))

case (obj: Map[_, _], mt: MapType) => obj.map {
case (k, v) => (k, toJava(v, mt.valueType)) // key should be primitive type
}.asJava

// Pyrolite can handle Timestamp
case other => obj
case (other, _) => other
}

def rowToArray(row: Row, structType: StructType): Array[Any] = {
val fields = structType.fields.map(field => field.dataType)
row.zip(fields).map {
Expand Down

0 comments on commit 4878911

Please sign in to comment.