From 4d5b1e35680abde24abdb47fb131cb98a77b58c8 Mon Sep 17 00:00:00 2001 From: Sandeep Subramanian Date: Tue, 15 Nov 2022 16:17:20 -0800 Subject: [PATCH] Support for finetuning and finetuning inference with .ckpt files & batch size refactoring (#5339) * Initial refactor Signed-off-by: MaximumEntropy * Resolve config before passing to load_from_checkpoint Signed-off-by: MaximumEntropy * Fixes for model parallel and nemo restore Signed-off-by: MaximumEntropy * Fixes for eval Signed-off-by: MaximumEntropy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Revert config changes Signed-off-by: MaximumEntropy * Refactor Signed-off-by: MaximumEntropy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix typo Signed-off-by: MaximumEntropy * Remove comments Signed-off-by: MaximumEntropy * Minor Signed-off-by: MaximumEntropy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix validation reconfiguration Signed-off-by: MaximumEntropy * Remove old comment Signed-off-by: MaximumEntropy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes for test_ds Signed-off-by: MaximumEntropy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: MaximumEntropy Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../conf/megatron_t0_config.yaml | 8 +- .../megatron_t5_config_finetune_eval.yaml | 6 +- ...megatron_t5_config_finetune_glue_eval.yaml | 9 +- ...megatron_t5_config_finetune_glue_mnli.yaml | 6 +- ...megatron_t5_config_finetune_glue_xnli.yaml | 8 +- .../conf/megatron_t5_finetune.yaml | 6 +- .../megatron_t5_seq2seq_eval.py | 119 ++++++------ .../megatron_t5_seq2seq_finetune.py | 172 +++++++++++++----- .../megatron_finetune_model.py | 132 +++++--------- 9 files changed, 270 insertions(+), 196 deletions(-) diff --git a/examples/nlp/language_modeling/conf/megatron_t0_config.yaml b/examples/nlp/language_modeling/conf/megatron_t0_config.yaml index 04503cac769f..503dc17d2acc 100644 --- a/examples/nlp/language_modeling/conf/megatron_t0_config.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t0_config.yaml @@ -36,7 +36,11 @@ exp_manager: save_best_model: True model: - restore_from_path: ??? # Path to a trained T5 or LM-adapted T5 .nemo file + restore_from_path: null # Path to a trained T5 .nemo file + pretrained_checkpoint: + checkpoint_dir: null # Path to a folder that contains a .ckpt file + checkpoint_name: null # Name of the .ckpt file within the checkpoint_dir. + hparams_file: null # Path to a .yaml file that contains the hyperparameters of the checkpoint. tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 pipeline_model_parallel_split_rank: 0 @@ -82,7 +86,7 @@ model: num_classes: null replace_bos_with_pad: ${data.train_ds.replace_bos_with_pad} add_bos_to_input: ${data.train_ds.add_bos_to_input} - add_eos_to_input: ${data.train_ds.replace_bos_with_pad} + add_eos_to_input: ${data.train_ds.add_eos_to_input} seed: 1234 optim: diff --git a/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_eval.yaml b/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_eval.yaml index 8be471a78dde..bc1a7420df48 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_eval.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_eval.yaml @@ -17,7 +17,11 @@ exp_manager: create_checkpoint_callback: False model: - restore_from_path: ??? # Path to a finetuned T5 .nemo file + restore_from_path: null # Path to a trained T5 .nemo file + pretrained_checkpoint: + checkpoint_dir: null # Path to a folder that contains a .ckpt file + checkpoint_name: null # Name of the .ckpt file within the checkpoint_dir. + hparams_file: null # Path to a .yaml file that contains the hyperparameters of the checkpoint. gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) megatron_amp_O2: False # Enable O2 optimization for megatron amp diff --git a/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_eval.yaml b/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_eval.yaml index 87ce5ac03eb5..024ad5f66ae9 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_eval.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_eval.yaml @@ -17,9 +17,16 @@ exp_manager: create_checkpoint_callback: False model: - restore_from_path: ??? # Path to a finetuned T5 .nemo file + restore_from_path: null # Path to a trained T5 .nemo file + pretrained_checkpoint: + checkpoint_dir: null # Path to a folder that contains a .ckpt file + checkpoint_name: null # Name of the .ckpt file within the checkpoint_dir. + hparams_file: null # Path to a .yaml file that contains the hyperparameters of the checkpoint. gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) megatron_amp_O2: False # Enable O2 optimization for megatron amp + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + pipeline_model_parallel_split_rank: 0 data: validation_ds: diff --git a/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_mnli.yaml b/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_mnli.yaml index ac68b57e0216..ff61c5fde20c 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_mnli.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_mnli.yaml @@ -37,7 +37,11 @@ exp_manager: save_best_model: True model: - restore_from_path: ??? # Path to a trained T5 .nemo file + restore_from_path: null # Path to a trained T5 .nemo file + pretrained_checkpoint: + checkpoint_dir: null # Path to a folder that contains a .ckpt file + checkpoint_name: null # Name of the .ckpt file within the checkpoint_dir. + hparams_file: null # Path to a .yaml file that contains the hyperparameters of the checkpoint. tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 pipeline_model_parallel_split_rank: 0 diff --git a/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_xnli.yaml b/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_xnli.yaml index 1b08bc37246e..486a6da14135 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_xnli.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_config_finetune_glue_xnli.yaml @@ -37,9 +37,13 @@ exp_manager: save_best_model: True model: - restore_from_path: ??? + restore_from_path: null # Path to a trained T5 .nemo file + pretrained_checkpoint: + checkpoint_dir: null # Path to a folder that contains a .ckpt file + checkpoint_name: null # Name of the .ckpt file within the checkpoint_dir. + hparams_file: null # Path to a .yaml file that contains the hyperparameters of the checkpoint. tensor_model_parallel_size: 1 - pipeline_model_parallel_size: 2 + pipeline_model_parallel_size: 1 pipeline_model_parallel_split_rank: 1 gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) resume_from_checkpoint: null diff --git a/examples/nlp/language_modeling/conf/megatron_t5_finetune.yaml b/examples/nlp/language_modeling/conf/megatron_t5_finetune.yaml index 9a5cf15cfe74..8c383aad9c78 100644 --- a/examples/nlp/language_modeling/conf/megatron_t5_finetune.yaml +++ b/examples/nlp/language_modeling/conf/megatron_t5_finetune.yaml @@ -36,7 +36,11 @@ exp_manager: save_best_model: True model: - restore_from_path: ??? # Path to a trained T5 .nemo file + restore_from_path: null # Path to a trained T5 .nemo file + pretrained_checkpoint: + checkpoint_dir: null # Path to a folder that contains a .ckpt file + checkpoint_name: null # Name of the .ckpt file within the checkpoint_dir. + hparams_file: null # Path to a .yaml file that contains the hyperparameters of the checkpoint. tensor_model_parallel_size: 1 pipeline_model_parallel_size: 1 pipeline_model_parallel_split_rank: 0 diff --git a/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py b/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py index 25fd84d800d4..e78d34adee65 100644 --- a/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py +++ b/examples/nlp/language_modeling/megatron_t5_seq2seq_eval.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from megatron_t5_seq2seq_finetune import load_from_checkpoint_dir, load_from_nemo, validate_checkpoint_loading_args from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer from pytorch_lightning.callbacks.timer import Timer @@ -21,17 +22,51 @@ from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model -from nemo.collections.nlp.parts.nlp_overrides import ( - GradScaler, - MegatronHalfPrecisionPlugin, - NLPDDPStrategy, - NLPSaveRestoreConnector, -) +from nemo.collections.nlp.parts.nlp_overrides import GradScaler, MegatronHalfPrecisionPlugin, NLPDDPStrategy from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import StatelessTimer, exp_manager +def _modify_config(t5_cfg, cfg, add_cfg_to_tree=False): + """ + This function modifies the original t5 pre-training config (t5_cfg) with attributes from the finetuning config (cfg). + The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. + """ + OmegaConf.set_struct(t5_cfg, True) + with open_dict(t5_cfg): + t5_cfg.precision = cfg.trainer.precision + # Overwrite data configs + if cfg.model.data.validation_ds.get('src_file_name', None) is not None: + logging.info( + 'Found validation_ds.src_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.' + ) + t5_cfg.data.validation_ds.src_file_name = cfg.model.data.validation_ds.src_file_name + if cfg.model.data.validation_ds.get('tgt_file_name', None) is not None: + logging.info( + 'Found validation_ds.tgt_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.' + ) + t5_cfg.data.validation_ds.tgt_file_name = cfg.model.data.validation_ds.tgt_file_name + + if "write_predictions_to_file" in cfg.model.data.validation_ds: + t5_cfg.data.validation_ds.write_predictions_to_file = ( + cfg.model.data.validation_ds.write_predictions_to_file + ) + if "output_file_path_prefix" in cfg.model.data.validation_ds: + t5_cfg.data.validation_ds.output_file_path_prefix = cfg.model.data.validation_ds.output_file_path_prefix + + t5_cfg.data.validation_ds.micro_batch_size = cfg.model.data.validation_ds.micro_batch_size + t5_cfg.data.validation_ds.global_batch_size = cfg.model.data.validation_ds.global_batch_size + + # This is needed when modifying a hparam file directly to load `.ckpt` files. + # This is not needed to modify the cfg in `.nemo` files. + if add_cfg_to_tree: + OmegaConf.resolve(t5_cfg) + t5_cfg.cfg = t5_cfg + + return t5_cfg + + @hydra_runner(config_path="conf", config_name="megatron_t5_config_finetune_glue_eval") def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") @@ -69,59 +104,33 @@ def main(cfg) -> None: if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) - t5_cfg = MegatronT5GLUEModel.restore_from( - restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True - ) - - # Override the T5 configuration with the one from the config file. - # NOTE: Only data can be overriden here since this the file being restored here should already correspond to a GLUE/XNLI finetuned model. - OmegaConf.set_struct(t5_cfg, True) - with open_dict(t5_cfg): - t5_cfg.precision = cfg.trainer.precision - # Overwrite data configs - if cfg.model.data.validation_ds.get('src_file_name', None) is not None: - logging.info( - 'Found validation_ds.src_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.' + if hasattr(cfg.model.data.validation_ds, 'task_name'): + if cfg.model.restore_from_path: + t5_cfg = MegatronT5GLUEModel.restore_from( + restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True ) - t5_cfg.data.validation_ds.src_file_name = cfg.model.data.validation_ds.src_file_name - if cfg.model.data.validation_ds.get('tgt_file_name', None) is not None: - logging.info( - 'Found validation_ds.tgt_file_name in the config file. Overriding the finetuned model config file with the values from the new config file.' - ) - t5_cfg.data.validation_ds.tgt_file_name = cfg.model.data.validation_ds.tgt_file_name - - if "write_predictions_to_file" in cfg.model.data.validation_ds: - t5_cfg.data.validation_ds.write_predictions_to_file = ( - cfg.model.data.validation_ds.write_predictions_to_file - ) - if "output_file_path_prefix" in cfg.model.data.validation_ds: - t5_cfg.data.validation_ds.output_file_path_prefix = cfg.model.data.validation_ds.output_file_path_prefix - t5_cfg.data.validation_ds.src_file_name = cfg.model.data.validation_ds.src_file_name - - t5_cfg.data.validation_ds.micro_batch_size = cfg.model.data.validation_ds.micro_batch_size - t5_cfg.data.validation_ds.global_batch_size = cfg.model.data.validation_ds.global_batch_size - - if hasattr(cfg.model.data.validation_ds, 'task_name'): - model = MegatronT5GLUEModel.restore_from( - restore_path=cfg.model.restore_from_path, - trainer=trainer, - override_config_path=t5_cfg, - save_restore_connector=NLPSaveRestoreConnector(), - ) - elif hasattr(cfg.model.data.validation_ds, 'file_names'): - model = MegatronT0Model.restore_from( - restore_path=cfg.model.restore_from_path, - trainer=trainer, - override_config_path=t5_cfg, - save_restore_connector=NLPSaveRestoreConnector(), + model = load_from_nemo(MegatronT5GLUEModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config) + else: + validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) + model = load_from_checkpoint_dir(MegatronT5GLUEModel, cfg, trainer, modify_confg_fn=_modify_config) + elif hasattr(cfg.model.data.validation_ds, 'file_names'): + if cfg.model.restore_from_path: + t5_cfg = MegatronT0Model.restore_from( + restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True ) + model = load_from_nemo(MegatronT0Model, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config) else: - model = MegatronT5FinetuneModel.restore_from( - restore_path=cfg.model.restore_from_path, - trainer=trainer, - override_config_path=t5_cfg, - save_restore_connector=NLPSaveRestoreConnector(), + validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) + model = load_from_checkpoint_dir(MegatronT0Model, cfg, trainer, modify_confg_fn=_modify_config) + else: + if cfg.model.restore_from_path: + t5_cfg = MegatronT5FinetuneModel.restore_from( + restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True ) + model = load_from_nemo(MegatronT5FinetuneModel, cfg, trainer, modify_confg_fn=_modify_config) + else: + validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) + model = load_from_checkpoint_dir(MegatronT5FinetuneModel, cfg, trainer, modify_confg_fn=_modify_config) model.freeze() trainer.validate(model) diff --git a/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py b/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py index 22883657736f..84b78739f673 100644 --- a/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py +++ b/examples/nlp/language_modeling/megatron_t5_seq2seq_finetune.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import tempfile + from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer from pytorch_lightning.callbacks.timer import Timer @@ -21,6 +24,7 @@ from nemo.collections.nlp.models.language_modeling.megatron_finetune_model import MegatronT5FinetuneModel from nemo.collections.nlp.models.language_modeling.megatron_glue_model import MegatronT5GLUEModel from nemo.collections.nlp.models.language_modeling.megatron_t0_model import MegatronT0Model +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel from nemo.collections.nlp.parts.nlp_overrides import ( GradScaler, MegatronHalfPrecisionPlugin, @@ -29,8 +33,99 @@ PipelineMixedPrecisionPlugin, ) from nemo.core.config import hydra_runner -from nemo.utils import logging +from nemo.utils import AppState, logging from nemo.utils.exp_manager import StatelessTimer, exp_manager +from nemo.utils.model_utils import inject_model_parallel_rank + + +def _modify_config(t5_cfg, cfg, add_cfg_to_tree=False): + """ + This function modifies the original t5 pre-training config (t5_cfg) with attributes from the finetuning config (cfg). + The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. + """ + OmegaConf.set_struct(t5_cfg, True) + with open_dict(t5_cfg): + t5_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) + if hasattr(t5_cfg, 'encoder') and hasattr(t5_cfg, 'decoder'): + t5_cfg.encoder.masked_softmax_fusion = False + t5_cfg.decoder.masked_softmax_fusion = False + t5_cfg.encoder.hidden_dropout = cfg.model.get('hidden_dropout', 0.1) + t5_cfg.decoder.hidden_dropout = cfg.model.get('hidden_dropout', 0.1) + if hasattr(t5_cfg.encoder, 'ffn_dropout'): + t5_cfg.encoder.ffn_dropout = cfg.model.get('ffn_dropout', 0.1) + if hasattr(t5_cfg.decoder, 'ffn_dropout'): + t5_cfg.decoder.ffn_dropout = cfg.model.get('ffn_dropout', 0.1) + else: + t5_cfg.hidden_dropout = cfg.model.get('hidden_dropout', 0.1) + t5_cfg.attention_dropout = cfg.model.get('attention_dropout', 0.1) + t5_cfg.masked_softmax_fusion = False + t5_cfg.data = cfg.model.data + t5_cfg.precision = cfg.trainer.precision + t5_cfg.optim = cfg.model.optim + t5_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size + t5_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size + # XNLI has eval languages in the yaml config. + if hasattr(cfg.model, 'eval_languages'): + t5_cfg.eval_languages = cfg.model.eval_languages + + # This is needed when modifying a hparam file directly to load `.ckpt` files. + # This is not needed to modify the cfg in `.nemo` files. + if add_cfg_to_tree: + OmegaConf.resolve(t5_cfg) + t5_cfg.cfg = t5_cfg + + return t5_cfg + + +def load_from_nemo(cls, cfg, trainer, t5_cfg, modify_confg_fn): + t5_cfg = modify_confg_fn(t5_cfg, cfg, add_cfg_to_tree=False) + model = cls.restore_from( + restore_path=cfg.model.restore_from_path, + trainer=trainer, + override_config_path=t5_cfg, + save_restore_connector=NLPSaveRestoreConnector(), + ) + return model + + +def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn): + app_state = AppState() + if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1: + app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size + app_state.tensor_model_parallel_size = cfg.model.tensor_model_parallel_size + app_state.pipeline_model_parallel_size = cfg.model.pipeline_model_parallel_size + ( + app_state.tensor_model_parallel_rank, + app_state.pipeline_model_parallel_rank, + app_state.model_parallel_size, + app_state.data_parallel_size, + app_state.pipeline_model_parallel_split_rank, + app_state.virtual_pipeline_model_parallel_rank, + ) = fake_initialize_model_parallel( + world_size=app_state.model_parallel_size, + rank=trainer.global_rank, + tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size, + pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size, + pipeline_model_parallel_split_rank_=cfg.model.pipeline_model_parallel_split_rank, + ) + checkpoint_path = inject_model_parallel_rank( + os.path.join(cfg.model.pretrained_checkpoint.checkpoint_dir, cfg.model.pretrained_checkpoint.checkpoint_name) + ) + hparams_file = OmegaConf.load(cfg.model.pretrained_checkpoint.hparams_file) + t5_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True) + with tempfile.NamedTemporaryFile(suffix='.yaml') as f: + OmegaConf.save(config=t5_cfg, f=f.name) + model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,) + return model + + +def validate_checkpoint_loading_args(cfg): + if cfg.checkpoint_dir is None or not os.path.isdir(cfg.checkpoint_dir): + raise ValueError(f'Checkpoint directory {cfg.checkpoint_dir} does not exist or is not a directory.') + if cfg.checkpoint_name is None: + raise ValueError(f'Checkpoint name {cfg.checkpoint_name} is not valid.') + if cfg.hparams_file is None or not os.path.isfile(cfg.hparams_file): + raise ValueError(f'Hparams file {cfg.hparams_file} does not exist or is not a file.') @hydra_runner(config_path="conf", config_name="megatron_t5_config_finetune_glue_mnli") @@ -78,58 +173,35 @@ def main(cfg) -> None: if isinstance(callback, Timer): trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) - # Get the T5 Base configuration. - t5_cfg = MegatronT5FinetuneModel.restore_from( - restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True - ) - - # Override the T5 configuration with the one from the config file. - OmegaConf.set_struct(t5_cfg, True) - with open_dict(t5_cfg): - t5_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) - if hasattr(t5_cfg, 'encoder') and hasattr(t5_cfg, 'decoder'): - t5_cfg.encoder.masked_softmax_fusion = False - t5_cfg.decoder.masked_softmax_fusion = False - t5_cfg.encoder.hidden_dropout = cfg.model.get('hidden_dropout', 0.1) - t5_cfg.decoder.hidden_dropout = cfg.model.get('hidden_dropout', 0.1) - if hasattr(t5_cfg.encoder, 'ffn_dropout'): - t5_cfg.encoder.ffn_dropout = cfg.model.get('ffn_dropout', 0.1) - if hasattr(t5_cfg.decoder, 'ffn_dropout'): - t5_cfg.decoder.ffn_dropout = cfg.model.get('ffn_dropout', 0.1) - else: - t5_cfg.hidden_dropout = cfg.model.get('hidden_dropout', 0.1) - t5_cfg.attention_dropout = cfg.model.get('attention_dropout', 0.1) - t5_cfg.masked_softmax_fusion = False - t5_cfg.data = cfg.model.data - t5_cfg.precision = cfg.trainer.precision - t5_cfg.optim = cfg.model.optim - t5_cfg.micro_batch_size = cfg.model.data.train_ds.micro_batch_size - t5_cfg.global_batch_size = cfg.model.data.train_ds.global_batch_size - # XNLI has eval languages in the yaml config. - if hasattr(cfg.model, 'eval_languages'): - t5_cfg.eval_languages = cfg.model.eval_languages - if hasattr(cfg.model.data.train_ds, 'task_name'): - model = MegatronT5GLUEModel.restore_from( - restore_path=cfg.model.restore_from_path, - trainer=trainer, - override_config_path=t5_cfg, - save_restore_connector=NLPSaveRestoreConnector(), - ) + if cfg.model.restore_from_path: + t5_cfg = MegatronT5GLUEModel.restore_from( + restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True + ) + model = load_from_nemo(MegatronT5GLUEModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config) + else: + validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) + model = load_from_checkpoint_dir(MegatronT5GLUEModel, cfg, trainer, modify_confg_fn=_modify_config) elif hasattr(cfg.model.data.train_ds, 'file_names'): - model = MegatronT0Model.restore_from( - restore_path=cfg.model.restore_from_path, - trainer=trainer, - override_config_path=t5_cfg, - save_restore_connector=NLPSaveRestoreConnector(), - ) + if cfg.model.restore_from_path: + t5_cfg = MegatronT0Model.restore_from( + restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True + ) + model = load_from_nemo(MegatronT0Model, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config) + else: + validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) + model = load_from_checkpoint_dir(MegatronT0Model, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config) else: - model = MegatronT5FinetuneModel.restore_from( - restore_path=cfg.model.restore_from_path, - trainer=trainer, - override_config_path=t5_cfg, - save_restore_connector=NLPSaveRestoreConnector(), - ) + if cfg.model.restore_from_path: + t5_cfg = MegatronT5FinetuneModel.restore_from( + restore_path=cfg.model.restore_from_path, trainer=trainer, return_config=True + ) + model = load_from_nemo(MegatronT5FinetuneModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config) + else: + validate_checkpoint_loading_args(cfg.model.pretrained_checkpoint) + model = load_from_checkpoint_dir( + MegatronT5FinetuneModel, cfg, trainer, t5_cfg, modify_confg_fn=_modify_config + ) trainer.fit(model) trainer.validate(model) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py index 4da8ab57f367..c49d7b50580a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_finetune_model.py @@ -132,22 +132,9 @@ def setup(self, stage=None): self.setup_training_data() def _process_global_batch(self, global_batch): - """Process a list of microbatches into a global batch.""" - # If there is no language information in the global batch (ex: English MNLI), we can use the parent global batch processor as is. - if 'lang' not in global_batch[0]: - return self._process_global_batch_without_megatron_batch_sampler(global_batch) - - # For validation data (XNLI), we need to process the global batch and and then deal with language info separately. - else: - assert all(['lang' in micro_batch for micro_batch in global_batch]) - langs_list = [] - processed_global_batch = self._process_global_batch_without_megatron_batch_sampler( - [{k: v for k, v in micro_batch.items() if k != 'lang'} for micro_batch in global_batch] - ) - for micro_batch in global_batch: - langs_list.extend(micro_batch['lang']) - processed_global_batch['lang'] = langs_list - return processed_global_batch + """Optionally processes a global batch.""" + # TODO: maybe remove this now that we've refactored data batch sizes. + return global_batch def on_validation_epoch_start(self): app_state = AppState() @@ -160,7 +147,26 @@ def on_validation_epoch_start(self): ) return super().on_validation_epoch_start() + def on_test_epoch_start(self): + app_state = AppState() + _reconfigure_microbatch_calculator( + rank=app_state.global_rank, + rampup_batch_size=None, + global_batch_size=self.cfg.data.test_ds.global_batch_size, + micro_batch_size=self.cfg.data.test_ds.micro_batch_size, + data_parallel_size=parallel_state.get_data_parallel_world_size(), + ) + return super().on_test_epoch_start() + + def on_test_epoch_end(self): + self.on_inference_epoch_end(self.cfg.data.test_ds) + return super().on_test_epoch_end() + def on_validation_epoch_end(self): + self.on_inference_epoch_end(self.cfg.data.validation_ds) + return super().on_validation_epoch_end() + + def on_inference_epoch_end(self, ds): app_state = AppState() if hasattr(self, "_train_ds"): _reconfigure_microbatch_calculator( @@ -176,35 +182,32 @@ def on_validation_epoch_end(self): _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, - global_batch_size=self.cfg.data.validation_ds.global_batch_size, - micro_batch_size=self.cfg.data.validation_ds.micro_batch_size, + global_batch_size=ds.global_batch_size, + micro_batch_size=ds.micro_batch_size, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) - return super().on_validation_epoch_end() - def on_train_epoch_start(self) -> None: # Same logic as validation epoch end, but this may be need if there is no validation sanity check to trigger validation_epoch_end() self.on_validation_epoch_end() return super().on_train_epoch_start() def training_step(self, batch, batch_idx): - micro_batch_size = batch[0]['text_enc'].size(0) + global_batch_size_per_gpu = batch['text_enc'].size(0) # This should happen only on the last batch of the dataset. - if micro_batch_size != self.cfg.data.train_ds.micro_batch_size: + if ( + global_batch_size_per_gpu + != self.cfg.data.train_ds.global_batch_size // parallel_state.get_data_parallel_world_size() + ): + # NOTE: This should never really be called since `drop_last=True` is required for training datasets. app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, - global_batch_size=micro_batch_size - * parallel_state.get_data_parallel_world_size() - * get_num_microbatches(), - micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu // get_num_microbatches(), data_parallel_size=parallel_state.get_data_parallel_world_size(), ) - # At this point batch is a list of dictionaries where eatch dict is a microbatch. - # After the process_global_batch call, batch will be a single dictionary containing the global batch. - # This is required since the parent class expects a single global batch dictioanry. batch = self._process_global_batch(batch) return super().training_step(batch, batch_idx) @@ -263,33 +266,30 @@ def cast_for_metric(self, pred, label, metric_name, class_labels=None, labels_ar return pred, label - def _reconfigure_and_process_inference_batch(self, batch): - micro_batch_size = batch[0]['text_enc'].size(0) + def _reconfigure_and_process_inference_batch(self, batch, ds_config): + global_batch_size_per_gpu = batch['text_enc'].size(0) # This should happen only on the last batch of the dataset. - if micro_batch_size != self.cfg.data.validation_ds.micro_batch_size: + if global_batch_size_per_gpu != ds_config.global_batch_size // parallel_state.get_data_parallel_world_size(): + # NOTE: This is reconfiguring to make sure there is no grad-acc for validation batches. app_state = AppState() _reconfigure_microbatch_calculator( rank=app_state.global_rank, rampup_batch_size=None, - global_batch_size=micro_batch_size - * parallel_state.get_data_parallel_world_size() - * get_num_microbatches(), - micro_batch_size=micro_batch_size, + global_batch_size=global_batch_size_per_gpu * parallel_state.get_data_parallel_world_size(), + micro_batch_size=global_batch_size_per_gpu, data_parallel_size=parallel_state.get_data_parallel_world_size(), ) - # At this point processed_batch is a list of dictionaries where eatch dict is a microbatch. - # After the process_global_batch call, processed_batch will be a single dictionary containing the global batch. - # This is required since the parent class expects a single global batch dictioanry. processed_batch = self._process_global_batch(batch) - return processed_batch def inference_step(self, batch, batch_idx, mode, dataloader_idx=0): # Regular finetuning datasets will return a list of dicts for each microbatch. But T0 datasets will return a single dict for the global batch. batch_has_lang_information = isinstance(batch, list) and len(batch[0]) == 7 - processed_batch = self._reconfigure_and_process_inference_batch(batch) + processed_batch = self._reconfigure_and_process_inference_batch( + batch, self.cfg.data.validation_ds if mode == 'validation' else self.cfg.data.test_ds + ) # Call parent validation step to get the loss. # NOTE: There could be extra keys in the processed_batch dictionary such as "langs" for XNLI, this will be ignored in the parent class. @@ -317,8 +317,12 @@ def inference_step(self, batch, batch_idx, mode, dataloader_idx=0): pred=pred, label=label, metric_name=self.val_metric_name if mode == 'validation' else self.test_metric_name, - class_labels=self.cfg.data.validation_ds.metric.get('class_labels', None), - labels_are_strings=self.cfg.data.validation_ds.metric.get('labels_are_strings', False), + class_labels=self.cfg.data.validation_ds.metric.get('class_labels', None) + if mode == 'validation' + else self.cfg.data.test_ds.metric.get('class_labels', None), + labels_are_strings=self.cfg.data.validation_ds.metric.get('labels_are_strings', False) + if mode == 'validation' + else self.cfg.data.test_ds.metric.get('labels_are_strings', False), ) if batch_has_lang_information: _ = metric(pred, label, category) @@ -505,15 +509,7 @@ def test_epoch_end(self, outputs): _ = self.inference_epoch_end(outputs, 'test', self.cfg.data.test_ds) def build_data_loader( - self, - dataset, - micro_batch_size, - global_batch_size, - shuffle, - num_workers, - pin_memory, - drop_last, - check_validation_interval, + self, dataset, global_batch_size, shuffle, num_workers, pin_memory, drop_last, ): """Buld dataloader given an input dataset.""" @@ -525,20 +521,6 @@ def build_data_loader( sampler = torch.utils.data.distributed.DistributedSampler( dataset, num_replicas=world_size, rank=rank, shuffle=shuffle ) - # This check makes sure the val_check_interval is less than the number of global batches. - # Normally, PTL would do this check and properly account for gradient accumulation. - # But now, it is implicit in the apex fwd/bwd functions and so we need to check for this somewhere. - # The consequence of not doing this is that training loop will never run validation. - # NOTE: Prog bar is also broken as a result of this. - global_batch_size_per_gpu = micro_batch_size * get_num_microbatches() - if ( - self.trainer.val_check_interval > (sampler.num_samples // global_batch_size_per_gpu) - and check_validation_interval - ): - raise ValueError( - f"trainer.val_check_interval {self.trainer.val_check_interval} is > number of global batches {sampler.num_samples // global_batch_size}" - ) - if isinstance(dataset, ConcatMapDataset): collate_fn = dataset.datasets[0].collate_fn else: @@ -548,7 +530,7 @@ def build_data_loader( dataset, collate_fn=collate_fn, sampler=sampler, - batch_size=micro_batch_size, + batch_size=global_batch_size // parallel_state.get_data_parallel_world_size(), num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, @@ -557,13 +539,11 @@ def build_data_loader( def setup_training_data(self): self._train_dl = self.build_data_loader( self._train_ds, - micro_batch_size=self.cfg.data.train_ds.micro_batch_size, global_batch_size=self.cfg.data.train_ds.global_batch_size, shuffle=self.cfg.data.train_ds.shuffle, num_workers=self.cfg.data.train_ds.num_workers, pin_memory=self.cfg.data.train_ds.pin_memory, drop_last=self.cfg.data.train_ds.drop_last, - check_validation_interval=True, ) def setup_eval_data(self, datasets, data_cfg): @@ -571,13 +551,11 @@ def setup_eval_data(self, datasets, data_cfg): for dataset in datasets: eval_dl = self.build_data_loader( dataset, - micro_batch_size=data_cfg.micro_batch_size, global_batch_size=data_cfg.global_batch_size, shuffle=data_cfg.shuffle, num_workers=data_cfg.num_workers, pin_memory=data_cfg.pin_memory, drop_last=data_cfg.drop_last, - check_validation_interval=False, ) dataloaders.append(eval_dl) return dataloaders @@ -689,15 +667,3 @@ def build_train_valid_test_datasets(self, stage): return self._train_ds = self._build_train_dataset(self.cfg.data.train_ds) logging.info(f'Finished building datasets ...') - - def on_train_start(self) -> None: - """PTL hook used to override DataFetcher with GlobalBatchDataFetcher """ - self.trainer.fit_loop._data_fetcher = GlobalBatchDataFetcher() - - def on_validation_start(self) -> None: - """PTL hook used to override DataFetcher with GlobalBatchDataFetcher """ - self.trainer.fit_loop.epoch_loop.val_loop._data_fetcher = GlobalBatchDataFetcher() - self.trainer.validate_loop._data_fetcher = GlobalBatchDataFetcher() - - def on_test_start(self) -> None: - self.trainer.test_loop._data_fetcher = GlobalBatchDataFetcher()