diff --git a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py index 2f57fabf6f24..cade1d0c235e 100644 --- a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py @@ -43,6 +43,7 @@ import torch from apex.transformer import parallel_state +from pytorch_lightning.core.saving import _load_state as ptl_load_state from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -224,10 +225,10 @@ def add_optimizer_state(lm_checkpoint, new_checkpoint, megatron_amp_o2=True): def load_model(cls, checkpoint, strict, **kwargs): try: if 'cfg' in kwargs: - model = cls._load_model_state(checkpoint, strict=strict, **kwargs) + model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs) else: - model = cls._load_model_state( - checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg, **kwargs + model = ptl_load_state( + cls, checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg, **kwargs ) # register the artifacts cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg