Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace Hparams by init args #1896

Merged
merged 101 commits into from
May 24, 2020
Merged

replace Hparams by init args #1896

merged 101 commits into from
May 24, 2020

Conversation

williamFalcon
Copy link
Contributor

@williamFalcon williamFalcon commented May 19, 2020

Problem

hparams was a temporary fix for not auto storing args by users. It’s something everyone hacks around, is not intuitive and makes the pl module somehow less like at pt module.

end of hparams!

This PR

This PR removes that and instead:

  • Stores all the args passed in init automatically so checkpoints can have this information.
  • doesn’t store things like losses, etc... only primitives, lists, dicts, tuples and namespace
  • auto saves this info into checkpoints
  • it DOES NOT assign properties automatically

Backward compatibility

  • this PR is still backward compatible for people who want to continue using hparams directly.

Summary

Before:

hparams = dict or Namespace

class LitModel(pl.LightningModule):
    def __init__(self, hparams, my_pretrained_nn_module):
        super().__init__()
        self.hparams = hparams
        self.l1 = nn.Linear(hparams.in_dim, hparams.out_dim)
        self.feature_extractor = my_pretrained_nn_module()

# old way had a ton of problems with this
model = LitModel.load_from_checkpoint(PATH)

New:

class LitModel(pl.LightningModule):
    def __init__(self, in_dim, out_dim, my_pretrained_nn_module):
        super().__init__()
        self.in_dim = in_dim
        self.out_dim = out_dim
        
        # self.in_dim, etc were auto registered to the module
        self.l1 = nn.Linear(in_dim, out_dim)
        self.feature_extractor = my_pretrained_nn_module()

# load from checkpoint still works as normal, but objects and such need to be specified
model = LitModel.load_from_checkpoint(PATH, my_pretrained_nn_module=MyModule)

# or can overwrite the old settings as well
model = LitModel.load_from_checkpoint(PATH, in_dim=some_new_dim, my_pretrained_nn_module=MyModule)

@mergify mergify bot requested a review from a team May 19, 2020 20:03
@williamFalcon
Copy link
Contributor Author

@awaelchli @justusschock

Need to figure out the loading/saving stuff and update tests

@williamFalcon williamFalcon changed the title No hparams [WIP] No hparams May 19, 2020
@pep8speaks
Copy link

pep8speaks commented May 19, 2020

Hello @williamFalcon! Thanks for updating this PR.

Line 289:52: W504 line break after binary operator

Comment last updated at 2020-05-24 17:00:58 UTC

@awaelchli
Copy link
Contributor

awaelchli commented May 19, 2020

if someone renames "self" to "this" for example, it could break, right?
need to pull out the first one in the arg list and only consider the rest

@williamFalcon
Copy link
Contributor Author

if someone renames "self" to "this" for example, it could break, right?
need to pull out the first one in the arg list and only consider the rest

is it an ordered dict? otherwise what if someone names it not this but something else

"""
# two frames back is the init of the child module
frame = inspect.currentframe()
args = frame.f_back.f_back.f_locals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mentioned this in #1735, but if my LightningModule was a subclass of another LightningModule, this wouldn't work right? We have to dynamically determine how many levels we should go since we always need to get to the leaf level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. any suggestions?
I guess we could always backtrack right up to the nn.Module?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yukw777 hadn't read that carefully actually haha. Good suggestion!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(also, happy to co-author this since it's fairly involved)

Copy link
Contributor

@yukw777 yukw777 May 19, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha no worries. I'd be happy to help out.

I had to implement something similar at my job, and I ended up going with #1735 (comment). I tried to see if I can somehow automate the whole thing, but it was more trouble than it's worth, so I decided to keep things more "declarative". This does mean that a PL user would need to implement that abstract property, which makes LightningModule less transparent... We could make LightningModule a data class as @mateuszpieniak, but it's only available in 3.7, and it also makes LightningModule less transparent. It does seem like we do need to add something like this though, as it's impossible for PL to figure out whose __init__() args to save automatically...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about i take a stab at a v1 and ping you to finish it haha.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! looking forward to it!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re the dataclass, we could always depend on https://pypi.org/project/dataclasses/ to get dataclasses in 3.6.

I think adding a serialize_args abstract method makes sense. That's what I use at work (on a non lightning training pipeline) and it works pretty well. We have sensible serialization defaults, so it only needs to be overriden if the training module has custom, non serializable types.

# set module_arguments in child
setattr(child, 'module_arguments', module_arguments)

def _is_allowed_hparam_value(self, value):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yukw777 this should be good no? allows for basically anything except objects (but allows dicts, lists, tuples)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously any picklable objects could be in hparams, so I think we should keep that behavior, which is actually quite useful for things like custom vocabulary dictionaries. It also makes it easy to invert dependencies to write tests.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon Other examples: this would not allow to use dataclasses and OmegaConf which are easy to use now

@Borda
Copy link
Member

Borda commented May 19, 2020

if someone renames "self" to "this" for example, it could break, right?
need to pull out the first one in the arg list and only consider the rest

is it an ordered dict? otherwise what if someone names it not this but something else

in python does not matter if you call it self or king just the first argument is treated such way and hold for the particular method, the self is just convention as well as this in Java

Comment on lines 194 to 202
replay_size,
warm_start_steps,
gamma, eps_start,
eps_end,
eps_last_frame,
sync_rate,
lr,
episode_length,
batch_size) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add types...

@mergify mergify bot requested a review from a team May 19, 2020 21:55
@@ -185,7 +182,7 @@ def main(hparams):
# ------------------------
# 1 INIT LIGHTNING MODEL
# ------------------------
model = SegModel(hparams)
model = SegModel(**hparams)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we shall be careful if the hparams contains any argument (key) not listed in the init it will crash

Comment on lines 40 to 47
drop_prob=0.2,
batch_size=2,
in_features=28 * 28,
learning_rate=0.001 * 8,
optimizer_name='adam',
data_root='./datasets',
out_features=10,
hidden_dim=1000,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add types...

@mergify mergify bot requested a review from a team May 19, 2020 21:56
@Borda Borda added the feature Is an improvement or enhancement label May 19, 2020
@Borda Borda added this to the 0.8.0 milestone May 19, 2020
@Borda
Copy link
Member

Borda commented May 19, 2020

for sure with this change and previously merged matrics we have to go with v0.8 as next release :]
correct me if I am wrong but I feel this is quite major API change... @PyTorchLightning/core-contributors

@williamFalcon
Copy link
Contributor Author

williamFalcon commented May 19, 2020

@yukw777 ok... the last issue to handle is what you mentioned.
I need to log off for today. Want to take a look at it?
The issue is definitely with the subclassing. Check the test_auto_hparams test in test_trainer.py

=========================================================================================== short test summary info ===========================================================================================
FAILED tests/trainer/test_trainer.py::test_auto_hparams - TypeError: __init__() got an unexpected keyword argument 'batch_size'
FAILED tests/trainer/test_trainer.py::test_dict_namespace_param_save_load - TypeError: __init__() got an unexpected keyword argument 'drop_prob'

@justusschock fixed all the checkpoint and yaml stuff. Take a look?

Comment on lines 22 to 58
def test_auto_hparams(tmpdir):
class SubClassEvalModelTemplate(EvalModelTemplate):
def __init__(self, subclass_arg=1200):
super().__init__()

class SubSubClassEvalModelTemplate(SubClassEvalModelTemplate):
pass

classes = [SubClassEvalModelTemplate, EvalModelTemplate, SubSubClassEvalModelTemplate]

for CLASS in classes:
# test that the model automatically sets the args passed into init as attrs
model = CLASS()
assert model.batch_size == 32
model = CLASS(batch_size=179)
assert model.batch_size == 179

if isinstance(model, SubClassEvalModelTemplate):
assert model.subclass_arg == 1200

# verify that the checkpoint saved the correct values
trainer = Trainer(max_steps=20)
trainer.fit(model)
raw_checkpoint_path = os.listdir(trainer.checkpoint_callback.dirpath)
raw_checkpoint_path = [x for x in raw_checkpoint_path if '.ckpt' in x][0]
raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path)
raw_checkpoint = torch.load(raw_checkpoint_path)
assert 'module_arguments' in raw_checkpoint
assert raw_checkpoint['module_arguments']['batch_size'] == 179

# verify that model loads correctly
model = CLASS.load_from_checkpoint(raw_checkpoint_path)
assert model.batch_size == 179

# verify that we can overwrite whatever we want
model = CLASS.load_from_checkpoint(raw_checkpoint_path, batch_size=99)
assert model.batch_size == 99
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@justusschock added this test. Any cases missing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you only pass yaml serializable stuff in there. For example there may be users who pass their loss functions that way if they experiment with then, but you can't serialize stuff like torch.nn.MSELoss with yaml

@williamFalcon
Copy link
Contributor Author

@yukw777 @tullie this is what i added. Why do we need the datamodules?

    def _auto_register_hparams(self):
        """
        Removes the need to pass in hparams. Instead, we register every argument in init
        to the module with some caveats:
        1. we don't overwrite the property if it already exists
        2. we also store a module_arguments property for model loading and saving
        """
        # two frames back is the init of the child module
        frame = inspect.currentframe()
        frame_args = frame.f_back.f_back.f_locals

        # we'll save hparams automatically (renamed to module_arguments)
        module_arguments = {}

        # pull out the child itself to make sure we have no issues
        child = frame_args['self']

        # auto set the attr which enables self.attr anywhere in the code
        for name, value in frame_args.items():

            # don't add self
            if name not in ['self']:

                # only track some things
                is_trackable = self._is_allowed_hparam_value(value)

                # don't overwrite something already set
                if not hasattr(child, name) and is_trackable:
                    setattr(child, name, value)

                if is_trackable:
                    module_arguments[name] = value

        # set module_arguments in child
        setattr(child, 'module_arguments', module_arguments)

    def _is_allowed_hparam_value(self, value):
        if isinstance(value, Namespace):
            return True
        return not hasattr(value, '__dict__')

@mergify mergify bot requested a review from a team May 20, 2020 13:47
@Borda
Copy link
Member

Borda commented May 20, 2020

have you seen this recently?
AttributeError: module 'tensorflow' has no attribute 'io'

pass


class AggSubClassEvalModel(SubClassEvalModel):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon here it is as init arg

@mergify mergify bot requested a review from a team May 24, 2020 17:51
@@ -119,6 +120,12 @@
else:
HOROVOD_AVAILABLE = True

PRIMITIVE_TYPES = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon @awaelchli @yukw777 @festeh any other primitives shall be stored in checkpoint?

Copy link
Contributor

@festeh festeh May 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using dataclasses. They are actually great in the way that they give you auto-completion (plain config classes would allow it too) and can provide some validation. But I never heard that somebody else also doing that, so I'd probably could just patch this variable in my code or invent some other hack.

In general I think we cannot handle all cases here so it would be beneficial to allow user to manually save some picklable argument, maybe via @should_pickle(argument) decorator. I'll try to design this feature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that sounds interesting, I would just stay for this PR to get done with the complete list, and kindly ask you to make a follow-up PR with your suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, sure

@mergify mergify bot requested a review from a team May 24, 2020 18:08

def _collect_init_args(frame, path_args: list) -> list:
"""Recursive search for all children."""
if '__class__' in frame.f_locals:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this do?
Here is a an example we should try to handle or at least give a warning

from pytorch_lightning import LightningModule
class Example(LightningModule):

    def __init__(this, arg):
        super().__init__()
        this.arg = arg

    def forward(self, x):
        pass

x = Example(1)
print(x.module_arguments)
#  {'this': Example(), 'arg': 1}

module_arguments contains the object itself.
If this the PR needs to be merged asap I hope we can put it at least on the list of TODOs.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

true, it would contain the aka self but it will be filtered out as it is not a primitive...
so you recommend doing the primitive filtering already here?

Copy link
Contributor

@awaelchli awaelchli May 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, but we need to also inspect the constructor to get a list of accepted args and filter based on that, because of this example:

from pytorch_lightning import LightningModule

class Example(LightningModule):

    def __init__(self, arg):
        my_local_var = 2
        super().__init__()

    def forward(self, x):
        pass

x = Example(1)
print(x.module_arguments)
# {'arg': 1, 'my_local_var': 2}

This will fail when we try to restore and pass in the local var which is not an argument in the constructor.

We should probably filter based on
inspect.signature(Example.__init__).parameters
and only save these locals, not the others.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x = Example(1)
print(x.module_arguments)
# {'arg': 1, 'my_local_var': 2}

this works as in the frame it appears with arg name always...
and as it is saved you can call x = Example(arg=1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but there is an extra arg.
Example.load_from_checkpoint(...)
will load module argumets
module_arguments = {'arg': 1, 'my_local_var': 2}
and then call
Example(**module_arguments)
Which means
my_local_var is passed in but this name is not accepted.
Thus yielding a TypeError: __init__ got an unexpected argument "my_local_var"
I am 99% certain.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, wait if my_local_var is not among __init__ arguments, it won't be present in module_arguments neither...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli mind adding an edge case to the test so we know the exact case and also we can truly test it... :]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in follow up or here? because william wants to merge asap he said

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let’s add it as a follow up.

ready to merge? happy to merge now

@mergify mergify bot requested a review from a team May 24, 2020 18:43
@williamFalcon williamFalcon merged commit caa9c67 into master May 24, 2020
@williamFalcon
Copy link
Contributor Author

Great job @Borda, thanks for the feedback @yukw777 @awaelchli!

@PyTorchLightning/core-contributors play with this for a bit? also verify than hydra still works? @tullie

@yukw777
Copy link
Contributor

yukw777 commented May 24, 2020

Thank you for pushing through this @Borda ! Glad I was able to help out here.

@awaelchli
Copy link
Contributor

Really cool idea this PR! It will simplify checkpointing a lot :)

I compiled a list to keep track of unresolved issues discussed in this thread here. Feel free to add anything I have missed.
See #1937

@DKandrew
Copy link
Contributor

DKandrew commented May 28, 2020

Received a warning when running the Trainer UserWarning: Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters. Does it relate to the legacy code of hparams?

@Borda
Copy link
Member

Borda commented May 28, 2020

Received a warning when running the Trainer UserWarning: Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters. Does it relate to the legacy code of hparams?

mind shoot an issue?

@DKandrew
Copy link
Contributor

Received a warning when running the Trainer UserWarning: Did not find hyperparameters at model hparams. Saving checkpoint without hyperparameters. Does it relate to the legacy code of hparams?

mind shoot an issue?

Sure!

@drozzy
Copy link

drozzy commented Jun 1, 2020

This is awesome, when can we expect this in the release?

@drozzy
Copy link

drozzy commented Jun 1, 2020

Wait, does this solve the issue of using params in jupyter notebooks? In other words, can I omit argparse with this?

williamFalcon pushed a commit that referenced this pull request Jun 17, 2020
* Misleading exception raised during batch scaling

Use batch_size from `model.hparams.batch_size` instead of `model.batch_size`

* Improvements considering #1896

* Apply suggestions from code review

Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
williamFalcon pushed a commit that referenced this pull request Jul 29, 2020
* Misleading exception raised during batch scaling

Use batch_size from `model.hparams.batch_size` instead of `model.batch_size`

* Improvements considering #1896

* Apply suggestions from code review

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>

Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
@austinmw
Copy link

austinmw commented Apr 2, 2021

Hi, what's the recommended way to use this with argparse?

model = LitModel(**vars(args))?

@awaelchli
Copy link
Contributor

@austinmw Yes

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on
Projects
None yet
Development

Successfully merging this pull request may close these issues.