Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove duplicate code and check the remainder only once #1717

Merged
merged 2 commits into from
Jul 12, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@ def __iter__(self):
synchronize_rng_states(self.rng_types, self.synchronized_generator)
self.reset()
self.gradient_state._add_dataloader(self)
# We can safely pass because the default is -1
with suppress(Exception):
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
self.remainder = length % self.total_batch_size
if self.remainder == -1:
with suppress(Exception):
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
self.remainder = length % self.total_batch_size
dataloader_iter = super().__iter__()
# We iterate one batch ahead to check when we are at the end
try:
Expand Down Expand Up @@ -536,10 +536,10 @@ def _fetch_batches(self, iterator):
def __iter__(self):
self.reset()
self.gradient_state._add_dataloader(self)
# We can safely pass because the default is -1
with suppress(Exception):
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
self.remainder = length % self.total_batch_size
if self.remainder == -1:
with suppress(Exception):
length = getattr(self.dataset, "total_dataset_length", len(self.dataset))
self.remainder = length % self.total_batch_size
# NOTE PyTorch DataLoader adds forward compatibilities for DataPipes, which broadcasts
# shared seed to all dist processes. Thus, we need to create iterator for all dist processes.
# But, we only iterate through the DataLoader on process 0.
Expand Down Expand Up @@ -591,7 +591,6 @@ def __iter__(self):

if stop_iteration:
self.end_of_dataloader = True
self.remainder = observed_batch_size
if batch_index >= self.skip_batches:
yield batch
batch_index += 1
Expand Down
Loading