Skip to content

Commit

Permalink
Fix the bug where DataLoaderDispatcher gets stuck in an infinite wait…
Browse files Browse the repository at this point in the history
… when the dataset is an IterDataPipe during multi-process training. (#1709)

Co-authored-by: YU Xinyuan <yuxinyuan02@corp.netease.com>
  • Loading branch information
yuxinyuan and YU Xinyuan authored Jul 12, 2023
1 parent 65b5c2c commit 518c206
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,10 +495,6 @@ def __init__(self, dataset, split_batches: bool = False, skip_batches=0, _drop_l
self.state = AcceleratorState()
self._drop_last = _drop_last
self.skip_batches = skip_batches
# We can safely pass because the default is -1
with suppress(Exception):
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
self.remainder = length % self.total_batch_size

def _fetch_batches(self, iterator):
batches, batch = None, None
Expand Down Expand Up @@ -540,10 +536,14 @@ def _fetch_batches(self, iterator):
def __iter__(self):
self.reset()
self.gradient_state._add_dataloader(self)
main_iterator = None
if self.state.process_index == 0:
# We only iterate through the DataLoader on process 0.
main_iterator = super().__iter__()
# We can safely pass because the default is -1
with suppress(Exception):
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
self.remainder = length % self.total_batch_size
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
# But, we only iterate through the DataLoader on process 0.
main_iterator = super().__iter__()
stop_iteration = False
self._stop_iteration = False
first_batch = None
Expand Down

0 comments on commit 518c206

Please sign in to comment.