From 42e59c6add29a5f91654a9c3a76febbe435df981 Mon Sep 17 00:00:00 2001 From: Roger Shieh Date: Fri, 20 Nov 2020 19:53:05 +0800 Subject: [PATCH] Cast hparams to dict when not using omegaconf (#4770) * init fix * init test * more specific dict assert * update changelog * Update tests/checkpointing/test_model_checkpoint.py Co-authored-by: Jirka Borovec Co-authored-by: chaton Co-authored-by: Jirka Borovec --- CHANGELOG.md | 23 +++++++++++++ .../connectors/checkpoint_connector.py | 5 ++- tests/checkpointing/test_model_checkpoint.py | 34 +++++++++++++++++++ 3 files changed, 59 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8a03085024776..1a4da285179ed 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -60,6 +60,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). +## [unreleased.BugFix] - YYYY-MM-DD + +### Added + + + +### Changed + + + +### Deprecated + + + +### Removed + + + +### Fixed + +- Fixed checkpoint hparams dict casting when omegaconf is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770)) + + ## [1.0.7] - 2020-11-17 ### Added diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index d98a3137be34a..eff7a1c7ae9dc 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -328,10 +328,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: if hasattr(model, '_hparams_name'): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # dump arguments - if OMEGACONF_AVAILABLE: + if OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - if isinstance(model.hparams, Container): - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) else: checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 371e0cd6b2cd4..480acf9076739 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -27,6 +27,8 @@ import cloudpickle import pytest import torch +from omegaconf import Container, OmegaConf +from argparse import Namespace import tests.base.develop_utils as tutils from pytorch_lightning import Trainer, seed_everything @@ -911,3 +913,35 @@ def training_step(self, *args): expected = float("inf" if mode == "min" else "-inf") assert model_checkpoint.best_model_score == expected assert model_checkpoint.current_score == expected + + +@pytest.mark.parametrize("hparams_type", [dict, Container]) +def test_hparams_type(tmpdir, hparams_type): + class TestModel(BoringModel): + def __init__(self, hparams): + super().__init__() + self.save_hyperparameters(hparams) + + model_checkpoint = ModelCheckpoint( + dirpath=tmpdir, + save_top_k=1, + monitor="foo", + ) + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + callbacks=[model_checkpoint], + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + hp = {"test_hp_0": 1, "test_hp_1": 2} + hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp) + model = TestModel(hp) + trainer.fit(model) + ckpt = trainer.checkpoint_connector.dump_checkpoint() + if hparams_type == Container: + assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type) + else: + # make sure it's not AttributeDict + assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type