Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] Add context parallel support for packed dataset in THD format #9540

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 35 additions & 16 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from functools import cache, partial
from importlib.metadata import version
from typing import Any, Dict, Iterator, List, Optional, Union
import transformer_engine_extensions as tex

import torch
from omegaconf import OmegaConf
Expand Down Expand Up @@ -1153,22 +1154,23 @@ def get_batch_on_this_context_parallel_rank(self, batch):
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val

# check if the batch is not in THD format
if 'cu_seqlens' not in batch:
for key, val in batch.items():
if val is not None:
seq_dim = 1 if key != 'attention_mask' else 2
val = val.view(
*val.shape[0:seq_dim],
2 * cp_size,
val.shape[seq_dim] // (2 * cp_size),
*val.shape[(seq_dim + 1) :],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda(
non_blocking=True
)
val = val.index_select(seq_dim, index)
val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :])
batch[key] = val
batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub

return batch
Expand Down Expand Up @@ -1239,6 +1241,23 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
)
raise e

# get packed sequences for this context parallel rank
cp_size = parallel_state.get_context_parallel_world_size()
if cp_size > 1:
cp_rank = parallel_state.get_context_parallel_rank()
for key in required_keys:
val = batch[key]
if key not in {"cu_seqlens"}:
index = tex.thd_get_partitioned_indices(cu_seqlens, val.size(1), cp_size, cp_rank)
val = val.index_select(1, index)
tomlifu marked this conversation as resolved.
Show resolved Hide resolved
batch[key] = val
cu_seqlens = cu_seqlens // cp_size
forward_args = {
'input_ids': batch['tokens'],
'position_ids': batch['position_ids'],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the position_ids means the token_id in packed sequence? how is this argument used in training fwd and bwd?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The position_ids is the position of the tokens in a sequence (e.g. [0,1,2, ... , seq_len-1]). In a packed sequence, we have a list of position_ids since the packed sequence is composed of many individual sequences. I'm not too sure if that's what you mean by token_id. It's used the same way as input_ids in training fwd and bwd.

'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'],
'labels': batch['labels'] if 'labels' in batch else None
}
forward_args['packed_seq_params'] = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
Expand Down
40 changes: 36 additions & 4 deletions scripts/nlp_language_modeling/prepare_packed_ft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import TYPE_CHECKING, Tuple

import numpy as np
import torch

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'torch' is not used.

from nemo.collections.nlp.data.language_modeling.megatron.gpt_sft_dataset import GPTSFTDataset
from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer
Expand Down Expand Up @@ -65,7 +66,6 @@
the unpacked case.
"""


def tokenize_dataset(cfg: 'DictConfig'):
"""
Tokenizes a dataset using the same configuration file as finetuninng with GPTSFTDataset.
Expand All @@ -83,12 +83,26 @@ def tokenize_dataset(cfg: 'DictConfig'):
# using the same template as SFT/PEFT script. This may be overkill but guarantees the preprocess settings
# are identical to normal SFT training
data_cfg = cfg.model.data.train_ds
pad_seq_length_to_mult=16
cp_size = cfg.model.context_parallel_size

# if context parallel is used, each individual data length in one packed dataset sample
# needs to be a multiple of (cp_size * 2): https://github.com/NVIDIA/TransformerEngine/pull/641
if cp_size > 1:
# Function to calculate the LCM of two numbers
def lcm(x, y):
larger = max(x, y)
for i in range(larger, x * y + 1):
if i % x == 0 and i % y == 0:
return i
pad_seq_length_to_mult = lcm(pad_seq_length_to_mult, cp_size * 2)
tomlifu marked this conversation as resolved.
Show resolved Hide resolved

dataset = GPTSFTDataset(
file_path=data_cfg.file_names[0],
tokenizer=get_nmt_tokenizer(library="sentencepiece", tokenizer_model=cfg.tokenizer_path),
max_seq_length=data_cfg.max_seq_length,
min_seq_length=data_cfg.min_seq_length,
pad_seq_length_to_mult=16, # adds padding in collate_fn so this value is irrelevant here
pad_seq_length_to_mult=pad_seq_length_to_mult,
add_bos=data_cfg.get('add_bos', False),
add_eos=data_cfg.get('add_eos', True),
add_sep=data_cfg.get('add_sep', False),
Expand All @@ -109,8 +123,26 @@ def tokenize_dataset(cfg: 'DictConfig'):
special_tokens=data_cfg.get('chat_prompt_tokens', None),
is_test=True,
)

return np.array([dataset[i] for i in range(len(dataset))])
max_seq_length = dataset.max_seq_length
pad_id = dataset.tokenizer.eos_id
pad_seq_length_to_mult = dataset.pad_seq_length_to_mult
dataset = np.array([dataset[i] for i in range(len(dataset))])
if cp_size > 1:
def pre_pad_dataset(data, max_length, pad_id):
'''
pad each individual data point to the length of max_length
'''
for key,val in data.items():
if key in {'input_ids', 'context_ids'}:
val = val + [pad_id] * (max_length - len(val) + 1) # add 1 for cp
tomlifu marked this conversation as resolved.
Show resolved Hide resolved
data[key]=val
return
ceil_to_nearest = lambda n, m : (n + m - 1) // m * m
for data in dataset:
max_length = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult))
assert max_length <= max_seq_length
tomlifu marked this conversation as resolved.
Show resolved Hide resolved
pre_pad_dataset(data, max_length, pad_id)
return dataset


@dataclass
Expand Down
Loading