Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25274][PYTHON][SQL] In toPandas with Arrow send un-ordered record batches to improve performance #22275

33 changes: 33 additions & 0 deletions python/pyspark/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,39 @@ def loads(self, obj):
raise NotImplementedError


class ArrowCollectSerializer(Serializer):
"""
Deserialize a stream of batches followed by batch order information.
"""

def __init__(self):
self.serializer = ArrowStreamSerializer()

def dump_stream(self, iterator, stream):
return self.serializer.dump_stream(iterator, stream)

def load_stream(self, stream):
"""
Load a stream of un-ordered Arrow RecordBatches, where the last
iteration will yield a list of indices to put the RecordBatches in
the correct order.
"""
# load the batches
for batch in self.serializer.load_stream(stream):
yield batch

# load the batch order indices
num = read_int(stream)
batch_order = []
for i in xrange(num):
index = read_int(stream)
batch_order.append(index)
yield batch_order

def __repr__(self):
return "ArrowCollectSerializer(%s)" % self.serializer


class ArrowStreamSerializer(Serializer):
"""
Serializes Arrow record batches as a stream.
Expand Down
12 changes: 10 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer, \
from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
Expand Down Expand Up @@ -2125,6 +2125,7 @@ def toPandas(self):
from pyspark.sql.types import _check_dataframe_convert_date, \
_check_dataframe_localize_timestamps
import pyarrow

batches = self._collectAsArrow()
if len(batches) > 0:
table = pyarrow.Table.from_batches(batches)
Expand Down Expand Up @@ -2183,7 +2184,14 @@ def _collectAsArrow(self):
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.collectAsArrowToPython()
return list(_load_from_socket(sock_info, ArrowStreamSerializer()))

# Collect list of un-ordered batches where last element is a list of correct order indices
results = list(_load_from_socket(sock_info, ArrowCollectSerializer()))
batches = results[:-1]
batch_order = results[-1]

# Re-order the batch list using the correct order
return [batches[i] for i in batch_order]

##########################################################################################
# Pandas compatibility
Expand Down
22 changes: 22 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4923,6 +4923,28 @@ def test_timestamp_dst(self):
self.assertPandasEqual(pdf, df_from_python.toPandas())
self.assertPandasEqual(pdf, df_from_pandas.toPandas())

def test_toPandas_batch_order(self):

# Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python
def run_test(num_records, num_parts, max_records):
df = self.spark.range(num_records, numPartitions=num_parts).toDF("a")
with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
self.assertPandasEqual(pdf, pdf_arrow)

cases = [
(1024, 512, 2), # Try large num partitions for good chance of not collecting in order
(512, 64, 2), # Try medium num partitions to test out of order collection
(64, 8, 2), # Try small number of partitions to test out of order collection
(64, 64, 1), # Test single batch per partition
(64, 1, 64), # Test single partition, single batch
(64, 1, 8), # Test single partition, multiple batches
(30, 7, 2), # Test different sized partitions
]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk and @felixcheung , I didn't do a loop but chose some different levels of partition numbers to be a bit more sure that partitions won't end up in order. I also added some other cases of different partition/batch ratios. Let me know if you think we need more to be sure here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how we're guaranteeing out-of-order from the JVM. Could we delay on one of the early partitions to guarantee out of order?

Copy link
Member Author

@BryanCutler BryanCutler Nov 6, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's not a guarantee, but with a large num of partitions, it's a pretty slim chance they will all be in order. I can also add a case with some delay. My only concern is how big to make the delay to be sure it's enough without adding wasted time to the tests.

How about we keep the case with a large number of partitions and add a case with 100ms delay on the first partition?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@holdenk , I updated the tests, please take another look when you get a chance. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the new tests, I think 0.1 on one of partitions is enough.


for case in cases:
run_test(num_records=case[0], num_parts=case[1], max_records=case[2])


class EncryptionArrowTests(ArrowTests):

Expand Down
45 changes: 25 additions & 20 deletions sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@

package org.apache.spark.sql

import java.io.CharArrayWriter
import java.io.{CharArrayWriter, DataOutputStream}
import java.sql.{Date, Timestamp}

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
Expand Down Expand Up @@ -3279,34 +3280,38 @@ class Dataset[T] private[sql](
val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone

withAction("collectAsArrowToPython", queryExecution) { plan =>
PythonRDD.serveToStream("serve-Arrow") { out =>
PythonRDD.serveToStream("serve-Arrow") { outputStream =>
val out = new DataOutputStream(outputStream)
val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
val arrowBatchRdd = toArrowBatchRdd(plan)
val numPartitions = arrowBatchRdd.partitions.length

// Store collection results for worst case of 1 to N-1 partitions
val results = new Array[Array[Array[Byte]]](numPartitions - 1)
var lastIndex = -1 // index of last partition written
// Batches ordered by (index of partition, batch index in that partition) tuple
val batchOrder = new ArrayBuffer[(Int, Int)]()
var partitionCount = 0

// Handler to eagerly write partitions to Python in order
// Handler to eagerly write batches to Python out of order
def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit = {
// If result is from next partition in order
if (index - 1 == lastIndex) {
if (arrowBatches.nonEmpty) {
// Write all batches (can be more than 1) in the partition, store the batch order tuple
batchWriter.writeBatches(arrowBatches.iterator)
lastIndex += 1
// Write stored partitions that come next in order
while (lastIndex < results.length && results(lastIndex) != null) {
batchWriter.writeBatches(results(lastIndex).iterator)
results(lastIndex) = null
lastIndex += 1
arrowBatches.indices.foreach {
partition_batch_index => batchOrder.append((index, partition_batch_index))
}
// After last batch, end the stream
if (lastIndex == results.length) {
batchWriter.end()
}
partitionCount += 1

// After last batch, end the stream and write batch order indices
if (partitionCount == numPartitions) {
batchWriter.end()
out.writeInt(batchOrder.length)
// Sort by (index of partition, batch index in that partition) tuple to get the
// overall_batch_index from 0 to N-1 batches, which can be used to put the
// transferred batches in the correct order
batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index) =>
out.writeInt(overall_batch_index)
}
} else {
// Store partitions received out of order
results(index - 1) = arrowBatches
out.flush()
}
}

Expand Down