Skip to content

Commit

Permalink
clean up distributed env setup and support multi-device testing (#535)
Browse files Browse the repository at this point in the history
Clean up `distributed_model_parallel_state` and support multi-device
testing through `torch.distributed.spawn`.

---------

Signed-off-by: sichu <sichu@nvidia.com>
  • Loading branch information
sichu2023 authored Feb 11, 2025
1 parent 9cec09f commit 1188119
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 150 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_nemo2_conversion_equivalent_8m_bf16(tmp_path):
model_tag = "facebook/esm2_t6_8M_UR50D"
module = biobert_lightning_module(config=ESM2Config())
io.import_ckpt(module, f"hf://{model_tag}", tmp_path / "nemo_checkpoint")
with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"):
with megatron_parallel_state_utils.distributed_model_parallel_state():
assert_model_equivalence(tmp_path / "nemo_checkpoint", model_tag, precision="bf16")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_esm2_loss(dummy_protein_dataset, dummy_parquet_train_val_inputs):
def test_model_equivalence_with_huggingface_8m(precision):
model_tag = "facebook/esm2_t6_8M_UR50D"
ckpt_path = load("esm2/8m:2.0")
with megatron_parallel_state_utils.distributed_model_parallel_state(precision=precision):
with megatron_parallel_state_utils.distributed_model_parallel_state():
assert_model_equivalence(ckpt_path, model_tag, precision=precision)


Expand All @@ -195,7 +195,7 @@ def test_model_equivalence_with_huggingface_650m():
def test_model_equivalence_with_huggingface_650m_bf16():
model_tag = "facebook/esm2_t33_650M_UR50D"
ckpt_path = load("esm2/650m:2.0")
with megatron_parallel_state_utils.distributed_model_parallel_state(precision="bf16"):
with megatron_parallel_state_utils.distributed_model_parallel_state():
assert_model_equivalence(ckpt_path, model_tag, precision="bf16")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def forward(self, x):
# TODO rewrite unittest and potentially LightningPassthroughPredictionMixin
@pytest.mark.xfail(reason="MegatronStrategy no longer has '_get_loss_reduction' attribute")
def test_mixin_strategy_contract_get_loss_reduction():
with megatron_parallel_state_utils.clean_parallel_state_context():
with megatron_parallel_state_utils.distributed_model_parallel_state():
strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,30 @@ def my_test():

import os
from contextlib import contextmanager
from typing import Any, Iterator, Optional, Sequence
from typing import Any, Optional, Sequence
from unittest import mock
from unittest.mock import MagicMock

import lightning.pytorch as pl
import megatron.core.num_microbatches_calculator
import torch
import torch.distributed
import torch.multiprocessing.spawn
from megatron.core import parallel_state
from megatron.core.tensor_parallel import random as tp_random
from nemo import lightning as nl
from nemo.utils import logging
from pytest import MonkeyPatch
from torch.testing._internal.distributed.fake_pg import FakeStore

from bionemo.core.utils.dtypes import PrecisionTypes


__all__: Sequence[str] = (
"clean_parallel_state_context",
"clean_up_distributed_and_parallel_states",
"distributed_model_parallel_state",
"mock_distributed_parallel_state",
)

DEFAULT_MASTER_ADDR = "localhost"
DEFAULT_MASTER_PORT = "29500"
DEFAULT_NCCL_TIMEOUT = "30" # in second


def _reset_microbatch_calculator():
"""Resets _GLOBAL_NUM_MICROBATCHES_CALCULATOR in megatron which is used in NeMo to initilised model parallel in
Expand All @@ -63,139 +64,66 @@ def _reset_microbatch_calculator():
megatron.core.num_microbatches_calculator._GLOBAL_NUM_MICROBATCHES_CALCULATOR = None


def _dummy() -> None:
return


def _teardown_apex_megatron_cuda():
"""Cleans GPU allocation and model and data parallel settings after usage of a model:
- sets the global variables related to model and data parallelism to None in Apex and Megatron:.
- releases all unoccupied cached GPU memory currently held by the caching CUDA allocator, see torch.cuda.empty_cache
""" # noqa: D205, D415
torch.cuda.empty_cache()
def clean_up_distributed_and_parallel_states():
"""Clean up parallel states, torch.distributed and torch cuda cache."""
_reset_microbatch_calculator()
parallel_state.destroy_model_parallel()


def _initialize_distributed_parallel_state(
devices: int = 1,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
pipeline_model_parallel_split_rank: int = 0,
context_parallel_size: int = 1,
interactive: bool = False,
precision: PrecisionTypes = "fp32",
) -> pl.Trainer | None:
trainer = None
# initialize pytorch DDP
# if not interactive and not torch.distributed.is_initialized():
if not torch.distributed.is_initialized():
logging.info("pytorch DDP is not initialized. Initializing with pytorch-lightning...")
trainer = pl.Trainer(
devices=devices,
strategy="ddp" if not interactive else "auto",
num_nodes=1,
# plugins=nl.MegatronMixedPrecision(
# precision=precision,
# params_dtype=get_autocast_dtype(precision),
# pipeline_dtype=get_autocast_dtype(precision),
# autocast_enabled=False,
# ),
)

if trainer.strategy.launcher is not None:
trainer.strategy.launcher.launch(_dummy, trainer=trainer)
trainer.strategy.setup_environment()

if not interactive and parallel_state.is_unitialized():
logging.info("Megatron DDP is not initialized. Initializing...")
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
context_parallel_size=context_parallel_size,
)

return trainer
parallel_state.destroy_model_parallel() # destroy parallel state before distributed
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
torch.cuda.empty_cache()


@contextmanager
def clean_parallel_state_context() -> Iterator[None]:
"""Puts you into a clean parallel state, and again tears it down at the end."""
try:
_teardown_apex_megatron_cuda()
yield
except Exception as e:
# TODO (@skothenhill) verify this is a problem and that this is a solution. Had issues with keyboard interrupts being ignored inside context manager.
raise Exception from e
finally:
_teardown_apex_megatron_cuda()
def distributed_model_parallel_state(
seed: int = 42,
rank: int = 0,
world_size: int = 1,
backend: str = "nccl",
**initialize_model_parallel_kwargs,
):
"""Context manager for torch distributed and parallel state testing.
Args:
seed (int): random seed to be passed into tensor_parallel.random (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py). default to 42.
rank (int): global rank of the current cuda device. default to 0.
world_size (int): world size or number of devices. default to 1.
backend (str): backend to torch.distributed.init_process_group. default to 'nccl'.
**initialize_model_parallel_kwargs: kwargs to be passed into initialize_model_parallel (https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py).
"""
with MonkeyPatch.context() as context:
initial_states = None
try:
clean_up_distributed_and_parallel_states()

# distributed and parallel state set up
if not os.environ.get("MASTER_ADDR", None):
context.setenv("MASTER_ADDR", DEFAULT_MASTER_ADDR)
if not os.environ.get("MASTER_PORT", None):
context.setenv("MASTER_PORT", DEFAULT_MASTER_PORT)
if not os.environ.get("NCCL_TIMEOUT", None):
context.setenv("NCCL_TIMEOUT", DEFAULT_NCCL_TIMEOUT)
context.setenv("RANK", str(rank))

torch.distributed.init_process_group(backend=backend, world_size=world_size)
parallel_state.initialize_model_parallel(**initialize_model_parallel_kwargs)

# tensor parallel random seed set up
# do not call torch.cuda.manual_seed after so!
if tp_random.get_cuda_rng_tracker().is_initialized():
initial_states = tp_random.get_cuda_rng_tracker().get_states()
if seed is not None:
tp_random.model_parallel_cuda_manual_seed(seed)

@contextmanager
def distributed_model_parallel_state(
seed: Optional[int] = 42,
devices: int = 1,
tensor_model_parallel_size: int = 1,
pipeline_model_parallel_size: int = 1,
pipeline_model_parallel_split_rank: int = 0,
context_parallel_size: int = 1,
interactive: bool = False,
precision: PrecisionTypes = "fp32",
) -> Iterator[None]:
"""Context manager for handling creating and cleaning up distributed model parallel state for tests.
Use like:
with distributed_model_parallel_state():
# your test code here
# After the block your state is cleaned up.
""" # noqa: D205
initial_states: Optional[Any] = None
trainer: pl.Trainer | None = None

try:
_teardown_apex_megatron_cuda()
trainer = _initialize_distributed_parallel_state(
devices=devices,
tensor_model_parallel_size=tensor_model_parallel_size,
pipeline_model_parallel_size=pipeline_model_parallel_size,
pipeline_model_parallel_split_rank=pipeline_model_parallel_split_rank,
context_parallel_size=context_parallel_size,
interactive=interactive,
precision=precision,
)
# Our goal is to set required state on entry, and then restore current state on exit for the RNGs.
# there are two possibilities that are handled below:
# 1. If the RNG state is not initialized, we need to set it up and then
# unset it on exit to restore the current state. We track that this is the case when `initial_states` is `None`.
# 2. If the RNG state is initialized, we need to track this state and reset it on exit to be what it was on entry.
# We track that this is the case when `initial_states` is not `None`.
if tp_random.get_cuda_rng_tracker().is_initialized():
initial_states = tp_random.get_cuda_rng_tracker().get_states()
if seed is not None:
# Set the seed if provided, this case is valid whether or not the RNG had state previously.
# on exit the RNG state will be restored to what it was on entry.
tp_random.model_parallel_cuda_manual_seed(seed)
else:
# This is the case where the RNG state is not initialized and no seed was provided.
# We need to raise an error in this case, as we cannot restore the RNG state on exit and we need a seed
# to initialize the RNG state to. This only happens if the user overrides the default seed and sets it
# to None, and additionally if the RNG state was not initialized externally, as there is a default seed of 42.
if initial_states is None:
raise ValueError(
"You must provide a seed if the initial parallel state is unset. "
"Either provide a seed or leave the default seed (rather setting to None) "
"or initialize the RNG state externally."
)
yield
finally:
if initial_states is not None:
tp_random.get_cuda_rng_tracker().set_states(initial_states)
else:
# Reset to the unset state
tp_random.get_cuda_rng_tracker().reset()
_teardown_apex_megatron_cuda()
if trainer is not None:
nl.teardown(trainer)
yield
finally:
# restore/unset tensor parallel random seed
if initial_states is not None:
tp_random.get_cuda_rng_tracker().set_states(initial_states)
else:
# Reset to the unset state
tp_random.get_cuda_rng_tracker().reset()

clean_up_distributed_and_parallel_states()


@contextmanager
Expand Down
Loading

0 comments on commit 1188119

Please sign in to comment.