From 5416242cfa035d4115d84100406d594e953497bb Mon Sep 17 00:00:00 2001 From: Sandeep Subramanian Date: Tue, 11 Jul 2023 12:56:17 -0700 Subject: [PATCH] RoPE length extrapolation with interpolation (#7005) * Push changes Signed-off-by: MaximumEntropy * Fixes Signed-off-by: MaximumEntropy * add continue training script Signed-off-by: MaximumEntropy * [WIP] nonlinear interp Signed-off-by: MaximumEntropy * Fix Signed-off-by: MaximumEntropy * override encoder_seq_len Signed-off-by: MaximumEntropy * Remove nonlinear Signed-off-by: MaximumEntropy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * sft with pi (#7006) * sft with pi Signed-off-by: Evelina * update values only if not None" Signed-off-by: Evelina --------- Signed-off-by: Evelina * Address comments Signed-off-by: MaximumEntropy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add info Signed-off-by: MaximumEntropy * Empty Signed-off-by: MaximumEntropy --------- Signed-off-by: MaximumEntropy Signed-off-by: Evelina Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Evelina <10428420+ekmb@users.noreply.github.com> Signed-off-by: Gerald Shen --- .../conf/megatron_gpt_config.yaml | 1 + .../megatron_gpt_continue_training.py | 193 ++++++++++++++++++ .../tuning/conf/megatron_gpt_sft.yaml | 2 + .../tuning/megatron_gpt_peft_eval.py | 4 + .../tuning/megatron_gpt_sft.py | 9 + .../language_modeling/megatron/gpt_model.py | 2 + .../language_modeling/megatron_gpt_model.py | 17 +- .../modules/common/megatron/language_model.py | 7 +- .../nlp/modules/common/megatron/module.py | 4 +- .../rotary_position_embedding.py | 17 +- 10 files changed, 249 insertions(+), 7 deletions(-) create mode 100644 examples/nlp/language_modeling/megatron_gpt_continue_training.py diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml index e588e94a6720..c2b0343c2ff7 100755 --- a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml @@ -83,6 +83,7 @@ model: share_embeddings_and_output_weights: True # Share embedding and output layer weights. overlap_p2p_comm: False # Overlap p2p communication with computes. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 batch_p2p_comm: True # Batch consecutive inter-peer send/recv operations. This argument is valid only when `virtual_pipeline_model_parallel_size` is larger than 1 + seq_len_interpolation_factor: null # RoPE Interpolation factor for sequence length. This is used to build long-context models with RoPE ex: https://arxiv.org/abs/2306.15595. tokenizer: library: 'megatron' diff --git a/examples/nlp/language_modeling/megatron_gpt_continue_training.py b/examples/nlp/language_modeling/megatron_gpt_continue_training.py new file mode 100644 index 000000000000..e90198833595 --- /dev/null +++ b/examples/nlp/language_modeling/megatron_gpt_continue_training.py @@ -0,0 +1,193 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import tempfile + +from omegaconf.omegaconf import OmegaConf, open_dict +from pytorch_lightning import Trainer +from pytorch_lightning.plugins.environments import TorchElasticEnvironment +from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector + +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel +from nemo.collections.nlp.parts.nlp_overrides import ( + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPStrategy, + NLPSaveRestoreConnector, + PipelineMixedPrecisionPlugin, +) +from nemo.core.config import hydra_runner +from nemo.utils import AppState, logging +from nemo.utils.exp_manager import exp_manager +from nemo.utils.model_utils import inject_model_parallel_rank + + +def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): + """ + This function modifies the original gpt pre-training config (t5_cfg) with attributes from the finetuning config (cfg). + The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. + """ + OmegaConf.set_struct(gpt_cfg, True) + OmegaConf.resolve(cfg) + with open_dict(gpt_cfg): + gpt_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) + gpt_cfg.micro_batch_size = cfg.model.micro_batch_size + gpt_cfg.global_batch_size = cfg.model.global_batch_size + gpt_cfg.sequence_parallel = cfg.model.get("sequence_parallel", False) + gpt_cfg.activations_checkpoint_granularity = cfg.model.get("activations_checkpoint_granularity", None) + gpt_cfg.activations_checkpoint_num_layers = cfg.model.get("activations_checkpoint_num_layers", None) + gpt_cfg.activations_checkpoint_method = cfg.model.get("activations_checkpoint_method", None) + gpt_cfg.data = cfg.model.data + gpt_cfg.optim = cfg.model.optim + gpt_cfg.precision = cfg.trainer.precision + gpt_cfg.restore_from_path = cfg.restore_from_path + gpt_cfg.resume_from_checkpoint = cfg.model.resume_from_checkpoint + gpt_cfg.gradient_as_bucket_view = cfg.model.gradient_as_bucket_view + gpt_cfg.encoder_seq_length = cfg.model.encoder_seq_length + gpt_cfg.max_position_embeddings = cfg.model.max_position_embeddings + gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor + gpt_cfg.use_flash_attention = cfg.model.use_flash_attention + + # This is needed when modifying a hparam file directly to load `.ckpt` files. + # This is not needed to modify the cfg in `.nemo` files. + if add_cfg_to_tree: + OmegaConf.resolve(gpt_cfg) + gpt_cfg.cfg = gpt_cfg + + return gpt_cfg + + +def load_from_nemo(cls, cfg, trainer, gpt_cfg, modify_confg_fn): + gpt_cfg = modify_confg_fn(gpt_cfg, cfg, add_cfg_to_tree=False) + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.restore_from_path): + save_restore_connector.model_extracted_dir = cfg.restore_from_path + model = cls.restore_from( + restore_path=cfg.restore_from_path, + trainer=trainer, + override_config_path=gpt_cfg, + save_restore_connector=save_restore_connector, + ) + return model + + +def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn): + app_state = AppState() + if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1: + app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size + app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size + app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size + ( + app_state.tensor_model_parallel_rank, + app_state.pipeline_model_parallel_rank, + app_state.model_parallel_size, + app_state.data_parallel_size, + app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, + ) = fake_initialize_model_parallel( + world_size=app_state.model_parallel_size, + rank=trainer.global_rank, + tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size, + pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank_=cfg.model.pipeline_model_parallel_split_rank, + ) + checkpoint_path = inject_model_parallel_rank( + os.path.join(cfg.model.pretrained_checkpoint.checkpoint_dir, cfg.model.pretrained_checkpoint.checkpoint_name) + ) + hparams_file = OmegaConf.load(cfg.model.pretrained_checkpoint.hparams_file) + gpt_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True) + with tempfile.NamedTemporaryFile(suffix='.yaml') as f: + OmegaConf.save(config=gpt_cfg, f=f.name) + model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,) + return model + + +def validate_checkpoint_loading_args(cfg): + if cfg.checkpoint_dir is None or not os.path.isdir(cfg.checkpoint_dir): + raise ValueError(f'Checkpoint directory {cfg.checkpoint_dir} does not exist or is not a directory.') + if cfg.checkpoint_name is None: + raise ValueError(f'Checkpoint name {cfg.checkpoint_name} is not valid.') + if cfg.hparams_file is None or not os.path.isfile(cfg.hparams_file): + raise ValueError(f'Hparams file {cfg.hparams_file} does not exist or is not a file.') + + +@hydra_runner(config_path="conf", config_name="megatron_gpt_config") +def main(cfg) -> None: + logging.info("\n\n************** Experiment configuration ***********") + logging.info(f'\n{OmegaConf.to_yaml(cfg)}') + + megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) + with_distributed_adam = cfg.model.optim.get('name', 'fused_adam') == 'distributed_fused_adam' + plugins = [] + strategy = NLPDDPStrategy( + no_ddp_communication_hook=True, + gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, + find_unused_parameters=False, + ) + if cfg.trainer.precision in [16, 'bf16']: + scaler = None + if cfg.trainer.precision == 16: + scaler = GradScaler( + init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + growth_interval=cfg.model.get('native_amp_growth_interval', 1000), + hysteresis=cfg.model.get('hysteresis', 2), + ) + if megatron_amp_o2 and not with_distributed_adam: + plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) + + if cfg.get('cluster_type', None) == 'BCP': + plugins.append(TorchElasticEnvironment()) + + trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) + + exp_manager(trainer, cfg.exp_manager) + + # update resume from checkpoint found by exp_manager + if cfg.model.resume_from_checkpoint is not None: + resume_from_checkpoint = cfg.model.resume_from_checkpoint + else: + resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path + logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') + + trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) + + if cfg.restore_from_path: + save_restore_connector = NLPSaveRestoreConnector() + if os.path.isdir(cfg.restore_from_path): + save_restore_connector.model_extracted_dir = cfg.restore_from_path + gpt_cfg = MegatronGPTModel.restore_from( + restore_path=cfg.restore_from_path, + trainer=trainer, + return_config=True, + save_restore_connector=save_restore_connector, + ) + model = load_from_nemo(MegatronGPTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config) + elif cfg.model.get("pretrained_checkpoint", None) is not None: + validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) + model = load_from_checkpoint_dir(MegatronGPTModel, cfg, trainer, gpt_cfg, modify_confg_fn=_modify_config) + else: + print(' > WARNING: No checkpoint provided. Starting from scratch.') + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams + with open_dict(cfg): + cfg.model.precision = cfg.trainer.precision + model = MegatronGPTModel(cfg.model, trainer) + trainer.fit(model) + + +if __name__ == '__main__': + main() diff --git a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml index f8a8e6b9dbc0..0e3f0d712dd6 100644 --- a/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml +++ b/examples/nlp/language_modeling/tuning/conf/megatron_gpt_sft.yaml @@ -60,6 +60,8 @@ model: activations_checkpoint_num_layers: null # not used with 'selective' answer_only_loss: False # not used right now gradient_as_bucket_view: False + seq_len_interpolation_factor: null # if not None, seq_len_interpolation_factor will match the base model's value + use_flash_attention: null # if not None, will match the base model's value hidden_dropout: 0.0 attention_dropout: 0.0 diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py index fc427a60d172..ed60328fd812 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_peft_eval.py @@ -127,6 +127,10 @@ def main(cfg) -> None: peft_model_cfg.data.test_ds = cfg.model.data.test_ds peft_model_cfg.activations_checkpoint_granularity = None peft_model_cfg.activations_checkpoint_method = None + if peft_model_cfg.get("use_flash_attention", False): + peft_model_cfg.use_flash_attention = cfg.model.use_flash_attention + if cfg.model.get("seq_len_interpolation_factor", None) is not None: + peft_model_cfg["seq_len_interpolation_factor"] = cfg.model.seq_len_interpolation_factor with open_dict(cfg): # update the config with the trained model config diff --git a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py index 0737d55cc514..eb4bd3125cd0 100644 --- a/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py +++ b/examples/nlp/language_modeling/tuning/megatron_gpt_sft.py @@ -64,6 +64,15 @@ def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False): sft_cls = MegatronGPTSFTModel gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" + if cfg.model.get('use_flash_attention', None) is not None: + gpt_cfg.use_flash_attention = cfg.model.use_flash_attention + + if cfg.model.get('seq_len_interpolation_factor', None) is not None: + gpt_cfg.seq_len_interpolation_factor = cfg.model.seq_len_interpolation_factor + + sft_cls = MegatronGPTSFTModel + gpt_cfg.target = f"{sft_cls.__module__}.{sft_cls.__name__}" + # This is needed when modifying a hparam file directly to load `.ckpt` files. # This is not needed to modify the cfg in `.nemo` files. if add_cfg_to_tree: diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py index 8e28b6cab362..d70c3e06bf01 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -166,6 +166,7 @@ def __init__( use_emha=False, ub_tp_comm_overlap=False, use_flash_attention=False, + seq_len_interpolation_factor=None, ): super(GPTModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights) @@ -249,6 +250,7 @@ def __init__( use_emha=use_emha, ub_tp_comm_overlap=ub_tp_comm_overlap, use_flash_attention=use_flash_attention, + seq_len_interpolation_factor=seq_len_interpolation_factor, ) if self.share_embeddings_and_output_weights: diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 44b484b28949..55c3786a3d96 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -249,10 +249,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): if isinstance(self.model, list): converted_model = [] for module in self.model: - converted_model.append(Float16Module(module=module, precision=cfg.precision)) + converted_model.append( + Float16Module( + module=module, + precision=cfg.precision, + share_token_embeddings=self.cfg.get('share_embeddings_and_output_weights', True), + ) + ) self.model = converted_model else: - self.model = Float16Module(module=self.model, precision=cfg.precision) + self.model = Float16Module( + module=self.model, + precision=cfg.precision, + share_token_embeddings=self.cfg.get('share_embeddings_and_output_weights', True), + ) if self.trainer.precision == 'bf16': self.autocast_dtype = torch.bfloat16 @@ -360,6 +370,7 @@ def model_provider_func(self, pre_process, post_process): ub_tp_comm_overlap=self.cfg.get('ub_tp_comm_overlap', False), use_flash_attention=self.cfg.get('use_flash_attention', False), megatron_legacy=self.cfg.get('megatron_legacy', False), + seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None), ) return model @@ -981,7 +992,7 @@ def build_pretraining_data_loader( data_parallel_size=parallel_state.get_data_parallel_world_size(), drop_last=drop_last, global_batch_size=self.cfg.global_batch_size, - rampup_batch_size=self.cfg.rampup_batch_size, + rampup_batch_size=self.cfg.get('rampup_batch_size', None), pad_samples_to_global_batch_size=pad_samples_to_global_batch_size, ) elif self.cfg.data.dataloader_type == 'cyclic': diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 683163246379..2aa2e8a3860e 100755 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -123,6 +123,7 @@ def get_language_model( use_emha=False, ub_tp_comm_overlap=False, use_flash_attention=False, + seq_len_interpolation_factor=None, ): """Build language model and return along with the key to save.""" @@ -200,6 +201,7 @@ def get_language_model( use_emha=use_emha, ub_tp_comm_overlap=ub_tp_comm_overlap, use_flash_attention=use_flash_attention, + seq_len_interpolation_factor=seq_len_interpolation_factor, ) # key used for checkpoints. language_model_key = 'language_model' @@ -508,6 +510,7 @@ def __init__( use_emha=False, ub_tp_comm_overlap=False, use_flash_attention=False, + seq_len_interpolation_factor=None, ): super(TransformerLanguageModel, self).__init__(share_token_embeddings=share_embeddings_and_output_weights) @@ -559,7 +562,9 @@ def __init__( assert 0 < rotary_percentage <= 1 if rotary_percentage < 1: rotary_dim = int(rotary_dim * rotary_percentage) - self.rotary_pos_emb = RotaryEmbedding(rotary_dim) + self.rotary_pos_emb = RotaryEmbedding( + rotary_dim, seq_len_interpolation_factor=seq_len_interpolation_factor + ) elif position_embedding_type == 'alibi': # TODO: If this is used for encoder-decodemax_position_embeddingsr model, implement proper logic and following diff --git a/nemo/collections/nlp/modules/common/megatron/module.py b/nemo/collections/nlp/modules/common/megatron/module.py index 22a223013fd2..0c8c811c2661 100644 --- a/nemo/collections/nlp/modules/common/megatron/module.py +++ b/nemo/collections/nlp/modules/common/megatron/module.py @@ -254,12 +254,12 @@ def float_conversion(val): class Float16Module(MegatronModule): - def __init__(self, module, precision): + def __init__(self, module, precision, share_token_embeddings=True): if not HAVE_MEGATRON_CORE: raise ImportError( "Megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) - super().__init__() + super().__init__(share_token_embeddings=share_token_embeddings) self.precision = precision if precision == 'bf16': diff --git a/nemo/collections/nlp/modules/common/megatron/position_embedding/rotary_position_embedding.py b/nemo/collections/nlp/modules/common/megatron/position_embedding/rotary_position_embedding.py index 5a8d6d7dd333..c97010ecb911 100644 --- a/nemo/collections/nlp/modules/common/megatron/position_embedding/rotary_position_embedding.py +++ b/nemo/collections/nlp/modules/common/megatron/position_embedding/rotary_position_embedding.py @@ -21,13 +21,28 @@ class RotaryEmbedding(nn.Module): - def __init__(self, dim): + """ + Implements Rotary Position Embedding from https://arxiv.org/abs/2104.09864. + """ + + def __init__(self, dim: int, seq_len_interpolation_factor: int = None): + """ + Args: + + dim (int): rotary embedding dimension + seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated + by this factor via the trick in https://arxiv.org/abs/2306.15595. + """ super().__init__() + self.seq_len_interpolation_factor = seq_len_interpolation_factor inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer('inv_freq', inv_freq) def forward(self, max_seq_len, offset=0): seq = torch.arange(max_seq_len, device=self.inv_freq.device) + offset + if self.seq_len_interpolation_factor is not None: + seq = seq.type_as(self.inv_freq) + seq *= 1 / self.seq_len_interpolation_factor freqs = einsum('i , j -> i j', seq.type_as(self.inv_freq), self.inv_freq) # first part even vector components, second part odd vector components, # 2 * dim in dimension size