Skip to content

Commit

Permalink
Update BaseMegatronSampler for compatibility with PTL's `_BatchProg…
Browse files Browse the repository at this point in the history
…ress` (#11016)

* Revert "[NeMo-UX] Use custom `BatchProgress` class which does not restore states (#10383)"

This reverts commit b5798de.

* make megatron sampler return the total number of batches in the dataset

Signed-off-by: ashors1 <ashors@nvidia.com>

---------

Signed-off-by: ashors1 <ashors@nvidia.com>
  • Loading branch information
ashors1 authored Oct 25, 2024
1 parent 3f68018 commit 19766a2
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 29 deletions.
7 changes: 3 additions & 4 deletions nemo/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,17 +286,16 @@ def __init__(
)

def __len__(self):
num_available_samples: int = self.total_samples - self.consumed_samples
if self.global_batch_size is not None:
if self.drop_last:
num_global_batches = num_available_samples // self.global_batch_size
num_global_batches = self.total_samples // self.global_batch_size
else:
num_global_batches = (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
num_global_batches = (self.total_samples + self.global_batch_size - 1) // self.global_batch_size
# return len of dataloader in terms of micro batches to avoid discrepancy between len of dataloader and
# num of batches fetched (as training step fetches in terms of micro batches)
return num_global_batches * (self.global_batch_size // self.micro_batch_times_data_parallel_size)
else:
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1
return (self.total_samples - 1) // self.micro_batch_times_data_parallel_size + 1

@abc.abstractmethod
def __iter__(self): ...
Expand Down
8 changes: 0 additions & 8 deletions nemo/lightning/pytorch/strategies/fsdp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@

from nemo.lightning import io
from nemo.lightning.pytorch.strategies.utils import (
_MegatronBatchProgress,
ckpt_to_dir,
create_checkpoint_io,
fix_progress_bar,
Expand Down Expand Up @@ -74,15 +73,13 @@ def __init__(
ckpt_load_optimizer: bool = True,
ckpt_save_optimizer: bool = True,
data_sampler=None,
overwrite_batch_progress: bool = True,
**kwargs,
):
super().__init__(auto_wrap_policy=auto_wrap_policy, state_dict_type=state_dict_type, **kwargs)

self.data_sampler = data_sampler
self.ckpt_load_optimizer = ckpt_load_optimizer
self.ckpt_save_optimizer = ckpt_save_optimizer
self.overwrite_batch_progress = overwrite_batch_progress

@override
def setup_environment(self) -> None:
Expand All @@ -95,11 +92,6 @@ def setup(self, trainer: pl.Trainer) -> None:
self.trainer = trainer
setup_data_sampler(self.trainer)
fix_progress_bar(trainer)

trainer_fn = trainer.state.fn
if trainer_fn == TrainerFn.FITTING and self.overwrite_batch_progress:
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()

super().setup(trainer)

def _get_loss_reduction(self, step_type: str):
Expand Down
7 changes: 0 additions & 7 deletions nemo/lightning/pytorch/strategies/megatron_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
from nemo.lightning.pytorch.callbacks import ModelTransform
from nemo.lightning.pytorch.strategies.utils import (
RestoreConfig,
_MegatronBatchProgress,
ckpt_to_dir,
create_checkpoint_io,
fix_progress_bar,
Expand Down Expand Up @@ -160,8 +159,6 @@ class MegatronStrategy(DDPStrategy, io.IOMixin):
that prints the metrics to stdout. Suitable for non-interactive settings.
progress_interval (int): How frequently to print progress to stdout. Only used when
replace_progress_bar is True.
overwrite_batch_progress (bool): Whether to overwrite _BatchProgress class used in PTL by default with
_MegatronBatchProgress. This should be True whenever you're using a Megatron-based dataset.
**kwargs: Additional keyword arguments.
Note:
Expand Down Expand Up @@ -204,7 +201,6 @@ def __init__(
replace_progress_bar: bool = True,
progress_interval: int = 1,
restore_config: Optional[RestoreConfig] = None,
overwrite_batch_progress: bool = True,
**kwargs,
) -> None:
super().__init__(
Expand Down Expand Up @@ -245,7 +241,6 @@ def __init__(

self.replace_progress_bar = replace_progress_bar
self.progress_interval = progress_interval
self.overwrite_batch_progress = overwrite_batch_progress

self.restore_config = restore_config

Expand Down Expand Up @@ -345,8 +340,6 @@ def setup(self, trainer: pl.Trainer) -> None:
self.configure_ddp()

trainer.fit_loop.epoch_loop.automatic_optimization = _MegatronAutomaticOptimization(trainer)
if self.overwrite_batch_progress:
trainer.fit_loop.epoch_loop.batch_progress = _MegatronBatchProgress()

import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD

Expand Down
10 changes: 0 additions & 10 deletions nemo/lightning/pytorch/strategies/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,10 @@
from megatron.core.dist_checkpointing.strategies.torch import sharded_tensor_to_torch_sharded_tensor
from megatron.core.transformer.utils import _get_extra_state_offsets
from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.loops.progress import _BatchProgress
from pytorch_lightning.plugins.io.wrapper import _WrappingCheckpointIO
from torch.distributed._sharded_tensor import ShardedTensor as TorchShardedTensor
from torch.distributed._tensor import DTensor, Replicate, Shard
from torch.distributed.device_mesh import DeviceMesh
from typing_extensions import override

from nemo.lightning import _strategy_lib
from nemo.lightning.io.pl import MegatronCheckpointIO
Expand All @@ -48,14 +46,6 @@ class RestoreConfig:
load_artifacts: bool = True


class _MegatronBatchProgress(_BatchProgress):
@override
def load_state_dict(self, state_dict: dict) -> None:
## in megatron, we want to start the batch progress over when
## restoring from a checkpoint
return


def setup_parallel_ranks(strategy: pl.strategies.Strategy):
from megatron.core.model_parallel_config import ModelParallelConfig

Expand Down

0 comments on commit 19766a2

Please sign in to comment.