diff --git a/nemo/io/__init__.py b/nemo/io/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/io/pl.py b/nemo/io/pl.py new file mode 100644 index 000000000000..f6bf46557b43 --- /dev/null +++ b/nemo/io/pl.py @@ -0,0 +1,167 @@ +import logging +from pathlib import Path +from typing import Any, Callable, Dict, Optional, TypeVar, Union + +import lightning as L +import torch +from lightning.fabric.plugins.io.checkpoint_io import CheckpointIO +from lightning.fabric.utilities.cloud_io import get_filesystem +from lightning.fabric.utilities.types import _PATH +from torch import nn +from typing_extensions import override + + +log = logging.getLogger(__name__) + + +LightningModuleT = TypeVar("LightningModuleT", bound=L.LightningModule) +ModuleT = TypeVar("ModuleT", bound=nn.Module) + + +class MegatronCheckpointIO(CheckpointIO): + """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively, + common for most use cases. + + .. warning:: This is an :ref:`experimental ` feature. + + """ + + @override + def save_checkpoint(self, checkpoint: Dict[str, Any], path: _PATH, storage_options: Optional[Any] = None) -> None: + """Save model/training states as a checkpoint file through state-dump and file-write. + + Args: + checkpoint: dict containing model and trainer state + path: write-target path + storage_options: not used in ``TorchCheckpointIO.save_checkpoint`` + + Raises + ------ + TypeError: + If ``storage_options`` arg is passed in + + """ + from megatron.core import dist_checkpointing + + if storage_options is not None: + raise TypeError( + "`Trainer.save_checkpoint(..., storage_options=...)` with `storage_options` arg" + f" is not supported for `{self.__class__.__name__}`. Please implement your custom `CheckpointIO`" + " to define how you'd like to use `storage_options`." + ) + checkpoint_dir = ckpt_to_dir(path) + fs = get_filesystem(checkpoint_dir) + if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir): + logging.info(f'Distributed checkpoint at path {checkpoint_dir} already exists, skipping saving') + return + + fs.makedirs(checkpoint_dir, exist_ok=True) + dist_checkpointing.save(sharded_state_dict=checkpoint, checkpoint_dir=str(checkpoint_dir)) + + @override + def load_checkpoint( + self, path: _PATH, sharded_state_dict=None, map_location: Optional[Callable] = None + ) -> Dict[str, Any]: + """Loads checkpoint using :func:`torch.load`, with additional handling for ``fsspec`` remote loading of files. + + Args: + path: Path to checkpoint + map_location: a function, :class:`torch.device`, string or a dict specifying how to remap storage + locations. + + Returns: The loaded checkpoint. + + Raises + ------ + FileNotFoundError: If ``path`` is not found by the ``fsspec`` filesystem + + """ + from megatron.core import dist_checkpointing + + if map_location is not None: + raise ValueError("`map_location` argument is not supported for `MegatronCheckpointIO.load_checkpoint`.") + + # Try to read the checkpoint at `path`. If not exist, do not restore checkpoint. + fs = get_filesystem(path) + if not fs.exists(path): + raise FileNotFoundError(f"Checkpoint file not found: {path}") + if not fs.isdir(path): + raise ValueError(f"Distributed checkpoints should be a directory. Found: {path}.") + + # return pl_load(path, map_location=map_location) + + checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path)) + checkpoint = _fix_tensors_device(checkpoint) + + return checkpoint + + @override + def remove_checkpoint(self, path: _PATH) -> None: + """Remove checkpoint file from the filesystem. + + Args: + path: Path to checkpoint + + """ + fs = get_filesystem(path) + if fs.exists(path): + fs.rm(path, recursive=True) + log.debug(f"Removed checkpoint: {path}") + + +def _fix_tensors_device(ckpt: Dict) -> Dict: + """Ensure checkpoint tensors are on the correct device.""" + assert torch.cuda.is_initialized(), (torch.cuda.is_available(), torch.cuda.is_initialized()) + cur_dev = torch.device("cuda", index=torch.cuda.current_device()) + + from megatron.core.dist_checkpointing.dict_utils import dict_list_map_outplace + + def _fix_device(t): + if isinstance(t, torch.Tensor) and t.is_cuda and t.device != cur_dev: + t = t.to(cur_dev) + return t + + return dict_list_map_outplace(_fix_device, ckpt) + + +def ckpt_to_dir(filepath: Union[str, Path]) -> Path: + """PTL considers checkpoints as .ckpt files. + This method removes the extension and returns a path + to be used as a directory for distributed checkpoints. + """ + filepath = Path(filepath) + + if not filepath.suffix == ".ckpt": + filepath = filepath.with_suffix(filepath.suffix + ".ckpt") + + # adding this assert because we will later remove directories based on the return value of this method + assert filepath.suffix == ".ckpt", f"filepath: {filepath} must have .ckpt extension" + + # create a new path whose name is the original filepath without the .ckpt extension + checkpoint_dir = filepath.with_name(filepath.stem) + + return checkpoint_dir + + +def is_distributed_ckpt(path) -> bool: + """Check if the given path corresponds to a distributed checkpoint directory. + + This function determines if the specified path is a directory that contains a distributed + checkpoint by checking the directory's metadata. + + Args: + path (Union[str, Path]): The path to check for being a distributed checkpoint. + + Returns + ------- + bool: True if the path is a distributed checkpoint directory, False otherwise. + + """ + from megatron.core import dist_checkpointing + + checkpoint_dir = ckpt_to_dir(path) + fs = get_filesystem(checkpoint_dir) + if fs.isdir(checkpoint_dir) and dist_checkpointing.check_is_distributed_checkpoint(checkpoint_dir): + return True + + return False diff --git a/nemo/lightning/_strategy_lib.py b/nemo/lightning/_strategy_lib.py new file mode 100644 index 000000000000..e3f5f146ff12 --- /dev/null +++ b/nemo/lightning/_strategy_lib.py @@ -0,0 +1,438 @@ +import itertools +import os +from collections import defaultdict +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, Generator, Optional, Protocol, TypeVar + +import torch +from torch import nn + +NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE" + + +if TYPE_CHECKING: + from lightning.fabric.utilities.types import Optimizable + from megatron.core.model_parallel_config import ModelParallelConfig + + +class SharedStateDictProtocol(Protocol): + def sharded_state_dict(self, prefix=""): + ... + + +def init_parallel_ranks( + world_size: int, global_rank: int, local_rank: int, parallel_config: "ModelParallelConfig", seed=1234, fp8=False, +) -> None: + """ + Initializes the parallel ranks for distributed training. + + This function sets up the parallel ranks based on the provided world size, global rank, local rank, + and parallel configuration. It also sets the seed for random number generation and determines whether + to use fp8 precision. + + Args: + world_size (int): The total number of processes participating in the distributed training. + global_rank (int): The rank of the current process in the distributed training setup. + local_rank (int): The rank of the current process within its machine. + parallel_config (ModelParallelConfig): The configuration object containing settings for model parallelism. + seed (int, optional): The seed for random number generation. Defaults to 1234. + fp8 (bool, optional): Whether to use fp8 precision for model parameters. Defaults to False. + """ + from nemo.collections.nlp.modules.common.megatron.megatron_init import initialize_model_parallel_for_nemo + from nemo.utils import AppState + + app_state = AppState() + + if os.environ.get(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, "false").lower() == "true": + init_world_size = app_state.tensor_model_parallel_size * app_state.pipeline_model_parallel_size + init_global_rank = app_state.global_rank + init_local_rank = app_state.local_rank + else: + init_world_size = world_size + init_global_rank = global_rank + init_local_rank = local_rank + + initialize_model_parallel_for_nemo( + world_size=init_world_size, + global_rank=init_global_rank, + local_rank=init_local_rank, + tensor_model_parallel_size=parallel_config.tensor_model_parallel_size, + pipeline_model_parallel_size=parallel_config.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=parallel_config.virtual_pipeline_model_parallel_size, + seed=seed, + pipeline_model_parallel_split_rank=getattr(parallel_config, "pipeline_model_parallel_split_rank", None), + use_fp8=fp8, + init_mpi_proc_group=getattr(parallel_config, "ub_tp_comm_overlap", False), + # apex_transformer_log_level=self.cfg.get('apex_transformer_log_level', 30), + ) + + +def init_model_parallel(model: Optional[nn.Module] = None) -> None: + """Initializes Megatron-LM model parallel if using model parallelism.""" + import torch.distributed + from megatron.core import parallel_state + + from nemo.utils import AppState + + app_state = AppState() + + # we initialize megatron-lm model parallel and data parallel groups + # after initializing DDP with PTL. + if app_state.model_parallel_size is not None: + # destroy groups in case they have already been created + # this happens with multiple calls to trainer.test for example + parallel_state.destroy_model_parallel() + if torch.distributed.is_initialized(): + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=app_state.tensor_model_parallel_size, + pipeline_model_parallel_size=app_state.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=app_state.virtual_pipeline_model_parallel_size, + pipeline_model_parallel_split_rank=app_state.pipeline_model_parallel_split_rank, + ) + + # assert that fake tp and pp rank match after model parallel init + assert app_state.tensor_model_parallel_rank == parallel_state.get_tensor_model_parallel_rank() + assert app_state.pipeline_model_parallel_rank == parallel_state.get_pipeline_model_parallel_rank() + + app_state.tensor_model_parallel_group = parallel_state.get_tensor_model_parallel_group() + app_state.data_parallel_group = parallel_state.get_data_parallel_group() + app_state.data_parallel_rank = parallel_state.get_data_parallel_rank() + app_state.data_parallel_size = parallel_state.get_data_parallel_world_size() + app_state.pipeline_model_parallel_group = parallel_state.get_pipeline_model_parallel_group() + + # create MPI process group for UCX-based communication APIs + if app_state.init_mpi_proc_group: + torch.distributed.new_group(backend="mpi") + + if model: + # Set TP group + # Deep iterate but skip self to avoid infinite recursion. + for index, child in enumerate(model.modules()): + if index == 0: + continue + if hasattr(child, "set_tensor_parallel_group"): + tp_group = parallel_state.get_tensor_model_parallel_group() + child.set_tensor_parallel_group(tp_group) + + +@contextmanager +def megatron_lazy_init_context(config) -> Generator[None, None, None]: + def monkey_patched(c): + return {"device": "meta"} + + from megatron.core.transformer.custom_layers import transformer_engine as _te + + original = _te._get_extra_te_kwargs # noqa: SLF001 + _te._get_extra_te_kwargs = monkey_patched # noqa: SLF001 + + _orig_perform_initialization = config.perform_initialization + _orig_use_cpu_initialization = config.use_cpu_initialization + + config.perform_initialization = False + config.use_cpu_initialization = True + + yield + + _te._get_extra_te_kwargs = original # noqa: SLF001 + config.perform_initialization = _orig_perform_initialization + config.use_cpu_initialization = _orig_use_cpu_initialization + + +@contextmanager +def megatron_cpu_init_context(config) -> Generator[None, None, None]: + _orig_use_cpu_initialization = config.use_cpu_initialization + + config.use_cpu_initialization = True + + yield + + config.use_cpu_initialization = _orig_use_cpu_initialization + + +ModelT = TypeVar("ModelT", bound=nn.Module) + + +class GradScaler(torch.cuda.amp.GradScaler): + """ + Gradient sclaer for model-parallel inf check. The inf in gradients are checked across tensor-parallel + ranks in (1) executing optimizer step and (2) gradient scaler update. + + """ + + def __init__( + self, + init_scale=2.0 ** 16, + growth_factor=2.0, + backoff_factor=0.5, + growth_interval=2000, + enabled=True, + hysteresis=1, + ): + super().__init__( + init_scale=init_scale, + growth_factor=growth_factor, + backoff_factor=backoff_factor, + growth_interval=growth_interval, + enabled=enabled, + ) + self.optimizer_update_skipped: Optional[bool] = None + self.hysteresis = hysteresis + self._hysteresis_tracker = self.hysteresis + + def _unscale_grads_(self, optimizer, *args): + if getattr(optimizer, "_custom_amp_unscale_grads", False): + return optimizer.unscale_grads(*args) + else: + return super()._unscale_grads_(optimizer, *args) + + def _maybe_opt_step(self, optimizer, optimizer_state, *args, **kwargs): + from megatron.core import parallel_state + + retval = None + found_inf = torch.cuda.FloatTensor([sum(v.item() for v in optimizer_state["found_inf_per_device"].values())]) + + # Update across all model parallel instances. + torch.distributed.all_reduce( + found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(), + ) + + if found_inf.item() == 0: + retval = optimizer.step(*args, **kwargs) + self.optimizer_update_skipped = False + else: + self.optimizer_update_skipped = True + return retval + + def update(self, new_scale=None): + """ + Updates to native grad scaler update function. + 1. Check inf across model-parallel ranks. + 2. Update hysteresis tracker. + 3. Apply hysteresis to grad scale update. + """ + from megatron.core import parallel_state + + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = ( + "new_scale should be a float or a 1-element torch.cuda.FloatTensor with" " requires_grad=False." + ) + assert isinstance(new_scale, torch.cuda.FloatTensor), reason # type: ignore[attr-defined] + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + + # Update across all model parallel instances. + torch.distributed.all_reduce( + found_inf_combined, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(), + ) + + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf = found_infs[i] + # Update across all model parallel instances. + torch.distributed.all_reduce( + found_inf, op=torch.distributed.ReduceOp.MAX, group=parallel_state.get_model_parallel_group(), + ) + found_inf_combined += found_inf + + if found_inf_combined > 0: + self._hysteresis_tracker -= 1 + if self._hysteresis_tracker <= 0: + # When hysteresis becomes zero, follow the native grad scale update rule. + # Increase scale and reset growth tracker + torch._amp_update_scale_( # noqa: SLF001 + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + else: + # Only reset the growth tracker when hysteresis is larger than zero + _growth_tracker.fill_(0.0) + else: + # When no inf found, follow the native grad scale update rule. + # Increment growth_tracker, update scale when growth tracker reaches the interval, and + # reset the hysteresis tracker. + torch._amp_update_scale_( # noqa: SLF001 + _scale, + _growth_tracker, + found_inf_combined, + self._growth_factor, + self._backoff_factor, + self._growth_interval, + ) + self._hysteresis_tracker = self.hysteresis + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict( + torch.cuda.amp.grad_scaler._refresh_per_optimizer_state # noqa: SLF001 + ) + + def state_dict(self): + """ + Add hysteresis_tracker to the native functions' state_dict. + """ + return ( + { + "scale": self.get_scale(), + "growth_factor": self._growth_factor, + "backoff_factor": self._backoff_factor, + "growth_interval": self._growth_interval, + "_growth_tracker": self._get_growth_tracker(), + "_hysteresis_tracker": self._hysteresis_tracker, + } + if self._enabled + else {} + ) + + def load_state_dict(self, state_dict): + """ + Load hysteresis_tracker in addition to the state dict of the native function. + """ + if not self._enabled: + return + + if len(state_dict) == 0: + raise RuntimeError( + "The source state dict is empty, possibly because it was saved " + "from a disabled instance of GradScaler." + ) + + self._init_scale = state_dict["scale"] + if self._scale is not None: + self._scale.fill_(state_dict["scale"]) + self._growth_factor = state_dict["growth_factor"] + self._backoff_factor = state_dict["backoff_factor"] + self._growth_interval = state_dict["growth_interval"] + self._init_growth_tracker = state_dict["_growth_tracker"] + if self._growth_tracker is not None: + self._growth_tracker.fill_(state_dict["_growth_tracker"]) + if "_hysterisis_tracker" in state_dict: + self._hysteresis_tracker = state_dict["_hysterisis_tracker"] + else: + self._hysteresis_tracker = 1 + + +def enable_nvidia_optimizations() -> None: + """These optimizations are present in NVIDIA NGC PyTorch Containers.""" + # NVIDIA container version check + nvidia_torch_version = os.getenv("NVIDIA_PYTORCH_VERSION", None) + if nvidia_torch_version is not None: + try: + NVIDIA_TORCH_MAJOR = int(nvidia_torch_version.split(".")[0]) + except Exception: + NVIDIA_TORCH_MAJOR = 0 + try: + NVIDIA_TORCH_MINOR = int(nvidia_torch_version.split(".")[1]) + except Exception: + NVIDIA_TORCH_MINOR = 0 + + # NVFUSER available starting with 21.11 + if NVIDIA_TORCH_MAJOR >= 21 or (NVIDIA_TORCH_MAJOR == 21 and NVIDIA_TORCH_MINOR >= 11): + # NVFUSER + torch._C._jit_set_profiling_executor(True) # noqa: SLF001 + torch._C._jit_set_profiling_mode(True) # noqa: SLF001 + torch._C._jit_override_can_fuse_on_cpu(False) # noqa: SLF001 + torch._C._jit_override_can_fuse_on_gpu(False) # noqa: SLF001 + torch._C._jit_set_texpr_fuser_enabled(False) # noqa: SLF001 + # torch._C._jit_set_nvfuser_enabled(True) + torch._C._debug_set_autodiff_subgraph_inlining(False) # noqa: SLF001 + else: + # Not a Nvidia container. NVFUSER Dependency check is on users + pass + + +def optimizer_sharded_state_dict(model: SharedStateDictProtocol, optimizer: "Optimizable") -> Dict[str, torch.Tensor]: + """ + Sharded state dictionary for an MainParamsOptimizerWrapper. + Used to save and load the optimizer state when training with distributed_checkpoint. + + Returns + ------- + dict: The sharded state dictionary for the optimizer + Raises: + ValueError: If a parameter ID does not match any model sharded parameter. + """ + from megatron.core.dist_checkpointing.optimizer import ( + get_param_id_to_sharded_param_map, + make_sharded_optimizer_tensor, + optim_state_to_sharding_state, + ) + + from nemo.core.optim import MainParamsOptimizerWrapper + from nemo.core.optim.optimizers import init_optimizer_states + + model_sharded_state_dict = model.sharded_state_dict() + + # remove _extra_state + model_sharded_state_dict = { + key: value for key, value in model_sharded_state_dict.items() if not key.endswith("_extra_state") + } + + if hasattr(optimizer, "sharded_state_dict"): + return optimizer.sharded_state_dict(model_sharded_state_dict) + + if not isinstance(optimizer, MainParamsOptimizerWrapper): + # Regular optimizer, e.g. Adam or FusedAdam + init_optimizer_states(optimizer) + optimizer_state_dict = optimizer.state_dict() + id_to_sharded_param_map = get_param_id_to_sharded_param_map( + model_sharded_state_dict=model_sharded_state_dict, + optim_params_iter=itertools.chain.from_iterable(g['params'] for g in optimizer.param_groups), + ) + optim_state_to_sharding_state(optimizer_state_dict, id_to_sharded_param_map) + return optimizer_state_dict + + optimizer_state_dict: Dict[str, Any] = optimizer.state_dict() + + id_to_sharded_param_map = get_param_id_to_sharded_param_map( + model_sharded_state_dict=model_sharded_state_dict, + optim_params_iter=itertools.chain.from_iterable(g for g in optimizer.float16_groups), + ) + + # Convert fp32_from_fp16_params + assert len(optimizer_state_dict["fp32_from_fp16_params"]) == len(optimizer_state_dict["optimizer"]["param_groups"]) + + def get_safe(param_id): + try: + return id_to_sharded_param_map[param_id] + except KeyError as e: + raise ValueError(f"Param id {param_id} does not match any model sharded param") from e + + optimizer_state_dict["fp32_from_fp16_params"] = [ + [ + make_sharded_optimizer_tensor(get_safe(param_id), fp32_param, prefix="optimizer.state.fp32_param") + for param_id, fp32_param in zip(state_group["params"], fp32_group) + ] + for fp32_group, state_group in zip( + optimizer_state_dict["fp32_from_fp16_params"], optimizer_state_dict["optimizer"]["param_groups"], + ) + ] + + # Convert state + optim_state_to_sharding_state(optimizer_state_dict["optimizer"], id_to_sharded_param_map) + + return optimizer_state_dict diff --git a/nemo/lightning/pytorch/__init__.py b/nemo/lightning/pytorch/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/nemo/lightning/pytorch/callbacks/__init__.py b/nemo/lightning/pytorch/callbacks/__init__.py new file mode 100644 index 000000000000..fcceedeb7090 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/__init__.py @@ -0,0 +1,3 @@ +from nemo_ext.lightning.pytorch.callbacks.progress import MegatronProgressBar + +__all__ = ["MegatronProgressBar"] diff --git a/nemo/lightning/pytorch/callbacks/progress.py b/nemo/lightning/pytorch/callbacks/progress.py new file mode 100644 index 000000000000..9d4d9b385da8 --- /dev/null +++ b/nemo/lightning/pytorch/callbacks/progress.py @@ -0,0 +1,67 @@ +from pytorch_lightning.callbacks.progress import TQDMProgressBar +from pytorch_lightning.callbacks.progress.tqdm_progress import _update_n + + +class MegatronProgressBar(TQDMProgressBar): + """ + Add MegatronProgressBar to remove 's/it' and display progress per step instead of per microbatch + for megatron models. + """ + + def get_current_epoch_step(self, trainer) -> int: + """ + Get the value of step within an epoch. + """ + return max( + trainer.fit_loop.epoch_loop.automatic_optimization.optim_progress.optimizer.step.current.completed, + trainer.fit_loop.epoch_loop.manual_optimization.optim_step_progress.current.completed, + ) + + def init_train_tqdm(self): + """ + Override bar_format to not have 's/it'. + """ + self.bar = super().init_train_tqdm() + self.bar.bar_format = "{desc}: {percentage:3.0f}%|{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}{postfix}]" + return self.bar + + def on_train_epoch_start(self, trainer, *_): + if trainer.max_steps > 0 and (trainer.ckpt_path is not None): + # while resuming from a ckpt use trainer.max_steps as the total for progress bar as trainer.num_training_batches + # is truncated to max_steps - step being resumed at + num_training_batches = trainer.max_steps + else: + num_training_batches = trainer.num_training_batches + + # from nemo.utils import AppState + # app_state = AppState() + # app_state. + + num_training_batches = num_training_batches // calculate_data_parallel_groups() + + self.train_progress_bar.reset(num_training_batches) + self.train_progress_bar.initial = 0 + self.train_progress_bar.set_description(f"Epoch {trainer.current_epoch}") + + def on_train_batch_end(self, trainer, pl_module, *_, **__): + """ + Override parent class on_train_batch_end to update progress bar per global batch instead of per microbatch. + """ + n = self.get_current_epoch_step(trainer) + if self._should_update(n, self.train_progress_bar.total): + _update_n(self.train_progress_bar, n) + self.train_progress_bar.set_postfix(self.get_metrics(trainer, pl_module)) + + +def calculate_data_parallel_groups() -> int: + from nemo.utils import AppState + + app_state = AppState() + + pipeline_model_parallel_size = app_state.pipeline_model_parallel_size + tensor_model_parallel_size = app_state.tensor_model_parallel_size + + world_size = app_state.world_size + data_parallel_group_len = world_size // (pipeline_model_parallel_size * tensor_model_parallel_size) + + return world_size // data_parallel_group_len diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py new file mode 100644 index 000000000000..0fa386cb45ef --- /dev/null +++ b/nemo/lightning/pytorch/strategies.py @@ -0,0 +1,502 @@ +import functools +import logging +import shutil +from collections import OrderedDict +from contextlib import ExitStack +from pathlib import Path +from typing import Any, ContextManager, Dict, List, Mapping, Optional, TypeVar, Union, cast + +import lightning.pytorch as pl +import torch +import torch.distributed +from lightning.fabric.plugins import CheckpointIO, ClusterEnvironment +from lightning.fabric.utilities.optimizer import _optimizers_to_device +from lightning.pytorch.accelerators import CPUAccelerator +from lightning.pytorch.callbacks.progress import TQDMProgressBar +from lightning.pytorch.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop +from lightning.pytorch.loops.fetchers import _DataLoaderIterDataFetcher +from lightning.pytorch.overrides.distributed import _sync_module_states +from lightning.pytorch.plugins.io.wrapper import _WrappingCheckpointIO +from lightning.pytorch.strategies.ddp import DDPStrategy +from lightning.pytorch.trainer.states import RunningStage, TrainerFn +from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.types import STEP_OUTPUT +from torch import nn +from torch.distributed.algorithms.ddp_comm_hooks.debugging_hooks import noop_hook +from torch.nn.parallel import DistributedDataParallel +from torch.utils.data import DataLoader +from typing_extensions import override + +from nemo.io.pl import MegatronCheckpointIO +from nemo.lightning import _strategy_lib +from nemo.lightning.megatron_parallel import CallbackConnector, MegatronParallel, _ModuleStepFunction +from nemo.lightning.pytorch.callbacks import MegatronProgressBar + +ConfigT = TypeVar("ConfigT") + + +class MegatronStrategy(DDPStrategy): + """Megatron plugin for Pytorch Lightning. + + Args: + no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2 + with FP32 gradient accumulation. + """ + + trainer: pl.Trainer + + def __init__( + self, + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + virtual_pipeline_model_parallel_size: Optional[int] = None, + sequence_parallel: bool = False, + # data_sampler: Optional[DataSampler] = None, + parallel_devices: Optional[List[torch.device]] = None, + cluster_environment=None, # TODO: Add type-hint + checkpoint_io=None, # TODO: Add type-hint + no_ddp_communication_hook: bool = True, + find_unused_parameters: bool = False, + lazy_init: bool = False, + **kwargs, + ) -> None: + super().__init__( + parallel_devices, + cluster_environment, + checkpoint_io, + find_unused_parameters=find_unused_parameters, + **kwargs, + ) + self.no_ddp_communication_hook = no_ddp_communication_hook + self.megatron_callbacks = CallbackConnector() + # self.data_sampler: Optional[DataSampler] = data_sampler + self.tensor_model_parallel_size = tensor_model_parallel_size + self.pipeline_model_parallel_size = pipeline_model_parallel_size + self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size + self.sequence_parallel = sequence_parallel + self.lazy_init = lazy_init + + # used in NVIDIA NGC PyTorch containers + _strategy_lib.enable_nvidia_optimizations() + + @override + def connect(self, model: pl.LightningModule) -> None: + super().connect(model) + + # Right now mcore sub-classes ModelParellelConfig, we should remove that + # Given Lightning's structure it would be better if parallelism is a different object + # Since then it can be passed to the Strategy + + from megatron.core.transformer.transformer_config import TransformerConfig + + has_mcore_config = isinstance(getattr(model, "config", None), TransformerConfig) + if has_mcore_config and is_overridden("configure_model", model): + config: TransformerConfig = model.config + config.tensor_model_parallel_size = self.tensor_model_parallel_size + config.pipeline_model_parallel_size = self.pipeline_model_parallel_size + config.virtual_pipeline_model_parallel_size = self.virtual_pipeline_model_parallel_size + config.sequence_parallel = self.sequence_parallel + self._mcore_config = config + + @override + def setup(self, trainer: pl.Trainer) -> None: + assert self.accelerator is not None + self.accelerator.setup(trainer) + self.trainer = trainer + + # move the model to the correct device + # self.model_to_device() + + # skip wrapping the model if we are not fitting as no gradients need to be exchanged + trainer_fn = trainer.state.fn + + if trainer_fn == TrainerFn.FITTING and self._layer_sync: + assert self.model is not None + self.model = self._layer_sync.apply(self.model) + + datamodule = getattr(trainer, "datamodule", None) + if not self.data_sampler and hasattr(datamodule, "data_sampler"): + self.data_sampler = datamodule.data_sampler + self.data_sampler.setup(self.cluster_environment.global_rank()) + + if self.data_sampler: + self.data_sampler.connect(trainer) + + self._fix_progress_bar(trainer) + self.setup_megatron_parallel(trainer) + self.setup_precision_plugin() + + if trainer.num_sanity_val_steps > 1 and self.pipeline_model_parallel_size > 1: + # TODO: log here + trainer.num_sanity_val_steps = 0 + + for loop in [fit_loop, evaluation_loop, prediction_loop]: + loop._select_data_fetcher = _data_fetcher_wrapper(loop._select_data_fetcher) # noqa: SLF001 + + if trainer_fn == TrainerFn.FITTING: + # TODO: Make sure we don't always wrap the model in data-parallel + # See: https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/parts/nlp_overrides.py#L215-L217 + + # do not wrap with DDP if not fitting as there's no gradients to reduce + self.configure_ddp() + + trainer.fit_loop.epoch_loop.automatic_optimization = _MegatronAutomaticOptimization(trainer) + + # set up optimizers after the wrapped module has been moved to the device + self.setup_optimizers(trainer) + if hasattr(self.precision_plugin, "convert_optimizer"): + _optimizers = [*self.optimizers] + _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0]) + self.optimizers = _optimizers + + _optimizers_to_device(self.optimizers, self.root_device) + + import torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook as post_localSGD + + if isinstance(self._ddp_comm_state, post_localSGD.PostLocalSGDState): + self._enable_model_averaging() + else: + # we need to manually synchronize the module's states since we aren't using the DDP wrapper + assert self.model is not None + _sync_module_states(self.model) + + @override + def setup_distributed(self) -> None: + self._setup_parallel_ranks() + super().setup_distributed() + + from megatron.core import parallel_state + from nemo.utils import AppState + + # init model parallel if needed + if not parallel_state.model_parallel_is_initialized(): + app_state = AppState() + + if app_state.model_parallel_size is not None: + _strategy_lib.init_model_parallel(self.model) + + if self.data_sampler: + assert isinstance(self.cluster_environment, ClusterEnvironment), "Cluster environment not initialized" + self.data_sampler.setup(self.cluster_environment.global_rank()) + + @override + def process_dataloader(self, dataloader: DataLoader) -> DataLoader: + if self.data_sampler: + return self.data_sampler.transform_dataloader(dataloader) + + return dataloader + + def setup_megatron_parallel(self, trainer: pl.Trainer) -> None: + assert self.model is not None, "Model is not set" + + self.megatron_parallel = MegatronParallel( + self.model, + precision_plugin=self.precision_plugin, + vp_size=self.virtual_pipeline_model_parallel_size, + cpu=isinstance(trainer.accelerator, CPUAccelerator), + ) + self.model = self.megatron_parallel + self.model.trainer = trainer + + if hasattr(self.precision_plugin, "convert_module"): + self.model = self.precision_plugin.convert_module(self.model) + self.model.callbacks.add(getattr(trainer, "callbacks")) + + if self.data_sampler: + self.model.callbacks.add(self.data_sampler) + + datamodule = getattr(trainer, "datamodule", None) + if datamodule: + self.model.callbacks.add(datamodule) + + @override + def configure_ddp(self) -> None: + logging.debug(f"{self.__class__.__name__}: configuring MegatronParallel") + self.model = self._setup_model(self.model) + self._register_ddp_hooks() + + @override + def _setup_model(self, model: nn.Module) -> DistributedDataParallel: + """Only called when we need to wrap the model for pytorch's ddp.""" + from megatron.core import parallel_state + from nemo.utils import AppState + + app_state = AppState() + if app_state.model_parallel_size is not None: + self._ddp_kwargs["process_group"] = parallel_state.get_data_parallel_group() + + dist_data_parallel: DistributedDataParallel = super()._setup_model(model) + if self.no_ddp_communication_hook: + # When using custom gradient accumulation and allreduce, disable + # DDP communication hook that works on the gradient bucket. + # Instead, use the custom gradient function and communication hook, + # which is defined in the master optimizer wrapper. + dist_data_parallel.require_backward_grad_sync = False + dist_data_parallel.register_comm_hook(None, noop_hook) + + return dist_data_parallel + + def _setup_parallel_ranks(self) -> None: + self.set_world_ranks() + env = cast(ClusterEnvironment, self.cluster_environment) + + _strategy_lib.init_parallel_ranks(env.world_size(), env.global_rank(), env.local_rank(), self.parallelism) + + @override + def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.lightning_module is not None + assert self.model is not None + kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "training") + + with self.precision_plugin.train_step_context(): # TODO: Do we need this? + return self.model(dataloader_iter, *args, **kwargs) + + @override + def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.lightning_module is not None + assert self.model is not None + kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "validation") + + with self.precision_plugin.val_step_context(): # TODO: Do we need this? + return self.model(dataloader_iter, *args, **kwargs) + + @override + def test_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.lightning_module is not None + assert self.model is not None + kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "test") + + with self.precision_plugin.test_step_context(): # TODO: Do we need this? + return self.model(dataloader_iter, *args, **kwargs) + + @override + def predict_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.lightning_module is not None + assert self.model is not None + kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "predict") + + with self.precision_plugin.predict_step_context(): # TODO: Do we need this? + return self.model(dataloader_iter, *args, **kwargs) + + @override + def teardown(self) -> None: + super().teardown() + + @override + def model_sharded_context(self) -> ContextManager: + if self.lazy_init and hasattr(self, "_mcore_config"): + stack = ExitStack() + stack.enter_context(_strategy_lib.megatron_lazy_init_context(self._mcore_config)) + return stack + + return super().model_sharded_context() + + def _update_step_kwargs(self, dataloader_iter, kwargs, step_name: str): + if "data_step" not in kwargs: + kwargs["data_step"] = self._get_data_step(step_name) + if "forward_step" not in kwargs: + kwargs["forward_step"] = self._get_forward_step(step_name) + if "loss_reduction" not in kwargs: + kwargs["loss_reduction"] = self._get_loss_reduction(step_name) + kwargs.update(self._data_config_kwargs(dataloader_iter)) + + return kwargs + + def _fix_progress_bar(self, trainer: pl.Trainer) -> None: + callbacks: List[pl.Callback] = cast(List[pl.Callback], getattr(trainer, "callbacks")) + contains_megatron_progress, contains_progress = False, False + for callback in callbacks: + if isinstance(callback, MegatronProgressBar): + contains_megatron_progress = True + if callback.__class__ == TQDMProgressBar: + contains_progress = True + if not contains_megatron_progress and contains_progress: + for callback in callbacks: + if isinstance(callback, TQDMProgressBar): + callback.__class__ = MegatronProgressBar + break + + def optimizer_sharded_state_dict(self): + """ + Sharded state dictionary for an MainParamsOptimizerWrapper. + Used to save and load the optimizer state when training with distributed_checkpoint. + + Returns + ------- + dict: The sharded state dictionary for the optimizer + Raises: + ValueError: If a parameter ID does not match any model sharded parameter. + """ + # TODO: Fix when MainParamsOptimizerWrapper is not used + + optimizer = self.lightning_module.optimizers(use_pl_optimizer=False) + + return _strategy_lib.optimizer_sharded_state_dict(self.megatron_parallel, optimizer) + + @override + def save_checkpoint( + self, checkpoint: Dict[str, Any], filepath: Union[str, Path], storage_options: Optional[Any] = None + ) -> None: + checkpoint['state_dict'] = OrderedDict([]) # remove device state_dict + checkpoint['sharded_state_dict'] = self.megatron_parallel.sharded_state_dict() + if self.trainer.state.fn == TrainerFn.FITTING: + checkpoint['optimizer_states'] = [self.optimizer_sharded_state_dict()] + + self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) + + @override + def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: + """PTL method which we override to integrate distributed checkpoints for model parallel models. + In order to load distributed checkpoints we need to provide the sharded_state_dict to + the distributed load function. We get the sharded_state_dict from self.lightning_module + which makes it convenient to have the loading logic happen at the strategy level. + """ + torch.cuda.empty_cache() + + # After dist_checkpointing.load, sharded tensors will be replaced with tensors + sharded_state_dict = {} + sharded_state_dict["state_dict"] = self.megatron_parallel.sharded_state_dict() + + # if self.trainer.state.fn == TrainerFn.FITTING: + # if self.lightning_module.optimizers(use_pl_optimizer=False): + # sharded_state_dict["optimizer_states"] = [self.optimizer_sharded_state_dict()] + + checkpoint = self.checkpoint_io.load_checkpoint(checkpoint_path, sharded_state_dict=sharded_state_dict) + + return checkpoint + + def remove_checkpoint(self, filepath: Union[str, Path]) -> None: + if self.is_global_zero: + shutil.rmtree(ckpt_to_dir(filepath)) + + def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: + assert self.megatron_parallel is not None + from megatron.core import mpu + + for index, module in enumerate(self.megatron_parallel): + if mpu.get_virtual_pipeline_model_parallel_world_size() is not None: + checkpoint_state_dict = checkpoint['state_dict'][f'model_{index}'] + else: + checkpoint_state_dict = checkpoint['state_dict'] + # checkpoint_state_dict has "model." but module does not so we need to remove it when loading + checkpoint_state_dict = { + key.replace('model.', ''): checkpoint_state_dict.pop(key) for key in list(checkpoint_state_dict.keys()) + } + module.load_state_dict(checkpoint_state_dict, strict=strict) + + @property + @override + def checkpoint_io(self) -> CheckpointIO: + if self._checkpoint_io is None: + self._checkpoint_io = MegatronCheckpointIO() + elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): + self._checkpoint_io.checkpoint_io = MegatronCheckpointIO() + + return self._checkpoint_io + + def _get_data_step(self, step_type: str) -> Optional[_ModuleStepFunction]: + for fn_name in [f"{step_type}_data_step", "data_step"]: + if hasattr(self.lightning_module, fn_name): + return _ModuleStepFunction(fn_name) + + return None + + def _get_forward_step(self, step_type: str) -> Optional[_ModuleStepFunction]: + from megatron.core import mpu + + if mpu.is_pipeline_last_stage(): + if not hasattr(self.lightning_module, f"{step_type}_step"): + raise ValueError(f"LightningModule does not have {step_type}_step method") + + return _ModuleStepFunction(f"{step_type}_step", includes_self=True) + + for fn_name in [f"{step_type}_forward_step", "forward_step"]: + if hasattr(self.lightning_module, fn_name): + return _ModuleStepFunction(fn_name, includes_self=True) + + return None + + def _get_loss_reduction(self, step_type: str) -> Optional[_ModuleStepFunction]: + for fn_name in [f"{step_type}_loss_reduction", "loss_reduction"]: + if hasattr(self.lightning_module, fn_name): + return _ModuleStepFunction(fn_name, is_property=True) + + return None + + def _data_config_kwargs(self, dataloader_iter) -> Dict[str, Any]: + if not hasattr(dataloader_iter, "data_config") and self.data_sampler: + if hasattr(self.data_sampler, "megatron_data_kwargs"): + return self.data_sampler.megatron_data_kwargs + + return {} + + @property + def distributed_sampler_kwargs(self) -> Dict[str, Any]: + from nemo.utils import AppState + + app_state = AppState() + if app_state.model_parallel_size is not None: + # When using model parallel, data parallel groups are non-trivial and they + # correspond to the logical GPUs. This means that the GPUs that form a + # single logical GPU all need to get the same batch of data. + distributed_sampler_kwargs = dict( + num_replicas=app_state.data_parallel_size, rank=app_state.data_parallel_rank + ) + return distributed_sampler_kwargs + + else: + return super().distributed_sampler_kwargs + + @property + def restore_checkpoint_after_setup(self) -> bool: + """Needs to be True for distributed checkpointing because + we require the model to have configured the optimizer before + deserializing the checkpoint. + """ + return True + + @property + def parallelism(self): + from megatron.core.model_parallel_config import ModelParallelConfig + + return ModelParallelConfig( + tensor_model_parallel_size=self.tensor_model_parallel_size, + pipeline_model_parallel_size=self.pipeline_model_parallel_size, + virtual_pipeline_model_parallel_size=self.virtual_pipeline_model_parallel_size, + sequence_parallel=self.sequence_parallel, + ) + + +def ckpt_to_dir(filepath: Union[str, Path]) -> Path: + """PTL considers checkpoints as .ckpt files. + This method removes the extension and returns a path + to be used as a directory for distributed checkpoints. + """ + filepath = Path(filepath) + + if filepath.suffix == ".ckpt": + return filepath.with_name(filepath.stem) + + return filepath + + +def _data_fetcher_wrapper(fn): + @functools.wraps(fn) + def wrapped(trainer: pl.Trainer, stage: RunningStage): + if isinstance(trainer.strategy, MegatronStrategy): + return _DataLoaderIterDataFetcher() + + return fn(trainer, stage) + + return wrapped + + +class _MegatronAutomaticOptimization(_AutomaticOptimization): + """ + Custom loop for automatic optimization, tailored to work with a specific training_step + implementation that involves custom data preparation, forward pass, and loss reduction steps. + """ + + def __init__(self, trainer: "pl.Trainer") -> None: + super().__init__(trainer) + self._skip_backward = True # megatron will do the backward pass diff --git a/tests/lightning/test_megatron_parallel.py b/tests/lightning/test_megatron_parallel.py index cac568747331..06e614d48251 100644 --- a/tests/lightning/test_megatron_parallel.py +++ b/tests/lightning/test_megatron_parallel.py @@ -52,8 +52,8 @@ def mock_loss_reduction(self, mocker): def test_init_with_defaults(self, mocker, mock_pipeline): """Test __init__ with default parameters.""" - mocker.patch('megatron.core.mpu.get_pipeline_model_parallel_world_size', return_value=1) - mocker.patch('megatron.core.mpu.model_parallel_is_initialized', return_value=False) + mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1) + mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False) megatron_parallel = mp.MegatronParallel(pipeline=mock_pipeline) @@ -76,8 +76,8 @@ def test_init_with_defaults(self, mocker, mock_pipeline): # mock_loss_reduction # ): # """Test __init__ with custom parameters.""" - # mocker.patch('megatron.core.mpu.get_pipeline_model_parallel_world_size', return_value=1) - # mocker.patch('megatron.core.mpu.model_parallel_is_initialized', return_value=False) + # mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=1) + # mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=False) # # megatron_parallel = mp.MegatronParallel( # pipeline=mock_pipeline, @@ -99,20 +99,20 @@ def test_init_with_defaults(self, mocker, mock_pipeline): # def test_init_with_virtual_pipeline(self, mocker, mock_pipeline): # """Test __init__ with virtual pipeline model parallel world size.""" # mocker.patch('torch.distributed.get_rank', return_value=1) - # mocker.patch('megatron.core.mpu.get_tensor_model_parallel_group', return_value=1) - # mocker.patch('megatron.core.mpu.get_pipeline_model_parallel_group', return_value=1) - # mocker.patch('megatron.core.mpu.get_pipeline_model_parallel_world_size', return_value=2) - # mocker.patch('megatron.core.mpu.model_parallel_is_initialized', return_value=True) - # mocker.patch('megatron.core.mpu.set_virtual_pipeline_model_parallel_world_size') - # mocker.patch('megatron.core.mpu.set_virtual_pipeline_model_parallel_rank') + # mocker.patch('megatron.core.parallel_state.get_tensor_model_parallel_group', return_value=1) + # mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_group', return_value=1) + # mocker.patch('megatron.core.parallel_state.get_pipeline_model_parallel_world_size', return_value=2) + # mocker.patch('megatron.core.parallel_state.model_parallel_is_initialized', return_value=True) + # mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_world_size') + # mocker.patch('megatron.core.parallel_state.set_virtual_pipeline_model_parallel_rank') # mocker.patch('nemo_ext.lightning._strategy_lib.init_lightning_module', return_value=mock_pipeline) # megatron_parallel = mp.MegatronParallel(mock_pipeline, vp_size=2) # assert len(megatron_parallel.pipeline) == 2 # assert all(isinstance(mod, nn.Module) for mod in megatron_parallel.pipeline) - # megatron.core.mpu.set_virtual_pipeline_model_parallel_world_size.assert_called_once_with(2) - # assert megatron.core.mpu.set_virtual_pipeline_model_parallel_rank.call_count == 1 + # megatron.core.parallel_state.set_virtual_pipeline_model_parallel_world_size.assert_called_once_with(2) + # assert megatron.core.parallel_state.set_virtual_pipeline_model_parallel_rank.call_count == 1 class TestCallbackConnector: diff --git a/tests/lightning/test_strategy_lib.py b/tests/lightning/test_strategy_lib.py new file mode 100644 index 000000000000..96f5f2920bcf --- /dev/null +++ b/tests/lightning/test_strategy_lib.py @@ -0,0 +1,211 @@ +from unittest.mock import ANY, MagicMock, patch + +from torch import nn + +from nemo.lightning import _strategy_lib # , DataConfig + + +class Identity(nn.Identity): + def __init__(self): + super().__init__() + + +class WithCopy(nn.Identity): + def copy(self): + return WithCopy() + + +@patch('nemo.collections.nlp.modules.common.megatron.megatron_init.initialize_model_parallel_for_nemo') +def test_init_parallel_ranks(mock_initialize_model_parallel) -> None: + from nemo.utils import AppState + + app_state = AppState() + + app_state.tensor_model_parallel_size = 2 + app_state.pipeline_model_parallel_size = 3 + app_state.global_rank = 1 + app_state.local_rank = 0 + + mock_parallel_config = MagicMock() + mock_parallel_config.tensor_model_parallel_size = 2 + mock_parallel_config.pipeline_model_parallel_size = 3 + mock_parallel_config.virtual_pipeline_model_parallel_size = 4 + mock_parallel_config.ub_tp_comm_overlap = False + mock_parallel_config.pipeline_model_parallel_split_rank = None + + _strategy_lib.init_parallel_ranks( + world_size=2, global_rank=1, local_rank=0, parallel_config=mock_parallel_config, seed=1234, fp8=False, + ) + mock_initialize_model_parallel.assert_called_once_with( + world_size=2, + global_rank=1, + local_rank=0, + tensor_model_parallel_size=2, + pipeline_model_parallel_size=3, + virtual_pipeline_model_parallel_size=4, + seed=1234, + pipeline_model_parallel_split_rank=None, + use_fp8=False, + init_mpi_proc_group=False, + ) + + +@patch('torch.distributed.is_initialized', return_value=True) +@patch('megatron.core.parallel_state') +def test_init_model_parallel(mock_mpu, *args): + from nemo.utils import AppState + + app_state = AppState() + app_state.model_parallel_size = 1 + app_state.tensor_model_parallel_size = 2 + app_state.pipeline_model_parallel_size = 1 + app_state.pipeline_model_parallel_split_rank = None + app_state.init_mpi_proc_group = False + app_state.tensor_model_parallel_rank = 2 + app_state.pipeline_model_parallel_rank = 0 + + _mpu_tp_2(mock_mpu) + _strategy_lib.init_model_parallel(nn.Identity()) + + mock_mpu.initialize_model_parallel.assert_called_once_with( + tensor_model_parallel_size=2, + pipeline_model_parallel_size=1, + virtual_pipeline_model_parallel_size=None, + pipeline_model_parallel_split_rank=None, + ) + + +# TODO @chcui uncomment after DataConfig is merged +# @patch('nemo.lightning._strategy_lib.DataLoader', return_value=MagicMock()) +# @patch('megatron.core.parallel_state') +# def test_process_dataloader(mock_mpu, mock_dataloader) -> None: +# mock_dataloader_instance = MagicMock() +# mock_dataloader_instance.dataset = [1, 2, 3] +# mock_dataloader_instance.num_workers = 4 +# mock_dataloader_instance.pin_memory = True +# mock_dataloader_instance.persistent_workers = False +# +# data_config = DataConfig(256) +# data_config.micro_batch_size = 2 +# data_config.global_batch_size = 6 +# data_config.rampup_batch_size = 3 +# +# mock_mpu.get_data_parallel_rank.return_value = 0 +# mock_mpu.get_data_parallel_world_size.return_value = 1 +# +# out = _strategy_lib.process_dataloader(mock_dataloader_instance, data_config) +# assert isinstance(out.batch_sampler, MagicMock) +# mock_dataloader.assert_called_once_with( +# mock_dataloader_instance.dataset, +# batch_sampler=ANY, +# num_workers=4, +# pin_memory=True, +# persistent_workers=False, +# collate_fn=ANY +# ) + + +# @patch('nemo.lightning._strategy_lib.init_parallel_ranks') +# @patch('megatron.core.parallel_state') +# def test_setup_megatron_parallel_with_trainer(mock_mpu, mock_init_parallel_ranks) -> None: +# _mpu_tp_2(mock_mpu) +# mock_trainer = MagicMock(spec=pl.Trainer) +# mock_trainer.strategy = MegatronStrategy( +# ModelParallelConfig(tensor_model_parallel_size=2), +# DataConfig(256), +# ) +# mock_trainer.world_size = 2 +# mock_trainer.local_rank = 0 +# mock_trainer.global_rank = 1 + +# result = _strategy_lib.setup_megatron_parallel(mock_trainer, nn.Identity()) +# mock_init_parallel_ranks.assert_called_once() +# assert isinstance(result, LightningMegatronParallel) +# assert len(result) == 1 + +# # Test with function +# assert len(_strategy_lib.setup_megatron_parallel(mock_trainer, lambda: nn.Identity())) == 1 + + +# @patch('nemo.lightning._strategy_lib.init_parallel_ranks') +# @patch('megatron.core.parallel_state') +# def test_setup_megatron_parallel_virtual_pipelining(mock_mpu, mock_init_parallel_ranks) -> None: +# vp_size = 4 +# _mpu_tp_2(mock_mpu) +# mock_mpu.get_pipeline_model_parallel_world_size.return_value = 4 +# mock_trainer = MagicMock(spec=pl.Trainer) +# mock_trainer.strategy = MegatronStrategy( +# ModelParallelConfig( +# virtual_pipeline_model_parallel_size=vp_size, +# tensor_model_parallel_size=2, +# ), +# DataConfig(256), +# ) +# mock_trainer.world_size = 8 +# mock_trainer.local_rank = 0 +# mock_trainer.global_rank = 1 + +# result = _strategy_lib.setup_megatron_parallel(mock_trainer, Identity()) +# mock_init_parallel_ranks.assert_called_once() +# assert len(result) == vp_size + +# # Test with function +# assert len(_strategy_lib.setup_megatron_parallel(mock_trainer, lambda: nn.Identity())) == vp_size + +# # Test with a module with a copy method +# assert len(_strategy_lib.setup_megatron_parallel(mock_trainer, WithCopy())) == vp_size + +# with pytest.raises( +# ValueError, +# match="Model does not have a copy method. Please implement this or " + +# "pass in a function that returns the model" +# ): +# _strategy_lib.setup_megatron_parallel(mock_trainer, nn.Identity()) + + +# @patch('nemo.lightning._strategy_lib.init_parallel_ranks') +# @patch('megatron.core.parallel_state') +# def test_setup_megatron_parallel_with_fabric(mock_mpu, mock_init_parallel_ranks) -> None: +# _mpu_tp_2(mock_mpu) +# mock_trainer = MagicMock(spec=fl.Fabric) +# mock_trainer.strategy = FabricMegatronStrategy( +# ModelParallelConfig(tensor_model_parallel_size=2), +# DataConfig(256), +# ) +# mock_trainer.world_size = 2 +# mock_trainer.local_rank = 0 +# mock_trainer.global_rank = 1 + +# result = _strategy_lib.setup_megatron_parallel(mock_trainer, nn.Identity()) + +# mock_init_parallel_ranks.assert_called_once() +# assert isinstance(result, MegatronParallel) +# assert len(result) == 1 + + +# @patch('nemo.lightning._strategy_lib.init_parallel_ranks') +# @patch('megatron.core.parallel_state') +# def test_setup_megatron_parallel_with_strategy(mock_mpu, mock_init_parallel_ranks) -> None: +# _mpu_tp_2(mock_mpu) +# mock_trainer = MagicMock(spec=FabricMegatronStrategy) +# mock_trainer.configure_mock( +# parallelism=ModelParallelConfig(tensor_model_parallel_size=2), +# data_config=DataConfig(256), +# world_size=2, +# local_rank=0, +# global_rank=1 +# ) + +# result = _strategy_lib.setup_megatron_parallel(mock_trainer, nn.Identity()) + +# mock_init_parallel_ranks.assert_called_once() +# assert isinstance(result, MegatronParallel) +# assert len(result) == 1 + + +def _mpu_tp_2(mock_mpu) -> None: + mock_mpu.get_tensor_model_parallel_rank.return_value = 2 + mock_mpu.get_pipeline_model_parallel_rank.return_value = 0 + mock_mpu.get_pipeline_model_parallel_world_size.return_value = 1 + mock_mpu.get_pipeline_model_parallel_group.return_value = 0 + mock_mpu.get_tensor_model_parallel_group.return_value = 1