Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for save_hyperparameters with Python Data Class #3494

Closed
tbenst opened this issue Sep 14, 2020 · 19 comments · Fixed by #7992
Closed

add support for save_hyperparameters with Python Data Class #3494

tbenst opened this issue Sep 14, 2020 · 19 comments · Fixed by #7992
Labels
feature Is an improvement or enhancement help wanted Open to be worked on priority: 2 Low priority task
Milestone

Comments

@tbenst
Copy link

tbenst commented Sep 14, 2020

Python Data Classes are convenient in that they automatically generate a bunch of boilerplate code for assigning data to a class. They are particularly useful for PyTorch models that have a lot of hyperparameters and thus a lot of boilerplate.

🐛 Bug

I believe #1896 introduced a new bug: when using a data class, save_hyperparameters no longer works since it depends on init args and we instead use __post_init__ with dataclasses. Explicitly passing strings does not work either. Perhaps when passing

Code sample

import pytorch_lightning as pl
from dataclasses import dataclass

@dataclass()
class ConvDecoder(pl.LightningModule):
    imageChannels:int = 3
        
    def __post_init__(self):
        super().__init__()
        # both fail
#         self.save_hyperparameters()
        self.save_hyperparameters('imageChannels')
        
model = ConvDecoder()
model.hparams

Expected behavior

There should be a way to use save_hyperparameters with Data Classes

@tbenst tbenst added bug Something isn't working help wanted Open to be worked on labels Sep 14, 2020
@awaelchli
Copy link
Contributor

awaelchli commented Sep 14, 2020

It would be great to have this. However, I wouldn't say this is a bug. save_hyperparameters was never designed with data classes in mind, but it was mentioned before.

or does it say in the docs somewhere that it is supported?

@awaelchli awaelchli changed the title save_hyperparameters not compatible with Python Data Class add support for save_hyperparameters with Python Data Class Sep 14, 2020
@awaelchli awaelchli added feature Is an improvement or enhancement and removed bug Something isn't working labels Sep 14, 2020
@tbenst
Copy link
Author

tbenst commented Sep 14, 2020

@awaelchli I think it's a bug as a method should depend on the object not the context in which it's called--IIUC in this case the method can only be called from __init__, which is not documented.

Edit: there is a reference to __init__ in the docs: https://pytorch-lightning.readthedocs.io/en/latest/hyperparameters.html#lightningmodule-hyperparameters. But worth noting that a Python Data class still passes all these args to __init__ so it does seem to meet the requirements as currently worded.

Edit2: just realize that another issue with current implementation based on init args is that if you use the learning rate finder, which mutates the lr in the model, then this updated lr is not logged. The model logs the wrong lr since it is mutated after init

@awaelchli awaelchli added the bug Something isn't working label Sep 16, 2020
@awaelchli
Copy link
Contributor

ok adding back the bug label.

just realize that another issue with current implementation based on init args is that if you use the learning rate finder, which mutates the lr in the model, then this updated lr is not logged.

that's with the latest version? can you make a separate issue please? thanks

@tbenst
Copy link
Author

tbenst commented Sep 16, 2020

Yes, I think it’s appropriate to leave in same issue as the learning rate finder mutates the lr parameter of the LightningModule, which is reasonable, and the problem is that save_hyperparameters is based on the init args, when it should save parameters based on their current state when save_hyperparameters is called. Let me know if you think differently and happy to make another issue?

@edenlightning
Copy link
Contributor

@tbenst would you like to send a PR with the changes? I'm still not convinced what the bug is.

@edenlightning edenlightning added this to the 0.9.x milestone Sep 23, 2020
@edenlightning edenlightning modified the milestones: 0.9.x, 1.0 Oct 4, 2020
@edenlightning edenlightning added v1.0 allowed and removed help wanted Open to be worked on labels Oct 4, 2020
@Borda Borda self-assigned this Oct 5, 2020
@edenlightning edenlightning modified the milestones: 1.0, 1.1 Oct 7, 2020
@edenlightning edenlightning removed the bug Something isn't working label Oct 19, 2020
@Borda Borda added help wanted Open to be worked on priority: 2 Low priority task labels Nov 13, 2020
@edenlightning edenlightning removed this from the 1.1 milestone Nov 18, 2020
@stale
Copy link

stale bot commented Mar 19, 2021

This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team!

@stale stale bot added the won't fix This will not be worked on label Mar 19, 2021
@stale stale bot closed this as completed Mar 29, 2021
@dave-epstein
Copy link

Hi, what's the current status of this issue? Is there any way to make a LightningModule also be a Python dataclass while still being able to use save_hyperparameters to load successfully from checkpoints?

Thanks!

@awaelchli awaelchli added this to the v1.4 milestone Jun 15, 2021
@awaelchli awaelchli reopened this Jun 15, 2021
@stale stale bot removed the won't fix This will not be worked on label Jun 15, 2021
@awaelchli
Copy link
Contributor

@dave-epstein Does this work?

@dataclass()
class ConvDecoder(pl.LightningModule):
    imageChannels:int = 3
    
    def __init__(self, **kwargs)
        self.save_hyperparameters()
        super().__init__(**kwargs)        

@dave-epstein
Copy link

dave-epstein commented Jun 15, 2021

Thanks for the prompt response. This isn't working in my terminal (Python version 3.9.4, PL version 1.3.4). I get the following error: AssertionError: failed to inspect the obj init (L201, utilities/parsing.py, in save_hyperparameters), which is the same error I got when I tried this workaround:

@dataclass
class Test(LightningModule):
    x: int = 3
    y: float = 5
    def __post_init__(self): self.save_hyperparameters([f.name for f in fields(self)])

@awaelchli
Copy link
Contributor

Sorry for the beginner questions but let's put save_hyperparameters aside for a second. How would we build a nn.Module with dataclass, without defining an init?

@dataclass
class BoringModel(LightningModule):

    a: int

    def __post_init__(self):
        # self.save_hyperparameters()
        self.layer = torch.nn.Linear(32, 2)  # does not work??

@dave-epstein
Copy link

In this case, you add super().__init__() at the beginning of __post_init__'s body, and then it all works as expected. super() here is LightningModule, rather than something related to dataclass. Note that the dataclass decorator is not altering the object inheritance of the class, which can be seen by running BoringModel.mro().

Excuse me if this turns out not to be 100% accurate either, this is just the best of my understanding :)

@dave-epstein
Copy link

I think you were trying to write an explicit __init__ and then forwarding all the parameters to the dataclass __init__. The problem is that dataclass will not generate an __init__ method if one is explicitly specified. So it seems that a more likely successful solution is one that allows passing arbitrary field names to save_hyperparameters without them having to appear in the __init__. I think this should be possible, since the frame inspection call used in get_init_args does give correct results.

That is, _, _, _, local_vars = inspect.getargvalues(f) on the correct frame (L132 of parsing.py) yields {'self': BoringModel(a=WHATEVER), 'a': WHATEVER}.

@awaelchli
Copy link
Contributor

Okay, that works. I think that in that case we can use fields() to determine the fields of the model (thanks for the tip!) to determine the hyperparameters when model is a dataclass. So we need to modify the code of self.hyperparameters to support this feature I think.

@dave-epstein
Copy link

dave-epstein commented Jun 15, 2021

Yes, I think that makes sense to me. Note that you can use dataclasses.is_dataclass for this check.

@awaelchli
Copy link
Contributor

Can't guarantee it covers all edge cases, running tests now. Here is the draft: #7992
Of course, any input highly welcome.

@dave-epstein
Copy link

That change looks reasonable to me, but I'm also not a dataclass expert. I think a lot of the heavy lifting is done by fields, so the save_hyperparameters function should then automatically respect the semantics of the dataclass. Thanks for moving so fast on this!

@dave-epstein
Copy link

Maybe I'm missing something here but I didn't find a release where this commit was added. Has it just not been incorporated yet?

Thanks!

@awaelchli
Copy link
Contributor

This is a new feature that will be available with 1.4.
The 1.3.x releases contain bugfixes, but no new features.

@dave-epstein
Copy link

Makes sense. Thanks :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Is an improvement or enhancement help wanted Open to be worked on priority: 2 Low priority task
Projects
None yet
5 participants