diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 426fd01266a..10c56cee3e8 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -367,10 +367,10 @@ def __iter__(self): synchronize_rng_states(self.rng_types, self.synchronized_generator) self.reset() self.gradient_state._add_dataloader(self) - # We can safely pass because the default is -1 - with suppress(Exception): - length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) - self.remainder = length % self.total_batch_size + if self.remainder == -1: + with suppress(Exception): + length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) + self.remainder = length % self.total_batch_size dataloader_iter = super().__iter__() # We iterate one batch ahead to check when we are at the end try: @@ -537,10 +537,10 @@ def __iter__(self): self.reset() self.gradient_state._add_dataloader(self) main_iterator = None - # We can safely pass because the default is -1 - with suppress(Exception): - length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) - self.remainder = length % self.total_batch_size + if self.remainder == -1: + with suppress(Exception): + length = getattr(self.dataset, "total_dataset_length", len(self.dataset)) + self.remainder = length % self.total_batch_size if is_torch_version(">=", "2.0.1"): # NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts # shared seed to all dist processes. Thus, we need to create iterator for all dist processes. @@ -595,7 +595,6 @@ def __iter__(self): if stop_iteration: self.end_of_dataloader = True - self.remainder = observed_batch_size if batch_index >= self.skip_batches: yield batch batch_index += 1