Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the bug where DataLoaderDispatcher gets stuck in an infinite wait… #1709

Merged
merged 1 commit into from
Jul 12, 2023

Conversation

yuxinyuan
Copy link
Contributor

Fix the bug where DataLoaderDispatcher gets stuck in an infinite wait when the dataset is an IterDataPipe during multi-process training.

In newer version of pytorch, the iterator of DataLoader will try to broadcast a shared seed across all distributed processes (see here). In the current implementation of DataLoaderDispatcher, the iterator is only created in the main process. This causes the training to hang when dataset is an IterDataPipe.

One can try out the script below to see the effect.

import accelerate
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.datapipes.iter import IterableWrapper


def print_rank_by_rank(*args, **kwargs):
    for rank in range(accelerator.num_processes):
        if rank == accelerator.process_index:
            print(*args, **kwargs)
        accelerator.wait_for_everyone()


if __name__ == "__main__":
    accelerator = accelerate.Accelerator()
    accelerator.print(accelerate.__version__)

    dp = IterableWrapper(range(21)).shuffle().map(lambda x: x + 1)

    loader = DataLoader(dp, batch_size=4, shuffle=True)
    loader = accelerator.prepare(loader)

    accelerator.print(type(loader))

    print_rank_by_rank(list(loader))
    print_rank_by_rank(list(loader))

… when the dataset is an IterDataPipe during multi-process training.
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 12, 2023

The documentation is not available anymore as the PR was closed or merged.

Comment on lines -498 to -501
# We can safely pass because the default is -1
with suppress(Exception):
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
self.remainder = length % self.total_batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this has been moved into __iter__ and not __init__?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't go through the rest of this repo to figure out how remainder is used. However, since self.reset() will always set it to -1, it just makes more sense (looking at data_loader.py) to follow DataLoaderShard and set the remainder in __iter__. Otherwise, it should be safe to remove this code completely.

Copy link
Collaborator

@muellerzr muellerzr Jul 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's used for Gradient accumulation. Here it makes sense to have it in __init__ as there's no need to calculate it twice

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I might be missing something, but, in __iter__, self.reset() will set self.remainder to -1. So, self.remainder won't be useful once we start iterating through the dataloader/dataset. Is that correct?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking more, it is also updated at the end of the iter loop. @muellerzr do we actually need those lines?

But while we investigate, I agree that it's safer to just copy this after the reset.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm... will look into this today.

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks! Left one comment on moving a chunk of code.

@muellerzr muellerzr requested a review from sgugger July 12, 2023 09:56
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR! Let's just keep the remainder logic in the init as mentioned by @muellerzr and we should be good to go.

Comment on lines -498 to -501
# We can safely pass because the default is -1
with suppress(Exception):
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
self.remainder = length % self.total_batch_size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed

@sgugger sgugger merged commit 518c206 into huggingface:main Jul 12, 2023
24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants