-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add support for save_hyperparameters with Python Data Class #3494
Comments
It would be great to have this. However, I wouldn't say this is a bug. save_hyperparameters was never designed with data classes in mind, but it was mentioned before. or does it say in the docs somewhere that it is supported? |
@awaelchli I think it's a bug as a method should depend on the object not the context in which it's called--IIUC in this case the method can only be called from Edit: there is a reference to Edit2: just realize that another issue with current implementation based on init args is that if you use the learning rate finder, which mutates the lr in the model, then this updated lr is not logged. The model logs the wrong lr since it is mutated after init |
ok adding back the bug label.
that's with the latest version? can you make a separate issue please? thanks |
Yes, I think it’s appropriate to leave in same issue as the learning rate finder mutates the lr parameter of the LightningModule, which is reasonable, and the problem is that save_hyperparameters is based on the init args, when it should save parameters based on their current state when save_hyperparameters is called. Let me know if you think differently and happy to make another issue? |
@tbenst would you like to send a PR with the changes? I'm still not convinced what the bug is. |
This issue has been automatically marked as stale because it hasn't had any recent activity. This issue will be closed in 7 days if no further activity occurs. Thank you for your contributions, Pytorch Lightning Team! |
Hi, what's the current status of this issue? Is there any way to make a Thanks! |
@dave-epstein Does this work? @dataclass()
class ConvDecoder(pl.LightningModule):
imageChannels:int = 3
def __init__(self, **kwargs)
self.save_hyperparameters()
super().__init__(**kwargs) |
Thanks for the prompt response. This isn't working in my terminal (Python version 3.9.4, PL version 1.3.4). I get the following error: @dataclass
class Test(LightningModule):
x: int = 3
y: float = 5
def __post_init__(self): self.save_hyperparameters([f.name for f in fields(self)]) |
Sorry for the beginner questions but let's put @dataclass
class BoringModel(LightningModule):
a: int
def __post_init__(self):
# self.save_hyperparameters()
self.layer = torch.nn.Linear(32, 2) # does not work?? |
In this case, you add Excuse me if this turns out not to be 100% accurate either, this is just the best of my understanding :) |
I think you were trying to write an explicit That is, |
Okay, that works. I think that in that case we can use fields() to determine the fields of the model (thanks for the tip!) to determine the hyperparameters when model is a dataclass. So we need to modify the code of self.hyperparameters to support this feature I think. |
Yes, I think that makes sense to me. Note that you can use |
Can't guarantee it covers all edge cases, running tests now. Here is the draft: #7992 |
That change looks reasonable to me, but I'm also not a dataclass expert. I think a lot of the heavy lifting is done by |
Maybe I'm missing something here but I didn't find a release where this commit was added. Has it just not been incorporated yet? Thanks! |
This is a new feature that will be available with 1.4. |
Makes sense. Thanks :) |
Python Data Classes are convenient in that they automatically generate a bunch of boilerplate code for assigning data to a class. They are particularly useful for PyTorch models that have a lot of hyperparameters and thus a lot of boilerplate.
🐛 Bug
I believe #1896 introduced a new bug: when using a data class,
save_hyperparameters
no longer works since it depends on init args and we instead use__post_init__
with dataclasses. Explicitly passing strings does not work either. Perhaps when passingCode sample
Expected behavior
There should be a way to use
save_hyperparameters
with Data ClassesThe text was updated successfully, but these errors were encountered: