Skip to content

Commit

Permalink
[SPARK-23961][SPARK-27548][PYTHON] Fix error when toLocalIterator goe…
Browse files Browse the repository at this point in the history
…s out of scope and properly raise errors from worker

This fixes an error when a PySpark local iterator, for both RDD and DataFrames, goes out of scope and the connection is closed before fully consuming the iterator. The error occurs on the JVM in the serving thread, when Python closes the local socket while the JVM is writing to it. This usually happens when there is enough data to fill the socket read buffer, causing the write call to block.

Additionally, this fixes a problem when an error occurs in the Python worker and the collect job is cancelled with an exception. Previously, the Python driver was never notified of the error so the user could get a partial result (iteration until the error) and the application will continue. With this change, an error in the worker is sent to the Python iterator and is then raised.

The change here introduces a protocol for PySpark local iterators that work as follows:

1) The local socket connection is made when the iterator is created
2) When iterating, Python first sends a request for partition data as a non-zero integer
3) While the JVM local iterator over partitions has next, it triggers a job to collect the next partition
4) The JVM sends a nonzero response to indicate it has the next partition to send
5) The next partition is sent to Python and read by the PySpark deserializer
6) After sending the entire partition, an `END_OF_DATA_SECTION` is sent to Python which stops the deserializer and allows to make another request
7) When the JVM gets a request from Python but has already consumed it's local iterator, it will send a zero response to Python and both will close the socket cleanly
8) If an error occurs in the worker, a negative response is sent to Python followed by the error message. Python will then raise a RuntimeError with the message, stopping iteration.
9) When the PySpark local iterator is garbage-collected, it will read any remaining data from the current partition (this is data that has already been collected) and send a request of zero to tell the JVM to stop collection jobs and close the connection.

Steps 1, 3, 5, 6 are the same as before. Step 8 was completely missing before because errors in the worker were never communicated back to Python. The other steps add synchronization to allow for a clean closing of the socket, with a small trade-off in performance for each partition. This is mainly because the JVM does not start collecting partition data until it receives a request to do so, where before it would eagerly write all data until the socket receive buffer is full.

Added new unit tests for DataFrame and RDD `toLocalIterator` and tested not fully consuming the iterator. Manual tests with Python 2.7  and 3.6.

Closes apache#24070 from BryanCutler/pyspark-toLocalIterator-clean-stop-SPARK-23961.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
  • Loading branch information
BryanCutler authored and Willi Raschkowski committed Jun 5, 2020
1 parent 67e1d41 commit 80eff91
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 13 deletions.
57 changes: 56 additions & 1 deletion core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,63 @@ private[spark] object PythonRDD extends Logging {
serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
}

/**
* A helper function to create a local RDD iterator and serve it via socket. Partitions are
* are collected as separate jobs, by order of index. Partition data is first requested by a
* non-zero integer to start a collection job. The response is prefaced by an integer with 1
* meaning partition data will be served, 0 meaning the local iterator has been consumed,
* and -1 meaining an error occurred during collection. This function is used by
* pyspark.rdd._local_iterator_from_socket().
*
* @return 2-tuple (as a Java array) with the port number of a local socket which serves the
* data collected from these jobs, and the secret for authentication.
*/
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
val (port, secret) = SocketAuthServer.setupOneConnectionServer(
authHelper, "serve toLocalIterator") { s =>
val out = new DataOutputStream(s.getOutputStream)
val in = new DataInputStream(s.getInputStream)
Utils.tryWithSafeFinally {

// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
}

// Read request for data and send next partition if nonzero
var complete = false
while (!complete && in.readInt() != 0) {
if (collectPartitionIter.hasNext) {
try {
// Attempt to collect the next partition
val partitionArray = collectPartitionIter.next()

// Send response there is a partition to read
out.writeInt(1)

// Write the next object and signal end of data for this iteration
writeIteratorToStream(partitionArray.toIterator, out)
out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
out.flush()
} catch {
case e: SparkException =>
// Send response that an error occurred followed by error message
out.writeInt(-1)
writeUTF(e.getMessage, out)
complete = true
}
} else {
// Send response there are no more partitions to read and close
out.writeInt(0)
complete = true
}
}
} {
out.close()
in.close()
}
}
Array(port, secret)
}

def readRDDFromFile(
Expand Down
66 changes: 60 additions & 6 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
from itertools import imap as map, ifilter as filter

from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer
from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \
CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \
UTF8Deserializer, pack_long, read_int, write_int
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
Expand Down Expand Up @@ -138,15 +138,69 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])


def _load_from_socket(sock_info, serializer):
def _create_local_socket(sock_info):
(sockfile, sock) = local_connect_and_auth(*sock_info)
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
# The RDD materialization time is unpredictable, if we set a timeout for socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)
return sockfile


def _load_from_socket(sock_info, serializer):
sockfile = _create_local_socket(sock_info)
# The socket will be automatically closed when garbage-collected.
return serializer.load_stream(sockfile)


def _local_iterator_from_socket(sock_info, serializer):

class PyLocalIterable(object):
""" Create a synchronous local iterable over a socket """

def __init__(self, _sock_info, _serializer):
self._sockfile = _create_local_socket(_sock_info)
self._serializer = _serializer
self._read_iter = iter([]) # Initialize as empty iterator
self._read_status = 1

def __iter__(self):
while self._read_status == 1:
# Request next partition data from Java
write_int(1, self._sockfile)
self._sockfile.flush()

# If response is 1 then there is a partition to read, if 0 then fully consumed
self._read_status = read_int(self._sockfile)
if self._read_status == 1:

# Load the partition data as a stream and read each item
self._read_iter = self._serializer.load_stream(self._sockfile)
for item in self._read_iter:
yield item

# An error occurred, read error message and raise it
elif self._read_status == -1:
error_msg = UTF8Deserializer().loads(self._sockfile)
raise RuntimeError("An error occurred while reading the next element from "
"toLocalIterator: {}".format(error_msg))

def __del__(self):
# If local iterator is not fully consumed,
if self._read_status == 1:
try:
# Finish consuming partition data stream
for _ in self._read_iter:
pass
# Tell Java to stop sending data and close connection
write_int(0, self._sockfile)
self._sockfile.flush()
except Exception:
# Ignore any errors, socket is automatically closed when garbage-collected
pass

return iter(PyLocalIterable(sock_info, serializer))


def ignore_unicode_prefix(f):
"""
Ignore the 'u' prefix of string in doc tests, to make it works
Expand Down Expand Up @@ -2382,7 +2436,7 @@ def toLocalIterator(self):
"""
with SCCallSiteSync(self.context) as css:
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
return _load_from_socket(sock_info, self._jrdd_deserializer)
return _local_iterator_from_socket(sock_info, self._jrdd_deserializer)

def barrier(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import warnings

from pyspark import copy_func, since, _NoValue
from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
from pyspark.rdd import RDD, _load_from_socket, _local_iterator_from_socket, ignore_unicode_prefix
from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer, \
UTF8Deserializer
from pyspark.storagelevel import StorageLevel
Expand Down Expand Up @@ -528,7 +528,7 @@ def toLocalIterator(self):
"""
with SCCallSiteSync(self._sc) as css:
sock_info = self._jdf.toPythonIterator()
return _load_from_socket(sock_info, BatchedSerializer(PickleSerializer()))
return _local_iterator_from_socket(sock_info, BatchedSerializer(PickleSerializer()))

@ignore_unicode_prefix
@since(1.3)
Expand Down
28 changes: 28 additions & 0 deletions python/pyspark/sql/tests/test_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,34 @@ def test_repr_behaviors(self):
self.assertEquals(None, df._repr_html_())
self.assertEquals(expected, df.__repr__())

def test_to_local_iterator(self):
df = self.spark.range(8, numPartitions=4)
expected = df.collect()
it = df.toLocalIterator()
self.assertEqual(expected, list(it))

# Test DataFrame with empty partition
df = self.spark.range(3, numPartitions=4)
it = df.toLocalIterator()
expected = df.collect()
self.assertEqual(expected, list(it))

def test_to_local_iterator_not_fully_consumed(self):
# SPARK-23961: toLocalIterator throws exception when not fully consumed
# Create a DataFrame large enough so that write to socket will eventually block
df = self.spark.range(1 << 20, numPartitions=2)
it = df.toLocalIterator()
self.assertEqual(df.take(1)[0], next(it))
with QuietTest(self.sc):
it = None # remove iterator from scope, socket is closed when cleaned up
# Make sure normal df operations still work
result = []
for i, row in enumerate(df.toLocalIterator()):
result.append(row)
if i == 7:
break
self.assertEqual(df.take(8), result)


class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils):
# These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is
Expand Down
36 changes: 32 additions & 4 deletions python/pyspark/tests/test_rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@
xrange = range


global_func = lambda: "Hi"


class RDDTests(ReusedPySparkTestCase):

def test_range(self):
Expand All @@ -57,15 +60,12 @@ def test_sum(self):
self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum())

def test_to_localiterator(self):
from time import sleep
rdd = self.sc.parallelize([1, 2, 3])
it = rdd.toLocalIterator()
sleep(5)
self.assertEqual([1, 2, 3], sorted(it))

rdd2 = rdd.repartition(1000)
it2 = rdd2.toLocalIterator()
sleep(5)
self.assertEqual([1, 2, 3], sorted(it2))

def test_save_as_textfile_with_unicode(self):
Expand Down Expand Up @@ -605,7 +605,7 @@ def test_distinct(self):

def test_external_group_by_key(self):
self.sc._conf.set("spark.python.worker.memory", "1m")
N = 200001
N = 2000001
kv = self.sc.parallelize(xrange(N)).map(lambda x: (x % 3, x))
gkv = kv.groupByKey().cache()
self.assertEqual(3, gkv.count())
Expand Down Expand Up @@ -726,6 +726,34 @@ def stopit(*x):
self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
seq_rdd.aggregate, 0, lambda *x: 1, stopit)

def test_to_local_iterator_failure(self):
# SPARK-27548 toLocalIterator task failure not propagated to Python driver

def fail(_):
raise RuntimeError("local iterator error")

rdd = self.sc.range(10).map(fail)

with self.assertRaisesRegexp(Exception, "local iterator error"):
for _ in rdd.toLocalIterator():
pass

def test_to_local_iterator_collects_single_partition(self):
# Test that partitions are not computed until requested by iteration

def fail_last(x):
if x == 9:
raise RuntimeError("This should not be hit")
return x

rdd = self.sc.range(12, numSlices=4).map(fail_last)
it = rdd.toLocalIterator()

# Only consume first 4 elements from partitions 1 and 2, this should not collect the last
# partition which would trigger the error
for i in range(4):
self.assertEqual(i, next(it))


if __name__ == "__main__":
import unittest
Expand Down

0 comments on commit 80eff91

Please sign in to comment.