diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index fd908f1002e..95fceccb191 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -495,10 +495,6 @@ def __init__(self, dataset, split_batches: bool = False, skip_batches=0, _drop_l self.state = AcceleratorState() self._drop_last = _drop_last self.skip_batches = skip_batches - # We can safely pass because the default is -1 - with suppress(Exception): - length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) - self.remainder = length % self.total_batch_size def _fetch_batches(self, iterator): batches, batch = None, None @@ -540,10 +536,14 @@ def _fetch_batches(self, iterator): def __iter__(self): self.reset() self.gradient_state._add_dataloader(self) - main_iterator = None - if self.state.process_index == 0: - # We only iterate through the DataLoader on process 0. - main_iterator = super().__iter__() + # We can safely pass because the default is -1 + with suppress(Exception): + length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) + self.remainder = length % self.total_batch_size + # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts + # shared seed to all dist processes. Thus, we need to create iterator for all dist processes. + # But, we only iterate through the DataLoader on process 0. + main_iterator = super().__iter__() stop_iteration = False self._stop_iteration = False first_batch = None