Skip to content

Commit

Permalink
Cast hparams to dict when not using omegaconf (#4770)
Browse files Browse the repository at this point in the history
* init fix

* init test

* more specific dict assert

* update changelog

* Update tests/checkpointing/test_model_checkpoint.py

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

(cherry picked from commit 42e59c6)
  • Loading branch information
s-rog authored and Borda committed Nov 23, 2020
1 parent abff810 commit 6172b02
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed checkpoint `hparams` dict casting when `omegaconf` is available ([#4770](https://github.com/PyTorchLightning/pytorch-lightning/pull/4770))




## [1.0.7] - 2020-11-17

Expand Down
5 changes: 2 additions & 3 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
# dump arguments
if OMEGACONF_AVAILABLE:
if OMEGACONF_AVAILABLE and isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
if isinstance(model.hparams, Container):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
else:
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = dict(model.hparams)

Expand Down
34 changes: 34 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pickle
import platform
import re
from argparse import Namespace
from pathlib import Path
from unittest import mock
from unittest.mock import Mock
Expand All @@ -24,6 +25,7 @@
import pytest
import torch
import yaml
from omegaconf import Container, OmegaConf

import pytorch_lightning as pl
import tests.base.develop_utils as tutils
Expand Down Expand Up @@ -812,3 +814,35 @@ def test_configure_model_checkpoint(tmpdir):

with pytest.raises(MisconfigurationException, match="checkpoint_callback=False but found ModelCheckpoint"):
Trainer(checkpoint_callback=False, callbacks=[callback1], **kwargs)


@pytest.mark.parametrize("hparams_type", [dict, Container])
def test_hparams_type(tmpdir, hparams_type):
class TestModel(BoringModel):
def __init__(self, hparams):
super().__init__()
self.save_hyperparameters(hparams)

model_checkpoint = ModelCheckpoint(
dirpath=tmpdir,
save_top_k=1,
monitor="foo",
)
trainer = Trainer(
default_root_dir=tmpdir,
fast_dev_run=True,
callbacks=[model_checkpoint],
logger=False,
weights_summary=None,
progress_bar_refresh_rate=0,
)
hp = {"test_hp_0": 1, "test_hp_1": 2}
hp = OmegaConf.create(hp) if hparams_type == Container else Namespace(**hp)
model = TestModel(hp)
trainer.fit(model)
ckpt = trainer.checkpoint_connector.dump_checkpoint()
if hparams_type == Container:
assert isinstance(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY], hparams_type)
else:
# make sure it's not AttributeDict
assert type(ckpt[model.CHECKPOINT_HYPER_PARAMS_KEY]) == hparams_type

0 comments on commit 6172b02

Please sign in to comment.