-
Notifications
You must be signed in to change notification settings - Fork 76
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Potential issue with AsdlInterface
backend for SubnetLaplace
#271
Comments
Unfortunately, the maintainer of If you only need to do subnet Laplace "per-tensor", you can just disable the grads of the params you don't want to apply Laplace on and do standard Laplace. See https://aleximmer.github.io/Laplace/#subnetwork-laplace, last paragraph. I'm happy to review PRs if you are keen. But I don't plan to maintain |
Yeah, I tried doing
This error is there because to get the gradients/Jacobians of the first layer, all subsequent layers need to have May I just ask why do you plan to deprecate |
I see, Sure, In any case, feel free to open a PR. Your question on whether it will break anything can be answered by running (and adding) unit tests :) |
Regarding the "usefulness" of Sounds good, will open a PR in the next days! |
Hi!
I have a question about
.diag()
method inAsdlInterface
backend class. When usingAsdlGGN(stochastic=True)
, I was running into issues because the dimension ofdiag_ggn
did not match with the number of model parameters. After looking more into it, I figured out the issue was in the for loop overself.model.modules()
(see below):Laplace/laplace/curvature/asdl.py
Lines 200 to 205 in 6f17099
Concretely, the problem seemed to be that when the
if
statement evaluates toTrue
, nothing is inserted intovec
, so it is not surprising that the length ofdiag_ggn
is not equal to the number of model parameters. That represents a problem when usingself.subnetwork_indices
since for those to work correctly the length ofdiag_ggn
needs to be equal to the number of parameters inself.model
Laplace/laplace/curvature/asdl.py
Lines 206 to 207 in 6f17099
After changing the implementation such that the 0 vector is inserted in cases the
if
statement evaluates toTrue
, the issue was resolved:I am wondering if you think my fix makes sense for the case when using
DiagSubnetLaplace
withAsdlGGN(stochastic=True)
backend? Or does the fact that some modules havefisher=None
points to the problem elsewhere?The text was updated successfully, but these errors were encountered: