-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve Bures-Wasserstein distance #468
Conversation
It seems like the TensorFlow backend does not support |
Hello @francois-rozet, when a sepcific backend type misses a feature is a feature we usually add it to the backend, in this case maybe we shoudl add |
Also @francois-rozet could you please do a quick test of the timing before and after your speedup for different backends (at least numpy and pytorch) and put it in the text description of the PR? I like to have quantified performance gain in the history/github to know why we changed stuff. Also maybe a quick test that checks that the new function returns the same thing as the np.trace () up to numerical precision. It seems right but such a test will help detecting potential problems in the future. |
After some tests, I found that this implementation is only faster for NumPy. Computing the square root of a general matrix is indeed slower than computing its eigenvalues. However, computing the square root of a symmetric matrix takes more or less the same time as computing its eigenvalues. In fact, the PyTorch backend uses |
OK it make sens, happy I pushed you to investigate. You can leave the envals function in the backend it can be usefull in the future (trace norm regularization for instance) |
Oops I already removed it. Are the eigvals or the singular values necessary for trace norm regularization? And is it for a symmetric matrix? I turns out that |
OK no worry we can add them (properlyd depending on symmetry or not) later. I'm neraly OK for a merge but please add a short description of the PR in the RELEASES file file. |
I added a line to the import numpy as np
import scipy.linalg as sl
A = np.random.rand(512, 512) / 512 ** 0.5
A = A @ A.T + np.eye(512) * 1e-6 # definite positive
%timeit np.linalg.eigh(A)
%timeit sl.sqrtm(A) returns
The time gap increases for larger matrices. |
All tests have passed, but the CircleCI one. |
Types of changes
The implementation is based on two facts:
Then,$\mathrm{tr}(\sqrt{A})$ is the sum of the square-roots of the eigenvalues of $A$ .
See Lightning-AI/torchmetrics#1705.
Motivation and context / Related issue
Computing the square-root of a matrix is slow and unstable.
How has this been tested (if it applies)
The new implementation still passes the tests (at least with NumPy backend).
PR checklist