-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
replace Hparams by init args #1896
Conversation
Need to figure out the loading/saving stuff and update tests |
Hello @williamFalcon! Thanks for updating this PR.
Comment last updated at 2020-05-24 17:00:58 UTC |
if someone renames "self" to "this" for example, it could break, right? |
is it an ordered dict? otherwise what if someone names it not this but something else |
pytorch_lightning/core/lightning.py
Outdated
""" | ||
# two frames back is the init of the child module | ||
frame = inspect.currentframe() | ||
args = frame.f_back.f_back.f_locals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mentioned this in #1735, but if my LightningModule was a subclass of another LightningModule, this wouldn't work right? We have to dynamically determine how many levels we should go since we always need to get to the leaf level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good point. any suggestions?
I guess we could always backtrack right up to the nn.Module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yukw777 hadn't read that carefully actually haha. Good suggestion!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(also, happy to co-author this since it's fairly involved)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
haha no worries. I'd be happy to help out.
I had to implement something similar at my job, and I ended up going with #1735 (comment). I tried to see if I can somehow automate the whole thing, but it was more trouble than it's worth, so I decided to keep things more "declarative". This does mean that a PL user would need to implement that abstract property, which makes LightningModule
less transparent... We could make LightningModule
a data class as @mateuszpieniak, but it's only available in 3.7, and it also makes LightningModule
less transparent. It does seem like we do need to add something like this though, as it's impossible for PL to figure out whose __init__()
args to save automatically...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about i take a stab at a v1 and ping you to finish it haha.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good! looking forward to it!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re the dataclass, we could always depend on https://pypi.org/project/dataclasses/ to get dataclasses in 3.6.
I think adding a serialize_args abstract method makes sense. That's what I use at work (on a non lightning training pipeline) and it works pretty well. We have sensible serialization defaults, so it only needs to be overriden if the training module has custom, non serializable types.
pytorch_lightning/core/lightning.py
Outdated
# set module_arguments in child | ||
setattr(child, 'module_arguments', module_arguments) | ||
|
||
def _is_allowed_hparam_value(self, value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yukw777 this should be good no? allows for basically anything except objects (but allows dicts, lists, tuples)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Previously any picklable objects could be in hparams
, so I think we should keep that behavior, which is actually quite useful for things like custom vocabulary dictionaries. It also makes it easy to invert dependencies to write tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@williamFalcon Other examples: this would not allow to use dataclasses and OmegaConf
which are easy to use now
in python does not matter if you call it |
replay_size, | ||
warm_start_steps, | ||
gamma, eps_start, | ||
eps_end, | ||
eps_last_frame, | ||
sync_rate, | ||
lr, | ||
episode_length, | ||
batch_size) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add types...
@@ -185,7 +182,7 @@ def main(hparams): | |||
# ------------------------ | |||
# 1 INIT LIGHTNING MODEL | |||
# ------------------------ | |||
model = SegModel(hparams) | |||
model = SegModel(**hparams) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we shall be careful if the hparams
contains any argument (key) not listed in the init it will crash
drop_prob=0.2, | ||
batch_size=2, | ||
in_features=28 * 28, | ||
learning_rate=0.001 * 8, | ||
optimizer_name='adam', | ||
data_root='./datasets', | ||
out_features=10, | ||
hidden_dim=1000, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's add types...
for sure with this change and previously merged matrics we have to go with v0.8 as next release :] |
@yukw777 ok... the last issue to handle is what you mentioned.
@justusschock fixed all the checkpoint and yaml stuff. Take a look? |
tests/trainer/test_trainer.py
Outdated
def test_auto_hparams(tmpdir): | ||
class SubClassEvalModelTemplate(EvalModelTemplate): | ||
def __init__(self, subclass_arg=1200): | ||
super().__init__() | ||
|
||
class SubSubClassEvalModelTemplate(SubClassEvalModelTemplate): | ||
pass | ||
|
||
classes = [SubClassEvalModelTemplate, EvalModelTemplate, SubSubClassEvalModelTemplate] | ||
|
||
for CLASS in classes: | ||
# test that the model automatically sets the args passed into init as attrs | ||
model = CLASS() | ||
assert model.batch_size == 32 | ||
model = CLASS(batch_size=179) | ||
assert model.batch_size == 179 | ||
|
||
if isinstance(model, SubClassEvalModelTemplate): | ||
assert model.subclass_arg == 1200 | ||
|
||
# verify that the checkpoint saved the correct values | ||
trainer = Trainer(max_steps=20) | ||
trainer.fit(model) | ||
raw_checkpoint_path = os.listdir(trainer.checkpoint_callback.dirpath) | ||
raw_checkpoint_path = [x for x in raw_checkpoint_path if '.ckpt' in x][0] | ||
raw_checkpoint_path = os.path.join(trainer.checkpoint_callback.dirpath, raw_checkpoint_path) | ||
raw_checkpoint = torch.load(raw_checkpoint_path) | ||
assert 'module_arguments' in raw_checkpoint | ||
assert raw_checkpoint['module_arguments']['batch_size'] == 179 | ||
|
||
# verify that model loads correctly | ||
model = CLASS.load_from_checkpoint(raw_checkpoint_path) | ||
assert model.batch_size == 179 | ||
|
||
# verify that we can overwrite whatever we want | ||
model = CLASS.load_from_checkpoint(raw_checkpoint_path, batch_size=99) | ||
assert model.batch_size == 99 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@justusschock added this test. Any cases missing?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you only pass yaml serializable stuff in there. For example there may be users who pass their loss functions that way if they experiment with then, but you can't serialize stuff like torch.nn.MSELoss with yaml
@yukw777 @tullie this is what i added. Why do we need the datamodules?
|
have you seen this recently? |
pass | ||
|
||
|
||
class AggSubClassEvalModel(SubClassEvalModel): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@williamFalcon here it is as init arg
@@ -119,6 +120,12 @@ | |||
else: | |||
HOROVOD_AVAILABLE = True | |||
|
|||
PRIMITIVE_TYPES = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@williamFalcon @awaelchli @yukw777 @festeh any other primitives shall be stored in checkpoint?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm using dataclasses. They are actually great in the way that they give you auto-completion (plain config classes would allow it too) and can provide some validation. But I never heard that somebody else also doing that, so I'd probably could just patch this variable in my code or invent some other hack.
In general I think we cannot handle all cases here so it would be beneficial to allow user to manually save some picklable argument, maybe via @should_pickle(argument) decorator. I'll try to design this feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that sounds interesting, I would just stay for this PR to get done with the complete list, and kindly ask you to make a follow-up PR with your suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, sure
|
||
def _collect_init_args(frame, path_args: list) -> list: | ||
"""Recursive search for all children.""" | ||
if '__class__' in frame.f_locals: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this do?
Here is a an example we should try to handle or at least give a warning
from pytorch_lightning import LightningModule
class Example(LightningModule):
def __init__(this, arg):
super().__init__()
this.arg = arg
def forward(self, x):
pass
x = Example(1)
print(x.module_arguments)
# {'this': Example(), 'arg': 1}
module_arguments contains the object itself.
If this the PR needs to be merged asap I hope we can put it at least on the list of TODOs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
true, it would contain the aka self
but it will be filtered out as it is not a primitive...
so you recommend doing the primitive filtering already here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure, but we need to also inspect the constructor to get a list of accepted args and filter based on that, because of this example:
from pytorch_lightning import LightningModule
class Example(LightningModule):
def __init__(self, arg):
my_local_var = 2
super().__init__()
def forward(self, x):
pass
x = Example(1)
print(x.module_arguments)
# {'arg': 1, 'my_local_var': 2}
This will fail when we try to restore and pass in the local var which is not an argument in the constructor.
We should probably filter based on
inspect.signature(Example.__init__).parameters
and only save these locals, not the others.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
x = Example(1)
print(x.module_arguments)
# {'arg': 1, 'my_local_var': 2}
this works as in the frame it appears with arg name always...
and as it is saved you can call x = Example(arg=1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but there is an extra arg.
Example.load_from_checkpoint(...)
will load module argumets
module_arguments = {'arg': 1, 'my_local_var': 2}
and then call
Example(**module_arguments)
Which means
my_local_var is passed in but this name is not accepted.
Thus yielding a TypeError: __init__ got an unexpected argument "my_local_var"
I am 99% certain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, wait if my_local_var
is not among __init__
arguments, it won't be present in module_arguments
neither...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@awaelchli mind adding an edge case to the test so we know the exact case and also we can truly test it... :]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in follow up or here? because william wants to merge asap he said
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let’s add it as a follow up.
ready to merge? happy to merge now
Great job @Borda, thanks for the feedback @yukw777 @awaelchli! @PyTorchLightning/core-contributors play with this for a bit? also verify than hydra still works? @tullie |
Thank you for pushing through this @Borda ! Glad I was able to help out here. |
Really cool idea this PR! It will simplify checkpointing a lot :) I compiled a list to keep track of unresolved issues discussed in this thread here. Feel free to add anything I have missed. |
Received a warning when running the Trainer |
mind shoot an issue? |
Sure! |
This is awesome, when can we expect this in the release? |
Wait, does this solve the issue of using params in jupyter notebooks? In other words, can I omit argparse with this? |
* Misleading exception raised during batch scaling Use batch_size from `model.hparams.batch_size` instead of `model.batch_size` * Improvements considering #1896 * Apply suggestions from code review Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
* Misleading exception raised during batch scaling Use batch_size from `model.hparams.batch_size` instead of `model.batch_size` * Improvements considering #1896 * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Hi, what's the recommended way to use this with argparse?
|
@austinmw Yes |
Problem
hparams was a temporary fix for not auto storing args by users. It’s something everyone hacks around, is not intuitive and makes the pl module somehow less like at pt module.
end of hparams!
This PR
This PR removes that and instead:
Backward compatibility
Summary
Before:
New: