Handle logits with more than 2 axes in torch.func
diag and full backends
#178
Labels
enhancement
New feature or request
Milestone
Currently, we still assume that
logits.shape == (batch_size, n_classes)
. E.g.:Laplace/laplace/curvature/curvature.py
Lines 332 to 334 in a4d3ed6
The text was updated successfully, but these errors were encountered: