diff --git a/src/super_gradients/training/sg_trainer/sg_trainer.py b/src/super_gradients/training/sg_trainer/sg_trainer.py index 4fa7275a47..4aa00c4288 100755 --- a/src/super_gradients/training/sg_trainer/sg_trainer.py +++ b/src/super_gradients/training/sg_trainer/sg_trainer.py @@ -76,6 +76,7 @@ read_ckpt_state_dict, load_checkpoint_to_model, load_pretrained_weights, + get_scheduler_state, ) from super_gradients.training.datasets.datasets_utils import DatasetStatisticsTensorboardLogger from super_gradients.training.utils.callbacks import ( @@ -616,7 +617,7 @@ def _save_checkpoint( state["processing_params"] = processing_params if self._torch_lr_scheduler is not None: - state["torch_scheduler_state_dict"] = self._torch_lr_scheduler.state_dict() + state["torch_scheduler_state_dict"] = get_scheduler_state(self._torch_lr_scheduler) # SAVES CURRENT MODEL AS ckpt_latest self.sg_logger.add_checkpoint(tag="ckpt_latest.pth", state_dict=state, global_step=epoch) diff --git a/src/super_gradients/training/utils/checkpoint_utils.py b/src/super_gradients/training/utils/checkpoint_utils.py index 304ee27216..819a9ac29a 100644 --- a/src/super_gradients/training/utils/checkpoint_utils.py +++ b/src/super_gradients/training/utils/checkpoint_utils.py @@ -1,11 +1,12 @@ import collections import os import tempfile -from typing import Union, Mapping +from typing import Union, Mapping, Dict import pkg_resources import torch from torch import nn, Tensor +from torch.optim.lr_scheduler import CyclicLR from super_gradients.common.abstractions.abstract_logger import get_logger from super_gradients.common.data_interface.adnn_model_repository_data_interface import ADNNModelRepositoryDataInterfaces @@ -1597,3 +1598,18 @@ def load_pretrained_weights_local(model: torch.nn.Module, architecture: str, pre pretrained_state_dict = torch.load(pretrained_weights, map_location=map_location) _load_weights(architecture, model, pretrained_state_dict) + + +def get_scheduler_state(scheduler) -> Dict: + """ + Wrapper for getting a torch lr scheduler state dict, resolving some issues with CyclicLR + (see https://github.com/pytorch/pytorch/pull/91400) + :param scheduler: torch.optim.lr_scheduler._LRScheduler, the scheduler + :return: the scheduler's state_dict + """ + from super_gradients.training.utils import torch_version_is_greater_or_equal + + state = scheduler.state_dict() + if isinstance(scheduler, CyclicLR) and not torch_version_is_greater_or_equal(2, 0): + del state["_scale_fn_ref"] + return state