diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py index 32b22df22d2c..4d6870b89e99 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py @@ -74,6 +74,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): raise ImportError( "Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." ) + assert trainer.max_steps > 0, "max_steps for SFT can't be negative as its required to build the dataset" super().__init__(cfg, trainer=trainer) self.sep_id = cfg.get('sep_id', 49704) if hasattr(self.cfg.data, "validation_ds"):