From 5b8148d04d0b4c09c7744f44497530cb19e2bae5 Mon Sep 17 00:00:00 2001 From: David Mosallanezhad Date: Wed, 28 Sep 2022 19:54:47 -0700 Subject: [PATCH] fixed megatron lm conversion bug (PTL related) Signed-off-by: David Mosallanezhad --- examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py index 9dcd6d4451f4..ea2a783d991e 100644 --- a/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py +++ b/examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py @@ -39,6 +39,7 @@ import torch from apex.transformer import parallel_state +from pytorch_lightning.core.saving import _load_state as ptl_load_state from pytorch_lightning.core.saving import load_hparams_from_tags_csv, load_hparams_from_yaml from pytorch_lightning.trainer.trainer import Trainer from pytorch_lightning.utilities.cloud_io import load as pl_load @@ -220,10 +221,10 @@ def add_optimizer_state(lm_checkpoint, new_checkpoint, megatron_amp_o2=True): def load_model(cls, checkpoint, strict, **kwargs): try: if 'cfg' in kwargs: - model = cls._load_model_state(checkpoint, strict=strict, **kwargs) + model = ptl_load_state(cls, checkpoint, strict=strict, **kwargs) else: - model = cls._load_model_state( - checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg, **kwargs + model = ptl_load_state( + cls, checkpoint, strict=strict, cfg=checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg, **kwargs ) # register the artifacts cfg = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].cfg