Skip to content

Commit

Permalink
fix collect with UDT and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Davies Liu committed Jan 27, 2015
1 parent e971078 commit 6bf2b73
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 25 deletions.
29 changes: 18 additions & 11 deletions python/pyspark/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1973,7 +1973,7 @@ def collect(self):
[Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
"""
with SCCallSiteSync(self._sc) as css:
bytesInJava = self._jdf.collectToPython().iterator()
bytesInJava = self._jdf.javaToPython().collect().iterator()
cls = _create_cls(self.schema())
tempFile = NamedTemporaryFile(delete=False, dir=self._sc._temp_dir)
tempFile.close()
Expand All @@ -1997,14 +1997,14 @@ def take(self, num):
return self.limit(num).collect()

def map(self, f):
""" Return a new RDD by applying a function to each Row, it's a
shorthand for df.rdd.map()
"""
return self.rdd.map(f)

# Convert each object in the RDD to a Row with the right class
# for this DataFrame, so that fields can be accessed as attributes.
def mapPartitions(self, f, preservesPartitioning=False):
"""
Return a new RDD by applying a function to each partition of this RDD,
while tracking the index of the original partition.
Return a new RDD by applying a function to each partition.
>>> rdd = sc.parallelize([1, 2, 3, 4], 4)
>>> def f(iterator): yield 1
Expand All @@ -2013,21 +2013,28 @@ def mapPartitions(self, f, preservesPartitioning=False):
"""
return self.rdd.mapPartitions(f, preservesPartitioning)

# We override the default cache/persist/checkpoint behavior
# as we want to cache the underlying DataFrame object in the JVM,
# not the PythonRDD checkpointed by the super class
def cache(self):
""" Persist with the default storage level (C{MEMORY_ONLY_SER}).
"""
self.is_cached = True
self._jdf.cache()
return self

def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER):
""" Set the storage level to persist its values across operations
after the first time it is computed. This can only be used to assign
a new storage level if the RDD does not have a storage level set yet.
If no storage level is specified defaults to (C{MEMORY_ONLY_SER}).
"""
self.is_cached = True
javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel)
self._jdf.persist(javaStorageLevel)
return self

def unpersist(self, blocking=True):
""" Mark it as non-persistent, and remove all blocks for it from
memory and disk.
"""
self.is_cached = False
self._jdf.unpersist(blocking)
return self
Expand Down Expand Up @@ -2359,11 +2366,11 @@ def _scalaMethod(name):
""" Translate operators into methodName in Scala
For example:
>>> scalaMethod('+')
>>> _scalaMethod('+')
'$plus'
>>> scalaMethod('>=')
>>> _scalaMethod('>=')
'$greater$eq'
>>> scalaMethod('cast')
>>> _scalaMethod('cast')
'cast'
"""
return ''.join(SCALA_METHOD_MAPPINGS.get(c, c) for c in name)
Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,8 +946,7 @@ def test_apply_schema_with_udt(self):
schema = StructType([StructField("label", DoubleType(), False),
StructField("point", ExamplePointUDT(), False)])
df = self.sqlCtx.applySchema(rdd, schema)
# TODO: test collect with UDT
point = df.rdd.first().point
point = df.head().point
self.assertEquals(point, ExamplePoint(1.0, 2.0))

def test_parquet_with_udt(self):
Expand Down Expand Up @@ -984,11 +983,12 @@ def test_column_select(self):
self.assertEqual([Row(value='1')], df.where(df.key == 1).select(df.value).collect())

def test_aggregator(self):
from pyspark.sql import Aggregator as Agg
df = self.df
g = df.groupBy()
self.assertEqual([99, 100], sorted(g.agg({'key': 'max', 'value': 'count'}).collect()[0]))
self.assertEqual([Row(**{"AVG(key#0)": 49.5})], g.mean().collect())
# TODO(davies): fix aggregators
from pyspark.sql import Aggregator as Agg
# self.assertEqual((0, '100'), tuple(g.agg(Agg.first(df.key), Agg.last(df.value)).first()))


Expand Down
12 changes: 1 addition & 11 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -590,17 +590,7 @@ class DataFrame protected[sql](
*/
protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
val fieldTypes = schema.fields.map(_.dataType)
val jrdd = this.rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
SerDeUtil.javaToPython(jrdd)
}
/**
* Serializes the Array[Row] returned by collect(), using the same format as javaToPython.
*/
protected[sql] def collectToPython: JList[Array[Byte]] = {
val fieldTypes = schema.fields.map(_.dataType)
val pickle = new Pickler
new ArrayList[Array[Byte]](collect().map { row =>
EvaluatePython.rowToArray(row, fieldTypes)
}.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
}
}

0 comments on commit 6bf2b73

Please sign in to comment.