-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-23961][SPARK-27548][PYTHON] Fix error when toLocalIterator goes out of scope and properly raise errors from worker #24070
Conversation
Test build #103371 has finished for PR 24070 at commit
|
b92a4e6
to
d02d341
Compare
Timings for DataFrame.toLocalIterator and RDD.toLocalIteratorThese tests are to illustrate the slowdown caused by this change, comparing current master with this change. Wall clock time is measure to fully consume the local iterator and average of 5 runs are shown:
Test Scriptimport time
from pyspark.sql import SparkSession
spark = SparkSession\
.builder\
.appName("toLocalIterator_timing")\
.getOrCreate()
num = 1 << 22
numParts = 32
def run(df):
print("Starting iterator:")
start = time.time()
count = 0
for row in df.toLocalIterator():
count += 1
if count != num:
raise RuntimeError("Expected {} but got {}".format(num, count))
elapsed = time.time() - start
print("completed in {}".format(elapsed))
run(spark.range(num, numPartitions=numParts))
run(spark.sparkContext.range(num, numSlices=numParts))
spark.stop()
|
I just want to highlight that the error that this fixes only kills the serving thread and Spark can continue normal operation. Although the error is pretty ugly and would lead users to think that something went terribly wrong. Since it's pretty common to not fully consume an iterator, e.g. taking a slice, I believe it is worth making this change. It is also possible that this change would be very beneficial because if the iterator is not fully consumed, it could save the triggering of unneeded jobs where the behavior before eagerly queued jobs for all partitions. In this sense, the change here more closely follows the Scala behavior. I'm also not entirely sure why I'm seeing a speedup for the RDD toLocalIterator. When using 8 partitions instead of 32, I noticed a slowdown. I will try to run some more tests. |
|
||
// Collects a partition on each iteration | ||
val collectPartitionIter = rdd.partitions.indices.iterator.map { i => | ||
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this is the same function as collectPartition(p: Int)
in Scala, except here we do not want to flatten the collected arrays
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For performance, as mentioned in my questions, would it make sense to use something like a iterator with look ahead of say 1 partition (or X% of partitions) so we decrease the blocking time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that would have better performance, but it does say in the doc that max memory usage will be the largest partition. Going over that might cause problems for some people, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point @BryanCutler, we could implement the lookahead as a seperate PR/JIRA and allow it to be turned-on/off. I'd suggest this PR is more about fixing the behaviour of toLocalIterator memory wise than the out-of-scope issue (although the out-of-scope issue is maybe more visible in the logs).
result.append(row) | ||
if i == 7: | ||
break | ||
self.assertEqual(df.take(8), result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: this would not have crashed before, only generated the error and I don't think it's possible to check if this error happened
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the error a JVM error? if so we could grab the stderr/stdout and look for the error message in the result there?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, that might work. I can give it a try..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did anything come of this? It's optional so don't block on it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh right, I forgot about this. Let me give it a shot now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It ends up being a little complicated, maybe better to try as a followup
rdd = self.sc.parallelize([1, 2, 3]) | ||
it = rdd.toLocalIterator() | ||
sleep(5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This sleep is unnecessary. rdd.toLocalIterator
makes the socket connection and iterating starts reading from the pyspark serializer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like less sleeps in the code <3
self.assertEqual([1, 2, 3], sorted(it)) | ||
|
||
rdd2 = rdd.repartition(1000) | ||
it2 = rdd2.toLocalIterator() | ||
sleep(5) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Ping @HyukjinKwon @holdenk @viirya for thoughts on this, thanks! |
Test build #103382 has finished for PR 24070 at commit
|
gentle ping @HyukjinKwon @ueshin for thoughts on this fix |
python/pyspark/rdd.py
Outdated
""" Create a synchronous local iterable over a socket """ | ||
|
||
def __init__(self, sock_info, serializer): | ||
(self.sockfile, self.sock) = _create_local_socket(sock_info) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to store sock
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah no, that was leftover from a previous revision. I'll remove that and clean up.
python/pyspark/rdd.py
Outdated
@@ -2386,7 +2425,7 @@ def toLocalIterator(self): | |||
""" | |||
with SCCallSiteSync(self.context) as css: | |||
sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) | |||
return _load_from_socket(sock_info, self._jrdd_deserializer) | |||
return iter(_PyLocalIterable(sock_info, self._jrdd_deserializer)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about making a method instead of exposing _PyLocalIterable
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, good idea thanks!
Thanks @ueshin , I updated the code. I also made a change for the JVM to send a response to indicate if there are more partitions to read instead of relying on closing the socket and catching an error, which could have masked real errors. This change did not have any significant difference to the timings above. Please take another look when you can. |
Test build #104094 has finished for PR 24070 at commit
|
So before I jump into the details of this PR, from the design side I've got some questions I'd like to try and understand better:
I'm going to dig into the code, but I'd love to get a better understanding of the design as well. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this issue, I'd love to get a better understanding of the design trade-offs but really excited to see this get fixed and also exited to see less sleep(5)
in the code base :) :p
|
||
// Collects a partition on each iteration | ||
val collectPartitionIter = rdd.partitions.indices.iterator.map { i => | ||
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For performance, as mentioned in my questions, would it make sense to use something like a iterator with look ahead of say 1 partition (or X% of partitions) so we decrease the blocking time.
python/pyspark/rdd.py
Outdated
# Finish consuming partition data stream | ||
for _ in self._read_iter: | ||
pass | ||
# Tell Java to stop sending data and close connection |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to send this message before finishing consuming the data coming in?
|
||
// Write the next object and signal end of data for this iteration | ||
val partitionArray = collectPartitionIter.next() | ||
writeIteratorToStream(partitionArray.toIterator, out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to use something which we can interrupt here instead if we get a message from Python (or if the socket closes) that we don't need to send more data?
result.append(row) | ||
if i == 7: | ||
break | ||
self.assertEqual(df.take(8), result) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the error a JVM error? if so we could grab the stderr/stdout and look for the error message in the result there?
|
||
|
||
def _load_from_socket(sock_info, serializer): | ||
sockfile = _create_local_socket(sock_info) | ||
# The socket will be automatically closed when garbage-collected. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess that this comment should be moved above in _create_local_socket
, too.
@@ -168,7 +168,42 @@ private[spark] object PythonRDD extends Logging { | |||
} | |||
|
|||
def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { | |||
serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is also possible that this change would be very beneficial because if the iterator is not fully consumed, it could save the triggering of unneeded jobs where the behavior before eagerly queued jobs for all partitions. In this sense, the change here more closely follows the Scala behavior.
Once the local iterator is out of scope in Python side, will remaining jobs still be triggered after at Scala side it can't write into the closed connection?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the remaining jobs are not triggered. The python iterator finishes consuming the data from the current job, then sends a command for Scala iterator to stop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about previous behavior? The behavior before will trigger them? Looks like toLocalIterator
won't trigger the job if we don't iterate the data on a partition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous behavior was that the Scala local iterator would advance as long as the write
calls to the socket are not blocked. So this means when Python reads a batch (auto-batched elements) from the current partition, this will unblock the Scala call to write and could start a job to collect the next partition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once the local iterator at Python side is out of scope and so the iterator is not fully consumed, will it block the write call at Scala? Seems to me that it will and we shouldn't see unneeded jobs to be triggered after that, doesn't?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous behavior is when the iterator goes out of scope, the socket eventually is closed. This creates the error on the Scala side and the writing thread is terminated, so no more jobs are triggered but the user sees this error.
Thanks @holdenk , I totally agree. I started off with a fix that would catch the error on the Java side and just allow the connection to be broken without showing the error. Let me summarize the trade-offs that I noticed: Catch/Ignore error in Java:
Synchronized Protocol (this PR):
I went with the second option here because I definitely didn't want to mask potential real errors, and only triggering jobs for data that is requested, matches Scala and is good for some use cases. Also while we don't want performance to suck, I think people would use toLocalIterator to work with constrained memory (1 partition at a time) and do a regular collect if they want good performance. I'm not completely sure this is the best fix, so I'm happy to discuss :) Btw, from the JIRA for this, the user was calling |
Thanks for the additional context @BryanCutler that really helps. I think supporting memory constrained consumption of a large dataset is core to the goal of toLocalIterator so while there is a performance penality of this change, and we can work to minimize it, I think it's the right thing to do. I agree with you solving it in the Python side seems like the right set of trade-offs. I think it might make sense to kick off the job for the next partition (e.g. lookahead of 1), but we should totally do that in a follow up PR/JIRA as an optiimization. What do you think? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the additional context, I agree with the design I think it is the correct set of trade-offs. I think it would be good to make a follow up JIRA for trying to improve the performance though if you agree.
|
||
// Collects a partition on each iteration | ||
val collectPartitionIter = rdd.partitions.indices.iterator.map { i => | ||
rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point @BryanCutler, we could implement the lookahead as a seperate PR/JIRA and allow it to be turned-on/off. I'd suggest this PR is more about fixing the behaviour of toLocalIterator memory wise than the out-of-scope issue (although the out-of-scope issue is maybe more visible in the logs).
Thanks @holdenk ! I think having a lookahead for 1 partition makes sense as an option and that should bring the performance back up to where it was before. Sounds good to me to do this as a followup also. So do you think this is ok to merge? Let me rebase and test once more.. |
Test build #104755 has finished for PR 24070 at commit
|
Provided the AppVeyor failure is unrelated I think it's ok to merge, it would be nice to have a more explicit test for this but capturing the JVM output isn't something we do elsewhere so I'm OK with putting that off (although I'd like to see a JIRA for it maybe?). Thanks for working on this -- if you do make the follow up issues do CC me on them :) |
Thanks @holdenk ! I was close to getting a better test that captures JVM output using |
test this please |
Test build #104817 has finished for PR 24070 at commit
|
Sounds good @BryanCutler if your @ Spark Summit this week we can catch up in person about the difficult with the tests otherwise happy to chat whenever. |
I have another idea on how we can verify this behaviour change: parallelize a collection with 2 partitions and do map which throws on the last element and do a toLocalIterator and in the old version it should fail but in the new version it should not fail. What do you think about that as a test? |
dee0dfc
to
29b8ab6
Compare
Test build #105013 has finished for PR 24070 at commit
|
Test build #105012 has finished for PR 24070 at commit
|
Test build #105016 has finished for PR 24070 at commit
|
retest this please. |
Test build #105025 has finished for PR 24070 at commit
|
I added the fix for https://issues.apache.org/jira/browse/SPARK-27548 here and updated the PR description. Basically when an error occurs on a worker, the collect job fails in the JVM and the error is sent to the Python iterable, which raises the error. If you could please have another look @holdenk @viirya @ueshin @HyukjinKwon and see if you agree with the approach. Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending others comments of course, although if no one else has anything to say lets go with it :)
out.writeInt(SpecialLengths.END_OF_DATA_SECTION) | ||
out.flush() | ||
} catch { | ||
case e: SparkException => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's reasonable +1
Thanks @holdenk ! I will merge this after the current tests pass if there are no further comments. |
Test build #105174 has finished for PR 24070 at commit
|
retest this please |
Test build #105178 has finished for PR 24070 at commit
|
merged to master, thanks all for reviewing |
Sorry for my late response. I didn't check super closely but looks a-okay to me too. |
I made https://issues.apache.org/jira/browse/SPARK-27660 to explore prefetching data |
…arrow enabled ## What changes were proposed in this pull request? Similar to apache#24070, we now propagate SparkExceptions that are encountered during the collect in the java process to the python process. Fixes https://jira.apache.org/jira/browse/SPARK-27805 ## How was this patch tested? Added a new unit test Closes apache#24677 from dvogelbacher/dv/betterErrorMsgWhenUsingArrow. Authored-by: David Vogelbacher <dvogelbacher@palantir.com> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
…arrow enabled ## What changes were proposed in this pull request? Similar to apache#24070, we now propagate SparkExceptions that are encountered during the collect in the java process to the python process. Fixes https://jira.apache.org/jira/browse/SPARK-27805 ## How was this patch tested? Added a new unit test Closes apache#24677 from dvogelbacher/dv/betterErrorMsgWhenUsingArrow. Authored-by: David Vogelbacher <dvogelbacher@palantir.com> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
…s out of scope and properly raise errors from worker This fixes an error when a PySpark local iterator, for both RDD and DataFrames, goes out of scope and the connection is closed before fully consuming the iterator. The error occurs on the JVM in the serving thread, when Python closes the local socket while the JVM is writing to it. This usually happens when there is enough data to fill the socket read buffer, causing the write call to block. Additionally, this fixes a problem when an error occurs in the Python worker and the collect job is cancelled with an exception. Previously, the Python driver was never notified of the error so the user could get a partial result (iteration until the error) and the application will continue. With this change, an error in the worker is sent to the Python iterator and is then raised. The change here introduces a protocol for PySpark local iterators that work as follows: 1) The local socket connection is made when the iterator is created 2) When iterating, Python first sends a request for partition data as a non-zero integer 3) While the JVM local iterator over partitions has next, it triggers a job to collect the next partition 4) The JVM sends a nonzero response to indicate it has the next partition to send 5) The next partition is sent to Python and read by the PySpark deserializer 6) After sending the entire partition, an `END_OF_DATA_SECTION` is sent to Python which stops the deserializer and allows to make another request 7) When the JVM gets a request from Python but has already consumed it's local iterator, it will send a zero response to Python and both will close the socket cleanly 8) If an error occurs in the worker, a negative response is sent to Python followed by the error message. Python will then raise a RuntimeError with the message, stopping iteration. 9) When the PySpark local iterator is garbage-collected, it will read any remaining data from the current partition (this is data that has already been collected) and send a request of zero to tell the JVM to stop collection jobs and close the connection. Steps 1, 3, 5, 6 are the same as before. Step 8 was completely missing before because errors in the worker were never communicated back to Python. The other steps add synchronization to allow for a clean closing of the socket, with a small trade-off in performance for each partition. This is mainly because the JVM does not start collecting partition data until it receives a request to do so, where before it would eagerly write all data until the socket receive buffer is full. Added new unit tests for DataFrame and RDD `toLocalIterator` and tested not fully consuming the iterator. Manual tests with Python 2.7 and 3.6. Closes apache#24070 from BryanCutler/pyspark-toLocalIterator-clean-stop-SPARK-23961. Authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
…s out of scope and properly raise errors from worker This fixes an error when a PySpark local iterator, for both RDD and DataFrames, goes out of scope and the connection is closed before fully consuming the iterator. The error occurs on the JVM in the serving thread, when Python closes the local socket while the JVM is writing to it. This usually happens when there is enough data to fill the socket read buffer, causing the write call to block. Additionally, this fixes a problem when an error occurs in the Python worker and the collect job is cancelled with an exception. Previously, the Python driver was never notified of the error so the user could get a partial result (iteration until the error) and the application will continue. With this change, an error in the worker is sent to the Python iterator and is then raised. The change here introduces a protocol for PySpark local iterators that work as follows: 1) The local socket connection is made when the iterator is created 2) When iterating, Python first sends a request for partition data as a non-zero integer 3) While the JVM local iterator over partitions has next, it triggers a job to collect the next partition 4) The JVM sends a nonzero response to indicate it has the next partition to send 5) The next partition is sent to Python and read by the PySpark deserializer 6) After sending the entire partition, an `END_OF_DATA_SECTION` is sent to Python which stops the deserializer and allows to make another request 7) When the JVM gets a request from Python but has already consumed it's local iterator, it will send a zero response to Python and both will close the socket cleanly 8) If an error occurs in the worker, a negative response is sent to Python followed by the error message. Python will then raise a RuntimeError with the message, stopping iteration. 9) When the PySpark local iterator is garbage-collected, it will read any remaining data from the current partition (this is data that has already been collected) and send a request of zero to tell the JVM to stop collection jobs and close the connection. Steps 1, 3, 5, 6 are the same as before. Step 8 was completely missing before because errors in the worker were never communicated back to Python. The other steps add synchronization to allow for a clean closing of the socket, with a small trade-off in performance for each partition. This is mainly because the JVM does not start collecting partition data until it receives a request to do so, where before it would eagerly write all data until the socket receive buffer is full. Added new unit tests for DataFrame and RDD `toLocalIterator` and tested not fully consuming the iterator. Manual tests with Python 2.7 and 3.6. Closes apache#24070 from BryanCutler/pyspark-toLocalIterator-clean-stop-SPARK-23961. Authored-by: Bryan Cutler <cutlerb@gmail.com> Signed-off-by: Bryan Cutler <cutlerb@gmail.com>
What changes were proposed in this pull request?
This fixes an error when a PySpark local iterator, for both RDD and DataFrames, goes out of scope and the connection is closed before fully consuming the iterator. The error occurs on the JVM in the serving thread, when Python closes the local socket while the JVM is writing to it. This usually happens when there is enough data to fill the socket read buffer, causing the write call to block.
Additionally, this fixes a problem when an error occurs in the Python worker and the collect job is cancelled with an exception. Previously, the Python driver was never notified of the error so the user could get a partial result (iteration until the error) and the application will continue. With this change, an error in the worker is sent to the Python iterator and is then raised.
The change here introduces a protocol for PySpark local iterators that work as follows:
END_OF_DATA_SECTION
is sent to Python which stops the deserializer and allows to make another requestSteps 1, 3, 5, 6 are the same as before. Step 8 was completely missing before because errors in the worker were never communicated back to Python. The other steps add synchronization to allow for a clean closing of the socket, with a small trade-off in performance for each partition. This is mainly because the JVM does not start collecting partition data until it receives a request to do so, where before it would eagerly write all data until the socket receive buffer is full.
How was this patch tested?
Added new unit tests for DataFrame and RDD
toLocalIterator
and tested not fully consuming the iterator. Manual tests with Python 2.7 and 3.6.