Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-23961][SPARK-27548][PYTHON] Fix error when toLocalIterator goes out of scope and properly raise errors from worker #24070

Conversation

BryanCutler
Copy link
Member

@BryanCutler BryanCutler commented Mar 12, 2019

What changes were proposed in this pull request?

This fixes an error when a PySpark local iterator, for both RDD and DataFrames, goes out of scope and the connection is closed before fully consuming the iterator. The error occurs on the JVM in the serving thread, when Python closes the local socket while the JVM is writing to it. This usually happens when there is enough data to fill the socket read buffer, causing the write call to block.

Additionally, this fixes a problem when an error occurs in the Python worker and the collect job is cancelled with an exception. Previously, the Python driver was never notified of the error so the user could get a partial result (iteration until the error) and the application will continue. With this change, an error in the worker is sent to the Python iterator and is then raised.

The change here introduces a protocol for PySpark local iterators that work as follows:

  1. The local socket connection is made when the iterator is created
  2. When iterating, Python first sends a request for partition data as a non-zero integer
  3. While the JVM local iterator over partitions has next, it triggers a job to collect the next partition
  4. The JVM sends a nonzero response to indicate it has the next partition to send
  5. The next partition is sent to Python and read by the PySpark deserializer
  6. After sending the entire partition, an END_OF_DATA_SECTION is sent to Python which stops the deserializer and allows to make another request
  7. When the JVM gets a request from Python but has already consumed it's local iterator, it will send a zero response to Python and both will close the socket cleanly
  8. If an error occurs in the worker, a negative response is sent to Python followed by the error message. Python will then raise a RuntimeError with the message, stopping iteration.
  9. When the PySpark local iterator is garbage-collected, it will read any remaining data from the current partition (this is data that has already been collected) and send a request of zero to tell the JVM to stop collection jobs and close the connection.

Steps 1, 3, 5, 6 are the same as before. Step 8 was completely missing before because errors in the worker were never communicated back to Python. The other steps add synchronization to allow for a clean closing of the socket, with a small trade-off in performance for each partition. This is mainly because the JVM does not start collecting partition data until it receives a request to do so, where before it would eagerly write all data until the socket receive buffer is full.

How was this patch tested?

Added new unit tests for DataFrame and RDD toLocalIterator and tested not fully consuming the iterator. Manual tests with Python 2.7 and 3.6.

@SparkQA
Copy link

SparkQA commented Mar 12, 2019

Test build #103371 has finished for PR 24070 at commit b92a4e6.

  • This patch fails to build.
  • This patch merges cleanly.
  • This patch adds no public classes.

@BryanCutler BryanCutler force-pushed the pyspark-toLocalIterator-clean-stop-SPARK-23961 branch from b92a4e6 to d02d341 Compare March 12, 2019 18:34
@BryanCutler
Copy link
Member Author

Timings for DataFrame.toLocalIterator and RDD.toLocalIterator

These tests are to illustrate the slowdown caused by this change, comparing current master with this change. Wall clock time is measure to fully consume the local iterator and average of 5 runs are shown:

_ DataFrame RDD
master 10.26016583 4.354181528
this PR 12.14033799 3.823320436

Test Script

import time
from pyspark.sql import SparkSession

spark = SparkSession\
        .builder\
        .appName("toLocalIterator_timing")\
        .getOrCreate()

num = 1 << 22
numParts = 32

def run(df):
  print("Starting iterator:")
  start = time.time()

  count = 0
  for row in df.toLocalIterator():
    count += 1

  if count != num:
    raise RuntimeError("Expected {} but got {}".format(num, count))

  elapsed = time.time() - start
  print("completed in {}".format(elapsed))

run(spark.range(num, numPartitions=numParts))
run(spark.sparkContext.range(num, numSlices=numParts))

spark.stop()

@BryanCutler
Copy link
Member Author

BryanCutler commented Mar 12, 2019

I just want to highlight that the error that this fixes only kills the serving thread and Spark can continue normal operation. Although the error is pretty ugly and would lead users to think that something went terribly wrong. Since it's pretty common to not fully consume an iterator, e.g. taking a slice, I believe it is worth making this change.

It is also possible that this change would be very beneficial because if the iterator is not fully consumed, it could save the triggering of unneeded jobs where the behavior before eagerly queued jobs for all partitions. In this sense, the change here more closely follows the Scala behavior.

I'm also not entirely sure why I'm seeing a speedup for the RDD toLocalIterator. When using 8 partitions instead of 32, I noticed a slowdown. I will try to run some more tests.


// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this is the same function as collectPartition(p: Int) in Scala, except here we do not want to flatten the collected arrays

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For performance, as mentioned in my questions, would it make sense to use something like a iterator with look ahead of say 1 partition (or X% of partitions) so we decrease the blocking time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that would have better performance, but it does say in the doc that max memory usage will be the largest partition. Going over that might cause problems for some people, no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point @BryanCutler, we could implement the lookahead as a seperate PR/JIRA and allow it to be turned-on/off. I'd suggest this PR is more about fixing the behaviour of toLocalIterator memory wise than the out-of-scope issue (although the out-of-scope issue is maybe more visible in the logs).

python/pyspark/rdd.py Outdated Show resolved Hide resolved
result.append(row)
if i == 7:
break
self.assertEqual(df.take(8), result)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: this would not have crashed before, only generated the error and I don't think it's possible to check if this error happened

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the error a JVM error? if so we could grab the stderr/stdout and look for the error message in the result there?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, that might work. I can give it a try..

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did anything come of this? It's optional so don't block on it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh right, I forgot about this. Let me give it a shot now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It ends up being a little complicated, maybe better to try as a followup

rdd = self.sc.parallelize([1, 2, 3])
it = rdd.toLocalIterator()
sleep(5)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sleep is unnecessary. rdd.toLocalIterator makes the socket connection and iterating starts reading from the pyspark serializer.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like less sleeps in the code <3

self.assertEqual([1, 2, 3], sorted(it))

rdd2 = rdd.repartition(1000)
it2 = rdd2.toLocalIterator()
sleep(5)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@BryanCutler
Copy link
Member Author

Ping @HyukjinKwon @holdenk @viirya for thoughts on this, thanks!

@SparkQA
Copy link

SparkQA commented Mar 12, 2019

Test build #103382 has finished for PR 24070 at commit d02d341.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@BryanCutler
Copy link
Member Author

gentle ping @HyukjinKwon @ueshin for thoughts on this fix

""" Create a synchronous local iterable over a socket """

def __init__(self, sock_info, serializer):
(self.sockfile, self.sock) = _create_local_socket(sock_info)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to store sock?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah no, that was leftover from a previous revision. I'll remove that and clean up.

@@ -2386,7 +2425,7 @@ def toLocalIterator(self):
"""
with SCCallSiteSync(self.context) as css:
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd())
return _load_from_socket(sock_info, self._jrdd_deserializer)
return iter(_PyLocalIterable(sock_info, self._jrdd_deserializer))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about making a method instead of exposing _PyLocalIterable?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, good idea thanks!

@BryanCutler
Copy link
Member Author

Thanks @ueshin , I updated the code. I also made a change for the JVM to send a response to indicate if there are more partitions to read instead of relying on closing the socket and catching an error, which could have masked real errors. This change did not have any significant difference to the timings above. Please take another look when you can.

@SparkQA
Copy link

SparkQA commented Mar 30, 2019

Test build #104094 has finished for PR 24070 at commit 35e8730.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds the following public classes (experimental):
  • class PyLocalIterable(object):

@holdenk
Copy link
Contributor

holdenk commented Apr 1, 2019

So before I jump into the details of this PR, from the design side I've got some questions I'd like to try and understand better:

  1. This description mentions that it may save "unneeded work" but not eagerly queueing jobs that aren't needed, but in the situation where the jobs are needed it seems like we're adding unnecessary synchronization that could slow this down.
  • Have you tried the benchmarks with a slow per-partition computation to see if there is more impact there?
  • Is this slow down "worth it", 20% on DataFrame is nothing to walk away from if we can avoid it.
  • If we're worried about the driver program falling over from buffering too much data in the JVM can we maybe queue the next partition while we are serving the current one to decrease the overhead?
  1. Are we solving this problem on the "right" side? Is it common for the the localiterator to go out of scope before being consumed?
  • If it is is a relatively infrequent occurrence, would it possible make sense to just make the Java code accept that the Python socket may go away? I think that might keep the "happy" (or more frequent) path at reasonable performance. On the other hand if folks are using toLocalIterator to partially consume the data and this is the expected path, then using exception handling might not make sense.

I'm going to dig into the code, but I'd love to get a better understanding of the design as well.

Copy link
Contributor

@holdenk holdenk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this issue, I'd love to get a better understanding of the design trade-offs but really excited to see this get fixed and also exited to see less sleep(5) in the code base :) :p


// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For performance, as mentioned in my questions, would it make sense to use something like a iterator with look ahead of say 1 partition (or X% of partitions) so we decrease the blocking time.

# Finish consuming partition data stream
for _ in self._read_iter:
pass
# Tell Java to stop sending data and close connection
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to send this message before finishing consuming the data coming in?


// Write the next object and signal end of data for this iteration
val partitionArray = collectPartitionIter.next()
writeIteratorToStream(partitionArray.toIterator, out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to use something which we can interrupt here instead if we get a message from Python (or if the socket closes) that we don't need to send more data?

result.append(row)
if i == 7:
break
self.assertEqual(df.take(8), result)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the error a JVM error? if so we could grab the stderr/stdout and look for the error message in the result there?



def _load_from_socket(sock_info, serializer):
sockfile = _create_local_socket(sock_info)
# The socket will be automatically closed when garbage-collected.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that this comment should be moved above in _create_local_socket, too.

@@ -168,7 +168,42 @@ private[spark] object PythonRDD extends Logging {
}

def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = {
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is also possible that this change would be very beneficial because if the iterator is not fully consumed, it could save the triggering of unneeded jobs where the behavior before eagerly queued jobs for all partitions. In this sense, the change here more closely follows the Scala behavior.

Once the local iterator is out of scope in Python side, will remaining jobs still be triggered after at Scala side it can't write into the closed connection?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the remaining jobs are not triggered. The python iterator finishes consuming the data from the current job, then sends a command for Scala iterator to stop.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about previous behavior? The behavior before will trigger them? Looks like toLocalIterator won't trigger the job if we don't iterate the data on a partition.

Copy link
Member Author

@BryanCutler BryanCutler Apr 5, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous behavior was that the Scala local iterator would advance as long as the write calls to the socket are not blocked. So this means when Python reads a batch (auto-batched elements) from the current partition, this will unblock the Scala call to write and could start a job to collect the next partition.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once the local iterator at Python side is out of scope and so the iterator is not fully consumed, will it block the write call at Scala? Seems to me that it will and we shouldn't see unneeded jobs to be triggered after that, doesn't?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous behavior is when the iterator goes out of scope, the socket eventually is closed. This creates the error on the Scala side and the writing thread is terminated, so no more jobs are triggered but the user sees this error.

@BryanCutler
Copy link
Member Author

BryanCutler commented Apr 5, 2019

Are we solving this problem on the "right" side? Is it common for the the localiterator to go out of scope before being consumed?
If it is is a relatively infrequent occurrence, would it possible make sense to just make the Java code accept that the Python socket may go away?

Thanks @holdenk , I totally agree. I started off with a fix that would catch the error on the Java side and just allow the connection to be broken without showing the error. Let me summarize the trade-offs that I noticed:

Catch/Ignore error in Java:

  • No change in the protocol to Python, keeps things the same
  • This is the fastest way to iterate through all elements
  • Ignoring the error might hide an actual error, but this is over a local socket so once the connection is made I'm not sure how likely an error is
  • While Python is still consuming data from one partition, other jobs will be started to collect the next partitions until write to the socket is blocked again. This is nice if all the data is used, but bad if it's not.

Synchronized Protocol (this PR):

  • Allows for the iterator to go out of scope and close the connection cleanly, without needing to catch/ignore errors.
  • Jobs are only triggered when the next element is needed. So only after the last element of a partition is consumed, the next partition is collected. This matches the behavior of Scala toLocalIterator.
  • Performance is a little slower because of synchronization and no longer eagerly collecting the next partition
  • Makes the communication between Python more complicated

I went with the second option here because I definitely didn't want to mask potential real errors, and only triggering jobs for data that is requested, matches Scala and is good for some use cases. Also while we don't want performance to suck, I think people would use toLocalIterator to work with constrained memory (1 partition at a time) and do a regular collect if they want good performance.

I'm not completely sure this is the best fix, so I'm happy to discuss :)

Btw, from the JIRA for this, the user was calling itertools.islice(..) to take a slice of the the local iterator, which wasn't fully consuming it. I would agree it is probably more common to consume the entire iterator though.

@holdenk
Copy link
Contributor

holdenk commented Apr 15, 2019

Thanks for the additional context @BryanCutler that really helps.

I think supporting memory constrained consumption of a large dataset is core to the goal of toLocalIterator so while there is a performance penality of this change, and we can work to minimize it, I think it's the right thing to do.

I agree with you solving it in the Python side seems like the right set of trade-offs. I think it might make sense to kick off the job for the next partition (e.g. lookahead of 1), but we should totally do that in a follow up PR/JIRA as an optiimization. What do you think?

Copy link
Contributor

@holdenk holdenk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the additional context, I agree with the design I think it is the correct set of trade-offs. I think it would be good to make a follow up JIRA for trying to improve the performance though if you agree.


// Collects a partition on each iteration
val collectPartitionIter = rdd.partitions.indices.iterator.map { i =>
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point @BryanCutler, we could implement the lookahead as a seperate PR/JIRA and allow it to be turned-on/off. I'd suggest this PR is more about fixing the behaviour of toLocalIterator memory wise than the out-of-scope issue (although the out-of-scope issue is maybe more visible in the logs).

@BryanCutler
Copy link
Member Author

Thanks @holdenk ! I think having a lookahead for 1 partition makes sense as an option and that should bring the performance back up to where it was before. Sounds good to me to do this as a followup also. So do you think this is ok to merge? Let me rebase and test once more..

@SparkQA
Copy link

SparkQA commented Apr 19, 2019

Test build #104755 has finished for PR 24070 at commit e745595.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@holdenk
Copy link
Contributor

holdenk commented Apr 19, 2019

Provided the AppVeyor failure is unrelated I think it's ok to merge, it would be nice to have a more explicit test for this but capturing the JVM output isn't something we do elsewhere so I'm OK with putting that off (although I'd like to see a JIRA for it maybe?). Thanks for working on this -- if you do make the follow up issues do CC me on them :)

@BryanCutler
Copy link
Member Author

it would be nice to have a more explicit test for this but capturing the JVM output isn't something we do elsewhere so I'm OK with putting that off (although I'd like to see a JIRA for it maybe?)

Thanks @holdenk ! I was close to getting a better test that captures JVM output using launch_gateway with popen_kwargs but there are a couple of tricky things doing it this way so I think it would be better to try as a followup where we could discuss more.

@BryanCutler
Copy link
Member Author

test this please

@SparkQA
Copy link

SparkQA commented Apr 23, 2019

Test build #104817 has finished for PR 24070 at commit e745595.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@holdenk
Copy link
Contributor

holdenk commented Apr 23, 2019

Sounds good @BryanCutler if your @ Spark Summit this week we can catch up in person about the difficult with the tests otherwise happy to chat whenever.

@holdenk
Copy link
Contributor

holdenk commented Apr 23, 2019

I have another idea on how we can verify this behaviour change: parallelize a collection with 2 partitions and do map which throws on the last element and do a toLocalIterator and in the old version it should fail but in the new version it should not fail. What do you think about that as a test?

@BryanCutler BryanCutler force-pushed the pyspark-toLocalIterator-clean-stop-SPARK-23961 branch from dee0dfc to 29b8ab6 Compare April 30, 2019 01:00
@SparkQA
Copy link

SparkQA commented Apr 30, 2019

Test build #105013 has finished for PR 24070 at commit dee0dfc.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 30, 2019

Test build #105012 has finished for PR 24070 at commit 1e0e5f1.

  • This patch passes all tests.
  • This patch does not merge cleanly.
  • This patch adds no public classes.

@SparkQA
Copy link

SparkQA commented Apr 30, 2019

Test build #105016 has finished for PR 24070 at commit 29b8ab6.

  • This patch fails due to an unknown error code, -9.
  • This patch merges cleanly.
  • This patch adds no public classes.

@viirya
Copy link
Member

viirya commented Apr 30, 2019

retest this please.

@SparkQA
Copy link

SparkQA commented Apr 30, 2019

Test build #105025 has finished for PR 24070 at commit 29b8ab6.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@BryanCutler BryanCutler changed the title [SPARK-23961][PYTHON] Fix error when toLocalIterator goes out of scope [SPARK-23961][SPARK-27548][PYTHON] Fix error when toLocalIterator goes out of scope Apr 30, 2019
@BryanCutler BryanCutler changed the title [SPARK-23961][SPARK-27548][PYTHON] Fix error when toLocalIterator goes out of scope [SPARK-23961][SPARK-27548][PYTHON] Fix error when toLocalIterator goes out of scope and properly raise errors from worker Apr 30, 2019
@BryanCutler
Copy link
Member Author

I added the fix for https://issues.apache.org/jira/browse/SPARK-27548 here and updated the PR description. Basically when an error occurs on a worker, the collect job fails in the JVM and the error is sent to the Python iterable, which raises the error. If you could please have another look @holdenk @viirya @ueshin @HyukjinKwon and see if you agree with the approach. Thanks!

Copy link
Contributor

@holdenk holdenk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending others comments of course, although if no one else has anything to say lets go with it :)

out.writeInt(SpecialLengths.END_OF_DATA_SECTION)
out.flush()
} catch {
case e: SparkException =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's reasonable +1

@BryanCutler
Copy link
Member Author

Thanks @holdenk ! I will merge this after the current tests pass if there are no further comments.

@SparkQA
Copy link

SparkQA commented May 6, 2019

Test build #105174 has finished for PR 24070 at commit 4f842dc.

  • This patch fails Spark unit tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@BryanCutler
Copy link
Member Author

retest this please

@SparkQA
Copy link

SparkQA commented May 7, 2019

Test build #105178 has finished for PR 24070 at commit 4f842dc.

  • This patch passes all tests.
  • This patch merges cleanly.
  • This patch adds no public classes.

@BryanCutler BryanCutler deleted the pyspark-toLocalIterator-clean-stop-SPARK-23961 branch May 7, 2019 21:52
@BryanCutler
Copy link
Member Author

merged to master, thanks all for reviewing

@HyukjinKwon
Copy link
Member

Sorry for my late response. I didn't check super closely but looks a-okay to me too.

@BryanCutler
Copy link
Member Author

I made https://issues.apache.org/jira/browse/SPARK-27660 to explore prefetching data

emanuelebardelli pushed a commit to emanuelebardelli/spark that referenced this pull request Jun 15, 2019
…arrow enabled

## What changes were proposed in this pull request?
Similar to apache#24070, we now propagate SparkExceptions that are encountered during the collect in the java process to the python process.

Fixes https://jira.apache.org/jira/browse/SPARK-27805

## How was this patch tested?
Added a new unit test

Closes apache#24677 from dvogelbacher/dv/betterErrorMsgWhenUsingArrow.

Authored-by: David Vogelbacher <dvogelbacher@palantir.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
dvogelbacher added a commit to palantir/spark that referenced this pull request Nov 25, 2019
…arrow enabled

## What changes were proposed in this pull request?
Similar to apache#24070, we now propagate SparkExceptions that are encountered during the collect in the java process to the python process.

Fixes https://jira.apache.org/jira/browse/SPARK-27805

## How was this patch tested?
Added a new unit test

Closes apache#24677 from dvogelbacher/dv/betterErrorMsgWhenUsingArrow.

Authored-by: David Vogelbacher <dvogelbacher@palantir.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
rshkv pushed a commit to palantir/spark that referenced this pull request Jun 4, 2020
…s out of scope and properly raise errors from worker

This fixes an error when a PySpark local iterator, for both RDD and DataFrames, goes out of scope and the connection is closed before fully consuming the iterator. The error occurs on the JVM in the serving thread, when Python closes the local socket while the JVM is writing to it. This usually happens when there is enough data to fill the socket read buffer, causing the write call to block.

Additionally, this fixes a problem when an error occurs in the Python worker and the collect job is cancelled with an exception. Previously, the Python driver was never notified of the error so the user could get a partial result (iteration until the error) and the application will continue. With this change, an error in the worker is sent to the Python iterator and is then raised.

The change here introduces a protocol for PySpark local iterators that work as follows:

1) The local socket connection is made when the iterator is created
2) When iterating, Python first sends a request for partition data as a non-zero integer
3) While the JVM local iterator over partitions has next, it triggers a job to collect the next partition
4) The JVM sends a nonzero response to indicate it has the next partition to send
5) The next partition is sent to Python and read by the PySpark deserializer
6) After sending the entire partition, an `END_OF_DATA_SECTION` is sent to Python which stops the deserializer and allows to make another request
7) When the JVM gets a request from Python but has already consumed it's local iterator, it will send a zero response to Python and both will close the socket cleanly
8) If an error occurs in the worker, a negative response is sent to Python followed by the error message. Python will then raise a RuntimeError with the message, stopping iteration.
9) When the PySpark local iterator is garbage-collected, it will read any remaining data from the current partition (this is data that has already been collected) and send a request of zero to tell the JVM to stop collection jobs and close the connection.

Steps 1, 3, 5, 6 are the same as before. Step 8 was completely missing before because errors in the worker were never communicated back to Python. The other steps add synchronization to allow for a clean closing of the socket, with a small trade-off in performance for each partition. This is mainly because the JVM does not start collecting partition data until it receives a request to do so, where before it would eagerly write all data until the socket receive buffer is full.

Added new unit tests for DataFrame and RDD `toLocalIterator` and tested not fully consuming the iterator. Manual tests with Python 2.7  and 3.6.

Closes apache#24070 from BryanCutler/pyspark-toLocalIterator-clean-stop-SPARK-23961.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
rshkv pushed a commit to palantir/spark that referenced this pull request Jun 5, 2020
…s out of scope and properly raise errors from worker

This fixes an error when a PySpark local iterator, for both RDD and DataFrames, goes out of scope and the connection is closed before fully consuming the iterator. The error occurs on the JVM in the serving thread, when Python closes the local socket while the JVM is writing to it. This usually happens when there is enough data to fill the socket read buffer, causing the write call to block.

Additionally, this fixes a problem when an error occurs in the Python worker and the collect job is cancelled with an exception. Previously, the Python driver was never notified of the error so the user could get a partial result (iteration until the error) and the application will continue. With this change, an error in the worker is sent to the Python iterator and is then raised.

The change here introduces a protocol for PySpark local iterators that work as follows:

1) The local socket connection is made when the iterator is created
2) When iterating, Python first sends a request for partition data as a non-zero integer
3) While the JVM local iterator over partitions has next, it triggers a job to collect the next partition
4) The JVM sends a nonzero response to indicate it has the next partition to send
5) The next partition is sent to Python and read by the PySpark deserializer
6) After sending the entire partition, an `END_OF_DATA_SECTION` is sent to Python which stops the deserializer and allows to make another request
7) When the JVM gets a request from Python but has already consumed it's local iterator, it will send a zero response to Python and both will close the socket cleanly
8) If an error occurs in the worker, a negative response is sent to Python followed by the error message. Python will then raise a RuntimeError with the message, stopping iteration.
9) When the PySpark local iterator is garbage-collected, it will read any remaining data from the current partition (this is data that has already been collected) and send a request of zero to tell the JVM to stop collection jobs and close the connection.

Steps 1, 3, 5, 6 are the same as before. Step 8 was completely missing before because errors in the worker were never communicated back to Python. The other steps add synchronization to allow for a clean closing of the socket, with a small trade-off in performance for each partition. This is mainly because the JVM does not start collecting partition data until it receives a request to do so, where before it would eagerly write all data until the socket receive buffer is full.

Added new unit tests for DataFrame and RDD `toLocalIterator` and tested not fully consuming the iterator. Manual tests with Python 2.7  and 3.6.

Closes apache#24070 from BryanCutler/pyspark-toLocalIterator-clean-stop-SPARK-23961.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants