Skip to content

Commit

Permalink
Fix(StaticBatchSampler): fix statics of num_consumed_samples_in_epoch…
Browse files Browse the repository at this point in the history
… for ckpt (#5)
  • Loading branch information
li126com authored Jan 24, 2024
1 parent e6eb75f commit 14e938c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion internlm/data/batch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,10 @@ def __iter__(self):
cur_batch_size = batch_rampup_idx * self.bsz_incre + self.start_bsz
cur_batch_size = min(cur_batch_size, self.batch_size)
batch = indices[self.num_consumed_samples_in_epoch : self.num_consumed_samples_in_epoch + cur_batch_size]
yield batch
self.num_consumed_samples_in_epoch += len(batch) # Consider multiple processes.
self.batch_count += 1
yield batch

self.get_indices() # get a new round

def state_dict(self):
Expand Down

0 comments on commit 14e938c

Please sign in to comment.