diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 1cdee42f580e..ca8cbdbf050a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -23,6 +23,7 @@ from typing import Any, Dict, Iterator, List, Optional, Union import torch +import transformer_engine_extensions as tex from omegaconf import OmegaConf from omegaconf.dictconfig import DictConfig from pkg_resources import packaging @@ -1166,22 +1167,23 @@ def get_batch_on_this_context_parallel_rank(self, batch): cp_size = parallel_state.get_context_parallel_world_size() if cp_size > 1: cp_rank = parallel_state.get_context_parallel_rank() - for key, val in batch.items(): - if val is not None: - seq_dim = 1 if key != 'attention_mask' else 2 - val = val.view( - *val.shape[0:seq_dim], - 2 * cp_size, - val.shape[seq_dim] // (2 * cp_size), - *val.shape[(seq_dim + 1) :], - ) - index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True).cuda( - non_blocking=True - ) - val = val.index_select(seq_dim, index) - val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) - batch[key] = val - + # check if the batch is not in THD format + if 'cu_seqlens' not in batch: + for key, val in batch.items(): + if val is not None: + seq_dim = 1 if key != 'attention_mask' else 2 + val = val.view( + *val.shape[0:seq_dim], + 2 * cp_size, + val.shape[seq_dim] // (2 * cp_size), + *val.shape[(seq_dim + 1) :], + ) + index = torch.tensor( + [cp_rank, (2 * cp_size - cp_rank - 1)], device="cpu", pin_memory=True + ).cuda(non_blocking=True) + val = val.index_select(seq_dim, index) + val = val.view(*val.shape[0:seq_dim], -1, *val.shape[(seq_dim + 2) :]) + batch[key] = val batch['num_valid_tokens_in_ub'] = num_valid_tokens_in_ub return batch @@ -1252,6 +1254,26 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ ) raise e + # get packed sequences for this context parallel rank + cp_size = parallel_state.get_context_parallel_world_size() + if cp_size > 1: + cp_rank = parallel_state.get_context_parallel_rank() + for key in required_keys: + val = batch[key] + if key != "cu_seqlens": + seq_dim = 1 if key != 'attention_mask' else 2 + index = tex.thd_get_partitioned_indices( + cu_seqlens, val.size(seq_dim), cp_size, cp_rank + ) + val = val.index_select(seq_dim, index) + batch[key] = val + cu_seqlens = cu_seqlens // cp_size + forward_args = { + 'input_ids': batch['tokens'], + 'position_ids': batch['position_ids'], + 'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'], + 'labels': batch['labels'] if 'labels' in batch else None, + } forward_args['packed_seq_params'] = PackedSeqParams( cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, diff --git a/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py b/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py index b3251e75c84e..5fdbaed508f2 100644 --- a/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py +++ b/scripts/nlp_language_modeling/prepare_packed_ft_dataset.py @@ -83,12 +83,20 @@ def tokenize_dataset(cfg: 'DictConfig'): # using the same template as SFT/PEFT script. This may be overkill but guarantees the preprocess settings # are identical to normal SFT training data_cfg = cfg.model.data.train_ds + pad_seq_length_to_mult = 16 + cp_size = cfg.model.context_parallel_size + + # if context parallel is used, each individual data length in one packed dataset sample + # needs to be a multiple of (cp_size * 2): https://github.com/NVIDIA/TransformerEngine/pull/641 + if cp_size > 1: + pad_seq_length_to_mult = max(pad_seq_length_to_mult, cp_size * 2) + dataset = GPTSFTDataset( file_path=data_cfg.file_names[0], tokenizer=get_nmt_tokenizer(library="sentencepiece", tokenizer_model=cfg.tokenizer_path), max_seq_length=data_cfg.max_seq_length, min_seq_length=data_cfg.min_seq_length, - pad_seq_length_to_mult=16, # adds padding in collate_fn so this value is irrelevant here + pad_seq_length_to_mult=pad_seq_length_to_mult, add_bos=data_cfg.get('add_bos', False), add_eos=data_cfg.get('add_eos', True), add_sep=data_cfg.get('add_sep', False), @@ -109,8 +117,29 @@ def tokenize_dataset(cfg: 'DictConfig'): special_tokens=data_cfg.get('chat_prompt_tokens', None), is_test=True, ) - - return np.array([dataset[i] for i in range(len(dataset))]) + max_seq_length = dataset.max_seq_length + pad_id = dataset.tokenizer.eos_id + pad_seq_length_to_mult = dataset.pad_seq_length_to_mult + dataset = np.array([dataset[i] for i in range(len(dataset))]) + if cp_size > 1: + + def pre_pad_dataset(data, max_length, pad_id): + ''' + pad each individual data point to the length of max_length + ''' + for key, val in data.items(): + if key in {'input_ids', 'context_ids'}: + # because input_ids is truncated by 1 for labels in the collate_fn of GPTSFTPackedDataset + # in gpt_sft_dataset.py, we add 1 extra padding here + val = val + [pad_id] * (max_length - len(val) + 1) + data[key] = val + return + + ceil_to_nearest = lambda n, m: (n + m - 1) // m * m + for data in dataset: + max_length = min(max_seq_length, ceil_to_nearest(len(data['input_ids']), pad_seq_length_to_mult)) + pre_pad_dataset(data, max_length, pad_id) + return dataset @dataclass