Skip to content

Commit

Permalink
Try with flipping order
Browse files Browse the repository at this point in the history
  • Loading branch information
muellerzr committed Mar 7, 2023
1 parent 2afa047 commit c56b897
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,8 +548,8 @@ def __iter__(self):
batch = slice_tensors(batch, data_slice)

if stop_iteration:
self.gradient_state._set_remainder(observed_batch_size)
self.gradient_state._remove_dataloader(self)
self.gradient_state._set_remainder(observed_batch_size)
if batch_index >= self.skip_batches:
yield batch
batch_index += 1
Expand Down
2 changes: 1 addition & 1 deletion src/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def __init__(self):
self.remainder = -1
self.active_dataloader = None
self.dataloader_references = [None]
self.remainder_references = [-1]

@property
def initialized(self) -> bool:
Expand All @@ -733,7 +734,6 @@ def _set_sync_gradients(self, sync_gradients):
def _set_end_of_dataloader(self, end_of_dataloader):
"Private function that sets whether the end of the current dataloader has been reached. Users should not have to call this."
self.end_of_dataloader = end_of_dataloader
self.remainder = -1

def _set_remainder(self, remainder):
"Private function that sets the number of remaining samples at the end of the dataloader. Users should not have to call this."
Expand Down
1 change: 0 additions & 1 deletion src/accelerate/test_utils/scripts/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def main():
if state.local_process_index == 0:
print("**Test `accumulate` gradient accumulation with dataloader break**")
test_dataloader_break()
accelerator.gradient_state._shared_state = {}
if state.distributed_type == DistributedType.NO:
if state.local_process_index == 0:
print("**Test NOOP `no_sync` context manager**")
Expand Down

0 comments on commit c56b897

Please sign in to comment.