-
Notifications
You must be signed in to change notification settings - Fork 955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix the bug where DataLoaderDispatcher gets stuck in an infinite wait… #1709
Conversation
… when the dataset is an IterDataPipe during multi-process training.
The documentation is not available anymore as the PR was closed or merged. |
# We can safely pass because the default is -1 | ||
with suppress(Exception): | ||
length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) | ||
self.remainder = length % self.total_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this has been moved into __iter__
and not __init__
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't go through the rest of this repo to figure out how remainder
is used. However, since self.reset()
will always set it to -1, it just makes more sense (looking at data_loader.py) to follow DataLoaderShard
and set the remainder in __iter__
. Otherwise, it should be safe to remove this code completely.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's used for Gradient accumulation. Here it makes sense to have it in __init__
as there's no need to calculate it twice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, I might be missing something, but, in __iter__
, self.reset()
will set self.remainder
to -1. So, self.remainder
won't be useful once we start iterating through the dataloader/dataset. Is that correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking more, it is also updated at the end of the iter loop. @muellerzr do we actually need those lines?
But while we investigate, I agree that it's safer to just copy this after the reset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmmm... will look into this today.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, thanks! Left one comment on moving a chunk of code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR! Let's just keep the remainder logic in the init as mentioned by @muellerzr and we should be good to go.
# We can safely pass because the default is -1 | ||
with suppress(Exception): | ||
length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) | ||
self.remainder = length % self.total_batch_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed
Fix the bug where
DataLoaderDispatcher
gets stuck in an infinite wait when the dataset is anIterDataPipe
during multi-process training.In newer version of pytorch, the iterator of
DataLoader
will try to broadcast a shared seed across all distributed processes (see here). In the current implementation ofDataLoaderDispatcher
, the iterator is only created in the main process. This causes the training to hang when dataset is anIterDataPipe
.One can try out the script below to see the effect.