Skip to content

Commit

Permalink
Fix calculating num_available_samples (#11830)
Browse files Browse the repository at this point in the history
* Fix calculating num_available_samples

Signed-off-by: Huy Vu <86480512+huvunvidia@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: huvunvidia <huvunvidia@users.noreply.github.com>

---------

Signed-off-by: Huy Vu <86480512+huvunvidia@users.noreply.github.com>
Signed-off-by: huvunvidia <huvunvidia@users.noreply.github.com>
Co-authored-by: huvunvidia <huvunvidia@users.noreply.github.com>
  • Loading branch information
huvunvidia and huvunvidia authored Jan 13, 2025
1 parent e0c97aa commit abd4bf7
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,14 @@ def __len__(self) -> int:
When `rampup_batch_size` is enabled, the return value can be not exactly precise.
"""
num_available_samples: int = self.total_samples - self.consumed_samples
num_available_samples: int = self.total_samples - self.consumed_samples % self.total_samples
if self.drop_last:
return num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size

@abc.abstractmethod
def __iter__(self):
...
def __iter__(self): ...


class MegatronPretrainingBatchSampler(BaseMegatronBatchSampler):
Expand All @@ -151,7 +150,12 @@ def __iter__(self):
if len(batch) == self._global_batch_size:
# start_idx, end_idx = self.get_start_end_idx()
indices = [
batch[i] for i in range(self.data_parallel_rank, self._global_batch_size, self.data_parallel_size,)
batch[i]
for i in range(
self.data_parallel_rank,
self._global_batch_size,
self.data_parallel_size,
)
]
assert len(indices) == self._global_batch_size_on_this_data_parallel_rank
yield indices
Expand Down

0 comments on commit abd4bf7

Please sign in to comment.