Skip to content

Commit

Permalink
Cast hparams to dict when not using omegaconf (#4770)
Browse files Browse the repository at this point in the history
* init fix

* init test

* more specific dict assert

* update changelog

* Update tests/checkpointing/test_model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 20, 2020
1 parent 4803f68 commit 42e59c6
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
23 changes: 23 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).



## [unreleased.BugFix] - YYYY-MM-DD

### Added



### Changed



### Deprecated



### Removed



### Fixed

- Fixed checkpoint hparams dict casting when omegaconf is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770))


## [1.0.7] - 2020-11-17

### Added
Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# dump arguments
if OMEGACONF_AVAILABLE:
if OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
if isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
else:
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

Expand Down
34 changes: 34 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import cloudpickle
import pytest
import torch
from omegaconf import Container, OmegaConf
from argparse import Namespace

import tests.base.develop_utils as tutils
from pytorch_lightning import Trainer, seed_everything
Expand Down Expand Up @@ -911,3 +913,35 @@ def training_step(self, *args):
expected = float("inf" if mode == "min" else "-inf")
assert model_checkpoint.best_model_score == expected
assert model_checkpoint.current_score == expected


@pytest.mark.parametrize("hparams_type", [dict, Container])
def test_hparams_type(tmpdir, hparams_type):
class TestModel(BoringModel):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)

model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=1,
monitor="foo",
)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
hp = {"test_hp_0": 1, "test_hp_1": 2}
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
model = TestModel(hp)
trainer.fit(model)
ckpt = trainer.checkpoint_connector.dump_checkpoint()
if hparams_type == Container:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type

0 comments on commit 42e59c6

Please sign in to comment.