Skip to content

Commit

Permalink
Prune deprecated hparams setter (#6207)
Browse files Browse the repository at this point in the history
  • Loading branch information
carmocca authored Feb 27, 2021
1 parent 40d5a9d commit 111d9c7
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 101 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))


- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))


### Fixed

- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
Expand Down
34 changes: 0 additions & 34 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import copy
import inspect
import os
import re
import tempfile
import uuid
from abc import ABC
Expand Down Expand Up @@ -1806,39 +1805,6 @@ def hparams_initial(self) -> AttributeDict:
# prevent any change
return copy.deepcopy(self._hparams_initial)

@hparams.setter
def hparams(self, hp: Union[dict, Namespace, Any]):
# TODO: remove this method in v1.3.0.
rank_zero_warn(
"The setter for self.hparams in LightningModule is deprecated since v1.1.0 and will be"
" removed in v1.3.0. Replace the assignment `self.hparams = hparams` with "
" `self.save_hyperparameters()`.", DeprecationWarning
)
hparams_assignment_name = self.__get_hparams_assignment_variable()
self._hparams_name = hparams_assignment_name
self._set_hparams(hp)
# this resolves case when user does not uses `save_hyperparameters` and do hard assignement in init
if not hasattr(self, "_hparams_initial"):
self._hparams_initial = copy.deepcopy(self._hparams)

def __get_hparams_assignment_variable(self):
"""
looks at the code of the class to figure out what the user named self.hparams
this only happens when the user explicitly sets self.hparams
"""
try:
class_code = inspect.getsource(self.__class__)
lines = class_code.split("\n")
for line in lines:
line = re.sub(r"\s+", "", line, flags=re.UNICODE)
if ".hparams=" in line:
return line.split("=")[1]
# todo: specify the possible exception
except Exception:
return "hparams"

return None

@property
def model_size(self) -> float:
# todo: think about better way without need to dump model to drive
Expand Down
30 changes: 0 additions & 30 deletions tests/deprecated_api/test_remove_1-3.py

This file was deleted.

38 changes: 5 additions & 33 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ def __init__(self, hparams):
self.save_hyperparameters(hparams)


class AssignHparamsModel(BoringModel):
""" Tests that a model can take an object with explicit setter """

def __init__(self, hparams):
super().__init__()
self.hparams = hparams


def decorate(func):

@functools.wraps(func)
Expand All @@ -68,16 +60,6 @@ def __init__(self, hparams, *my_args, **my_kwargs):
self.save_hyperparameters(hparams)


class AssignHparamsDecoratedModel(BoringModel):
""" Tests that a model can take an object with explicit setter"""

@decorate
@decorate
def __init__(self, hparams, *my_args, **my_kwargs):
super().__init__()
self.hparams = hparams


# -------------------------
# STANDARD TESTS
# -------------------------
Expand Down Expand Up @@ -114,7 +96,7 @@ def _run_standard_hparams_test(tmpdir, model, cls, try_overwrite=False):


@pytest.mark.parametrize(
"cls", [SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel]
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel]
)
def test_namespace_hparams(tmpdir, cls):
# init model
Expand All @@ -125,7 +107,7 @@ def test_namespace_hparams(tmpdir, cls):


@pytest.mark.parametrize(
"cls", [SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel]
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel]
)
def test_dict_hparams(tmpdir, cls):
# init model
Expand All @@ -136,7 +118,7 @@ def test_dict_hparams(tmpdir, cls):


@pytest.mark.parametrize(
"cls", [SaveHparamsModel, AssignHparamsModel, SaveHparamsDecoratedModel, AssignHparamsDecoratedModel]
"cls", [SaveHparamsModel, SaveHparamsDecoratedModel]
)
def test_omega_conf_hparams(tmpdir, cls):
# init model
Expand Down Expand Up @@ -580,8 +562,7 @@ class SuperClassPositionalArgs(BoringModel):

def __init__(self, hparams):
super().__init__()
self._hparams = None # pretend BoringModel did not call self.save_hyperparameters()
self.hparams = hparams
self._hparams = hparams # pretend BoringModel did not call self.save_hyperparameters()


class SubClassVarArgs(SuperClassPositionalArgs):
Expand Down Expand Up @@ -617,8 +598,6 @@ def test_init_arg_with_runtime_change(tmpdir, cls):
assert model.hparams.running_arg == 123
model.hparams.running_arg = -1
assert model.hparams.running_arg == -1
model.hparams = Namespace(abc=42)
assert model.hparams.abc == 42

trainer = Trainer(
default_root_dir=tmpdir,
Expand Down Expand Up @@ -664,18 +643,11 @@ class TestHydraModel(BoringModel):

def __init__(self, args_0, args_1, args_2, kwarg_1=None):
self.save_hyperparameters()
self.test_hparams()
config_file = f"{tmpdir}/hparams.yaml"
save_hparams_to_yaml(config_file, self.hparams)
self.hparams = load_hparams_from_yaml(config_file)
self.test_hparams()
super().__init__()

def test_hparams(self):
assert self.hparams.args_0.log == "Something"
assert self.hparams.args_1['cfg'].log == "Something"
assert self.hparams.args_2[0].log == "Something"
assert self.hparams.kwarg_1['cfg'][0].log == "Something"
super().__init__()

with initialize(config_path="conf"):
args_0 = compose(config_name="config")
Expand Down
13 changes: 9 additions & 4 deletions tests/trainer/test_trainer_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.utilities import _NATIVE_AMP_AVAILABLE, AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base import EvalModelTemplate
from tests.helpers import BoringModel
from tests.helpers.datamodules import MNISTDataModule


Expand Down Expand Up @@ -282,10 +283,14 @@ def dataloader(self, *args, **kwargs):

def test_auto_scale_batch_size_duplicate_attribute_warning(tmpdir):
""" Test for a warning when model.batch_size and model.hparams.batch_size both present. """
hparams = EvalModelTemplate.get_default_hparams()
model = EvalModelTemplate(**hparams)
model.hparams = hparams
# now we have model.batch_size and model.hparams.batch_size
class TestModel(BoringModel):
def __init__(self, batch_size=1):
super().__init__()
# now we have model.batch_size and model.hparams.batch_size
self.batch_size = 1
self.save_hyperparameters()

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_steps=1, max_epochs=1000, auto_scale_batch_size=True)
expected_message = "Field `model.batch_size` and `model.hparams.batch_size` are mutually exclusive!"
with pytest.warns(UserWarning, match=expected_message):
Expand Down

0 comments on commit 111d9c7

Please sign in to comment.