diff --git a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py index 9dcd6d4451f4..ea2a783d991e 100644 --- a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py @@ -39,6 +39,7 @@ import torch from apex.transformer import parallel_state +from pytorch_lightning.core.saving import _load_state as ptl_load_state from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -220,10 +221,10 @@ def add_optimizer_state(lm_checkpoint, new_checkpoint, megatron_amp_o2=True): def load_model(cls, checkpoint, strict, **kwargs): try: if 'cfg' in kwargs: - model = cls._load_model_state(checkpoint, strict=strict, **kwargs) + model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs) else: - model = cls._load_model_state( - checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg, **kwargs + model = ptl_load_state( + cls, checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg, **kwargs ) # register the artifacts cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg