Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential issue with AsdlInterface backend for SubnetLaplace #271

Open
metodj opened this issue Jan 14, 2025 · 4 comments
Open

Potential issue with AsdlInterface backend for SubnetLaplace #271

metodj opened this issue Jan 14, 2025 · 4 comments

Comments

@metodj
Copy link
Contributor

metodj commented Jan 14, 2025

Hi!

I have a question about .diag() method in AsdlInterface backend class. When using AsdlGGN(stochastic=True), I was running into issues because the dimension of diag_ggn did not match with the number of model parameters. After looking more into it, I figured out the issue was in the for loop over self.model.modules() (see below):

for module in self.model.modules():
stats = getattr(module, "fisher", None)
if stats is None:
continue
vec.extend(stats.to_vector())
diag_ggn = torch.cat(vec)

Concretely, the problem seemed to be that when the if statement evaluates to True, nothing is inserted into vec, so it is not surprising that the length of diag_ggn is not equal to the number of model parameters. That represents a problem when using self.subnetwork_indices since for those to work correctly the length of diag_ggn needs to be equal to the number of parameters in self.model

if self.subnetwork_indices is not None:
diag_ggn = diag_ggn[self.subnetwork_indices]

After changing the implementation such that the 0 vector is inserted in cases the if statement evaluates to True, the issue was resolved:

def get_module_param_count(module):
   return sum(p.numel() for p in module.parameters(recurse=False))

device = f.device
      
vec = list()
for module in self.model.modules():
  stats = getattr(module, "fisher", None)
  if stats is None:
      n_params = get_module_param_count(module)
      if n_params > 0:
          vec.extend([torch.zeros(n_params, device=device)])
      continue
  
  vec.extend(stats.to_vector())
diag_ggn = torch.cat(vec)

I am wondering if you think my fix makes sense for the case when using DiagSubnetLaplace with AsdlGGN(stochastic=True) backend? Or does the fact that some modules have fisher=None points to the problem elsewhere?

@wiseodd
Copy link
Collaborator

wiseodd commented Jan 19, 2025

Unfortunately, the maintainer of SubnetLaplace (Erik) is not involved anymore.

If you only need to do subnet Laplace "per-tensor", you can just disable the grads of the params you don't want to apply Laplace on and do standard Laplace.

See https://aleximmer.github.io/Laplace/#subnetwork-laplace, last paragraph.

I'm happy to review PRs if you are keen. But I don't plan to maintain SubnetLaplace. Maybe we will even deprecate it soon.

@metodj
Copy link
Contributor Author

metodj commented Jan 20, 2025

Yeah, I tried doing p.requires_grad = False as suggested in the subnetwork example. However, in my case I want to do First-and-Last-Layer Laplace approximation (e.g., as done in some experiments in Do Bayesian Neural Networks Need To Be Fully Stochastic?) so I think this option is not suitable as I then kept getting the error:

RuntimeError: One of the differentiated Tensors does not require grad

This error is there because to get the gradients/Jacobians of the first layer, all subsequent layers need to have p.requires_grad=True, I think. Because of that I resorted to using SubnetLaplace with AsdlGGN(stochastic=True) instead.

May I just ask why do you plan to deprecate SubnetLaplace ? From my understanding, it currently still the best option for subnet options like mine, i.e., first-and-last-layer. Or am I overlooking something, and I actually do not need SubnetLaplace to do first-and-last layer Laplace?

@wiseodd
Copy link
Collaborator

wiseodd commented Jan 20, 2025

I see, p.requires_grad = False won't work since the backward path in the computation graph is blocked.

Sure, SubnetLaplace is more flexible, but the issue is that it's unmaintained and won't be given attention in the long run. Plus, it's not useful for large-scale NNs (like LLMs) since the Hessian/Jacobian of all params needs to be computed, unlike the requires_grad = False approach. Hence, it's only useful for smallish nets, in which case, why not just use the standard Laplace.

In any case, feel free to open a PR. Your question on whether it will break anything can be answered by running (and adding) unit tests :)

@metodj
Copy link
Contributor Author

metodj commented Jan 21, 2025

Regarding the "usefulness" of SubnetLaplace : using Laplace library I was able to fit first-and-last layer LA approximation on ~0.5B parameter NNs. And the use of Laplace library was crucial here, since I could make use of some more efficient EF/GGN approximations in AsdlGGN backend (i.e., using MC Fisher instead of Exact Fisher). So I would still say SubnetLaplace can be useful on larger nets and Laplace library can make its implementation easier.

Sounds good, will open a PR in the next days!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants