Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clarification Needed: GLM Predictive Functionality with Subset of Weights in ParametricLaplace #277

Open
xiaoyiming1999 opened this issue Feb 9, 2025 · 2 comments

Comments

@xiaoyiming1999
Copy link

Hi, I would like to express my sincere gratitude for your work in the field of Laplace approximation. Your contributions have not only benefited me greatly but also significantly advanced the development of this area.

Recently, while studying the laplace library you developed, I noticed a comment in the ParametricLaplace class within the call attribute:
"The GLM predictive is consistent with the curvature approximations used here. When Laplace is done only on a subset of parameters (i.e., some gradients are disabled), only nn predictive is supported."

This comment has left me puzzled because, in my experiments, when I set subset_of_weights to 'subnetork', I was still able to use pred_type='glm' without any issues. In article “Bayesian Deep Learning via Subnetwork Inference”, I also saw you combining linear Laplace approximation with subnetwork inference. I would greatly appreciate it if you could spare some time to clarify this matter for me. Your insights would be invaluable to my research, and I will do my utmost to promote your work within my field.

@aleximmer
Copy link
Owner

Hi, thanks for your interest and question! Would you mind sharing an example of you used the library so we can confirm this? Generally, SubnetworkLaplace should not support the glm predictive because that would require implementation of the functional_covariance, which is by default not implemented:

def functional_covariance(self, Js: torch.Tensor) -> torch.Tensor:
. In principle, this can be added but currently it should note be supported as far as I know.

@xiaoyiming1999
Copy link
Author

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc_1 = nn.Linear(10, 5)
        self.fc_2 = nn.Linear(5, 2)

    def forward(self, x):
        return self.fc_2(self.fc_1(x))

parameter_index=[i for i in range(20)]
subnetwork_indices = torch.as_tensor(parameter_index)
model = SimpleModel()
la = Laplace(
    model,
    likelihood='classification',
    subset_of_weights='subnetwork',
    subnetwork_indices=subnetwork_indices,
    hessian_structure='full'
)

X = torch.randn(1000, 10)
y = torch.randint(0, 2, (1000,))
loader = torch.utils.data.DataLoader(list(zip(X, y)), batch_size=10)

la.fit(loader)

x_test = torch.randn(3, 10)

# number of MC samples
n_samples = 1000

# For model uncertainty and data uncertainty
samples = la.predictive_samples(x_test, n_samples=n_samples, pred_type='glm',)

# GLM prediction
pred = la(x_test, pred_type='glm', link_approx='probit')

This is an example of how I use that library. Where la.predictive_samples(x_test, n_samples=n_samples, pred_type='glm',) is used to easily measure uncertainty and la(x_test, pred_type='glm', link_approx='probit') is used to directly make predictions. They both use subnetwork inference and GLM, but no errors were reported about them conflicting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants