Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NeMo-UX] Adding GPTModel & MockDataModule #9011

Merged
merged 16 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nemo/collections/nlp/parts/mixins/nlp_adapter_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ def merge_inference_cfg(cls, path: str, cfg: DictConfig) -> DictConfig:
"""

peft_cfg = cls.restore_from(path, return_config=True)

if hasattr(peft_cfg, 'peft') and peft_cfg.peft.peft_scheme not in [None, 'none']:
# before PEFT migrates to distributed ckpt, eval must use same TP/PP as training
for p in ['tensor_model_parallel_size', 'pipeline_model_parallel_size']:
Expand Down
1 change: 1 addition & 0 deletions nemo/io/pl.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def load_checkpoint(
# return pl_load(path, map_location=map_location)

checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=str(path))

checkpoint = _fix_tensors_device(checkpoint)

return checkpoint
Expand Down
24 changes: 24 additions & 0 deletions nemo/lightning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Union

from lightning.pytorch import plugins as _pl_plugins
from lightning_fabric.plugins.environments import slurm

from nemo.lightning.base import get_vocab_size, teardown
from nemo.lightning.pytorch.plugins import MegatronDataSampler
from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler
from nemo.lightning.pytorch.strategies import MegatronStrategy


# We monkey patch because nvidia uses a naming convention for SLURM jobs
def _is_slurm_interactive_mode():
job_name = slurm.SLURMEnvironment.job_name()
return job_name is None or job_name.endswith("bash") or job_name.endswith("interactive")


slurm._is_slurm_interactive_mode = _is_slurm_interactive_mode # noqa: SLF001


_pl_plugins._PLUGIN_INPUT = Union[_pl_plugins._PLUGIN_INPUT, _data_sampler.DataSampler] # noqa: SLF001


__all__ = ["MegatronStrategy", "MegatronDataSampler", "get_vocab_size", "teardown"]
51 changes: 51 additions & 0 deletions nemo/lightning/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import gc
import os
from pathlib import Path
from typing import Optional

import torch
import torch.distributed
from pytorch_lightning import Trainer
from torch import nn


NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE = "NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE"
Fixed Show fixed Hide fixed
DEFAULT_NEMO_CACHE_HOME = Path.home() / ".cache" / "nemo"
NEMO_CACHE_HOME = Path(os.getenv("NEMO_HOME", DEFAULT_NEMO_CACHE_HOME))
DEFAULT_NEMO_DATASETS_CACHE = NEMO_CACHE_HOME / "datasets"
NEMO_DATASETS_CACHE = Path(os.getenv("NEMO_DATASETS_CACHE", DEFAULT_NEMO_DATASETS_CACHE))
Fixed Show fixed Hide fixed
Dismissed Show dismissed Hide dismissed


def get_vocab_size(config, vocab_size: int, make_vocab_size_divisible_by: int = 128,) -> int:
from nemo.utils import logging

after = vocab_size
multiple = make_vocab_size_divisible_by * config.tensor_model_parallel_size
while (after % multiple) != 0:
after += 1
logging.info(
f"Padded vocab_size: {after}, original vocab_size: {vocab_size}, dummy tokens:" f" {after - vocab_size}."
)

return after


def teardown(trainer: Trainer, model: Optional[nn.Module] = None) -> None:
# Destroy torch distributed
if torch.distributed.is_initialized():
from megatron.core import mpu

mpu.destroy_model_parallel()
torch.distributed.destroy_process_group()

trainer._teardown() # noqa: SLF001
if model is not None:
for obj in gc.get_objects():
if torch.is_tensor(obj) and obj.is_cuda:
del obj

gc.collect()
torch.cuda.empty_cache()


__all__ = ["get_vocab_size", "teardown"]
281 changes: 281 additions & 0 deletions nemo/lightning/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
import abc
import logging
import os
from itertools import chain
from typing import List, Literal, Optional

import torch
from torch.utils.data import DataLoader, Dataset


def create_dataloader(
dataset: "Dataset", drop_last: bool = True, pad_samples_to_global_batch_size=False, **kwargs
) -> DataLoader:
output = DataLoader(dataset, collate_fn=dataset.collate_fn, **kwargs)

output._drop_last = drop_last # noqa: SLF001
output._pad_samples_to_global_batch_size = pad_samples_to_global_batch_size # noqa: SLF001

return output


def setup_microbatch_calculator(
global_rank: int, micro_batch_size: int, global_batch_size: int, rampup_batch_size: Optional[List[int]] = None,
) -> None:
"""
Initializes the data for distributed training by setting up the microbatch calculator
based on the provided global rank and data configuration.

This function checks if the microbatch calculator has already been initialized. If it has,
the function validates that the current configuration matches the initialized settings. If the
calculator has not been initialized, it sets up a new one with the provided configuration.

Args:
global_rank (int): The global rank of the current process.
config (DataConfig): The data configuration object containing settings for global batch size,
micro batch size, data parallel size, and optional ramp-up batch size.

Raises
------
Exception: If the microbatch calculator has already been initialized with different settings.

"""
from nemo_ext.lightning._strategy_lib import NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE

from nemo.utils import AppState

app_state = AppState()

if os.environ.get(NEMO_MEGATRON_MODEL_PARALLEL_APPSTATE_OVERRIDE, "false").lower() == "true":
init_global_rank = app_state.global_rank
else:
init_global_rank = global_rank

from apex.transformer.microbatches import ConstantNumMicroBatches
from apex.transformer.pipeline_parallel.utils import (
_GLOBAL_NUM_MICROBATCHES_CALCULATOR,
setup_microbatch_calculator,
)

if _GLOBAL_NUM_MICROBATCHES_CALCULATOR is None:
setup_microbatch_calculator(
rank=init_global_rank,
global_batch_size=global_batch_size,
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatches):
assert _GLOBAL_NUM_MICROBATCHES_CALCULATOR.current_global_batch_size == global_batch_size
assert _GLOBAL_NUM_MICROBATCHES_CALCULATOR.micro_batch_size == micro_batch_size
assert _GLOBAL_NUM_MICROBATCHES_CALCULATOR.num_micro_batches == global_batch_size // (
micro_batch_size * app_state.data_parallel_size
)
else:
raise Exception("Microbatch calculator already initialized.")


def add_megatron_sampler(
dataloader: DataLoader,
micro_batch_size: int,
global_batch_size: int,
rampup_batch_size: Optional[List[int]] = None,
consumed_samples: int = 0,
dataloader_type: Literal["single", "cyclic"] = "single",
# data_sharding: bool = False
) -> DataLoader:
from megatron.core import mpu

if dataloader_type == 'single':
batch_sampler = MegatronPretrainingSampler(
total_samples=len(dataloader.dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
global_batch_size=global_batch_size,
rampup_batch_size=rampup_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
drop_last=getattr(dataloader, "_drop_last", False),
pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False),
)
elif dataloader_type == 'cyclic':
batch_sampler = MegatronPretrainingRandomSampler(
dataloader.dataset,
total_samples=len(dataloader.dataset),
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=mpu.get_data_parallel_rank(),
data_parallel_size=mpu.get_data_parallel_world_size(),
pad_samples_to_global_batch_size=getattr(dataloader, "_pad_samples_to_global_batch_size", False),
# data_sharding=data_sharding
)
else:
raise Exception(f'{dataloader_type} dataloader type is not supported.')

return DataLoader(
dataloader.dataset,
batch_sampler=batch_sampler,
num_workers=dataloader.num_workers,
pin_memory=dataloader.pin_memory,
persistent_workers=dataloader.persistent_workers,
collate_fn=dataloader.collate_fn,
)


# TODO: Replace this with megatron.core.data.data_samplers after we upgrade
class BaseMegatronSampler:
def __init__(
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
global_batch_size: Optional[int] = None,
rampup_batch_size: Optional[list] = None,
pad_samples_to_global_batch_size: Optional[bool] = False,
) -> None:
# Sanity checks.
if total_samples <= 0:
raise RuntimeError(f"no sample to consume: {total_samples}")
if consumed_samples >= total_samples:
raise RuntimeError(f"no samples left to consume: {consumed_samples}, {total_samples}")
if micro_batch_size <= 0:
raise RuntimeError(f"micro_batch_size size must be greater than 0, but {micro_batch_size}")
if data_parallel_size <= 0:
raise RuntimeError(f"data parallel size must be greater than 0, but {data_parallel_size}")
if data_parallel_rank >= data_parallel_size:
raise RuntimeError(
f"data_parallel_rank should be smaller than data size, but {data_parallel_rank} >= {data_parallel_size}"
)
if global_batch_size is not None and rampup_batch_size is None:
if global_batch_size % (micro_batch_size * data_parallel_size) != 0:
raise RuntimeError(
f"`global_batch_size` ({global_batch_size}) is not divisible by "
f"`micro_batch_size ({micro_batch_size}) x data_parallel_size "
f"({data_parallel_size})`"
)
if pad_samples_to_global_batch_size and global_batch_size is None:
raise RuntimeError(
"`pad_samples_to_global_batch_size` can be `True` only when "
"`global_batch_size` is set to an integer value"
)

# Keep a copy of input params for later use.
self.total_samples = total_samples
self.consumed_samples = consumed_samples
self.micro_batch_size = micro_batch_size
self.data_parallel_rank = data_parallel_rank
self.micro_batch_times_data_parallel_size = self.micro_batch_size * data_parallel_size
self.drop_last = drop_last
self.global_batch_size = global_batch_size
self.pad_samples_to_global_batch_size = pad_samples_to_global_batch_size

logging.info(
f"Instantiating MegatronPretrainingSampler with total_samples: {total_samples} and"
f" consumed_samples: {consumed_samples}"
)

def __len__(self):
num_available_samples: int = self.total_samples - self.consumed_samples
if self.global_batch_size is not None:
if self.drop_last:
return num_available_samples // self.global_batch_size
else:
return (num_available_samples + self.global_batch_size - 1) // self.global_batch_size
else:
return (num_available_samples - 1) // self.micro_batch_times_data_parallel_size + 1

@abc.abstractmethod
def __iter__(self):
...
Dismissed Show dismissed Hide dismissed


class MegatronPretrainingSampler(BaseMegatronSampler):
def get_start_end_idx(self):
start_idx = self.data_parallel_rank * self.micro_batch_size
end_idx = start_idx + self.micro_batch_size
return start_idx, end_idx

def __iter__(self):
batch = []
# Last batch will be dropped if drop_last is not set False
indices = range(self.consumed_samples, self.total_samples)
if (not self.drop_last) and self.pad_samples_to_global_batch_size:
pad_samples_num = -len(indices) % self.global_batch_size
pad_indices = range(-1, -pad_samples_num - 1, -1)
indices = chain(indices, pad_indices)

for idx in indices:
batch.append(idx)
if len(batch) == self.micro_batch_times_data_parallel_size:
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]
batch = []

# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
assert (
not self.pad_samples_to_global_batch_size
), "with pad_samples_to_global_batch_size all batches should be complete"
start_idx, end_idx = self.get_start_end_idx()
yield batch[start_idx:end_idx]


class MegatronPretrainingRandomSampler(BaseMegatronSampler):
def __init__(
self,
total_samples: int,
consumed_samples: int,
micro_batch_size: int,
data_parallel_rank: int,
data_parallel_size: int,
drop_last: bool = True,
global_batch_size: Optional[int] = None,
pad_samples_to_global_batch_size: Optional[bool] = False,
) -> None:
super().__init__(
total_samples=total_samples,
consumed_samples=consumed_samples,
micro_batch_size=micro_batch_size,
data_parallel_rank=data_parallel_rank,
data_parallel_size=data_parallel_size,
drop_last=drop_last,
global_batch_size=global_batch_size,
pad_samples_to_global_batch_size=pad_samples_to_global_batch_size,
)
assert (
not pad_samples_to_global_batch_size
), "`MegatronPretrainingRandomSampler` does not support sample padding"
self.last_batch_size = self.total_samples % self.micro_batch_times_data_parallel_size

def __iter__(self):
active_total_samples = self.total_samples - self.last_batch_size
self.epoch = self.consumed_samples // active_total_samples
current_epoch_samples = self.consumed_samples % active_total_samples
assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0

# data sharding and random sampling
bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
bucket_offset = current_epoch_samples // self.data_parallel_size
start_idx = self.data_parallel_rank * bucket_size

g = torch.Generator()
g.manual_seed(self.epoch)
random_idx = torch.randperm(bucket_size, generator=g).tolist()
idx_range = [start_idx + x for x in random_idx[bucket_offset:]]

batch = []
# Last batch if not complete will be dropped.
for idx in idx_range:
batch.append(idx)
if len(batch) == self.micro_batch_size:
self.consumed_samples += self.micro_batch_times_data_parallel_size
yield batch
batch = []

# Check the last partial batch and see drop_last is set
if len(batch) > 0 and not self.drop_last:
yield batch
2 changes: 1 addition & 1 deletion nemo/lightning/pytorch/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from nemo_ext.lightning.pytorch.callbacks.progress import MegatronProgressBar
from nemo.lightning.pytorch.callbacks.progress import MegatronProgressBar

__all__ = ["MegatronProgressBar"]
3 changes: 3 additions & 0 deletions nemo/lightning/pytorch/plugins/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from nemo.lightning.pytorch.plugins.data_sampler import MegatronDataSampler

__all__ = ["MegatronDataSampler"]
Loading
Loading