From 6172b02c0c295d92add0af1cc59bdf2ae62954f0 Mon Sep 17 00:00:00 2001 From: Roger Shieh Date: Fri, 20 Nov 2020 12:53:05 +0100 Subject: [PATCH] Cast hparams to dict when not using omegaconf (#4770) * init fix * init test * more specific dict assert * update changelog * Update tests/checkpointing/test_model_checkpoint.py Co-authored-by: Jirka Borovec Co-authored-by: chaton Co-authored-by: Jirka Borovec (cherry picked from commit 42e59c6add29a5f91654a9c3a76febbe435df981) --- CHANGELOG.md | 4 +++ .../connectors/checkpoint_connector.py | 5 ++- tests/checkpointing/test_model_checkpoint.py | 34 +++++++++++++++++++ 3 files changed, 40 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1f07502b381b3..afc19f4645f1e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed checkpoint `hparams` dict casting when `omegaconf` is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770)) + + + ## [1.0.7] - 2020-11-17 diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index 3b44ce96c02ad..e8c8bbc254e56 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -326,10 +326,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: if hasattr(model, '_hparams_name'): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name # dump arguments - if OMEGACONF_AVAILABLE: + if OMEGACONF_AVAILABLE and isinstance(model.hparams, Container): checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams - if isinstance(model.hparams, Container): - checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) + checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams) else: checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams) diff --git a/tests/checkpointing/test_model_checkpoint.py b/tests/checkpointing/test_model_checkpoint.py index 55d1fdace673a..2c8b7d10d8cfd 100644 --- a/tests/checkpointing/test_model_checkpoint.py +++ b/tests/checkpointing/test_model_checkpoint.py @@ -16,6 +16,7 @@ import pickle import platform import re +from argparse import Namespace from pathlib import Path from unittest import mock from unittest.mock import Mock @@ -24,6 +25,7 @@ import pytest import torch import yaml +from omegaconf import Container, OmegaConf import pytorch_lightning as pl import tests.base.develop_utils as tutils @@ -812,3 +814,35 @@ def test_configure_model_checkpoint(tmpdir): with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"): Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs) + + +@pytest.mark.parametrize("hparams_type", [dict, Container]) +def test_hparams_type(tmpdir, hparams_type): + class TestModel(BoringModel): + def __init__(self, hparams): + super().__init__() + self.save_hyperparameters(hparams) + + model_checkpoint = ModelCheckpoint( + dirpath=tmpdir, + save_top_k=1, + monitor="foo", + ) + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + callbacks=[model_checkpoint], + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + hp = {"test_hp_0": 1, "test_hp_1": 2} + hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp) + model = TestModel(hp) + trainer.fit(model) + ckpt = trainer.checkpoint_connector.dump_checkpoint() + if hparams_type == Container: + assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type) + else: + # make sure it's not AttributeDict + assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type