diff --git a/RELEASES.md b/RELEASES.md index 3366e2adb..586089bda 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -11,6 +11,7 @@ - Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457) - Major documentation cleanup (PR #462, #467) - Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466) +- Faster Bures-Wasserstein distance with NumPy backend (PR #468) ## 0.9.0 diff --git a/ot/backend.py b/ot/backend.py index a82c4486a..eecf9dd99 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1235,7 +1235,8 @@ def inv(self, a): return scipy.linalg.inv(a) def sqrtm(self, a): - return scipy.linalg.sqrtm(a) + L, V = np.linalg.eigh(a) + return (V * np.sqrt(L)[None, :]) @ V.T def kl_div(self, p, q, eps=1e-16): return np.sum(p * np.log(p / q + eps)) @@ -2433,7 +2434,7 @@ def inv(self, a): def sqrtm(self, a): L, V = cp.linalg.eigh(a) - return (V * self.sqrt(L)[None, :]) @ V.T + return (V * cp.sqrt(L)[None, :]) @ V.T def kl_div(self, p, q, eps=1e-16): return cp.sum(p * cp.log(p / q + eps)) @@ -2824,7 +2825,8 @@ def inv(self, a): return tf.linalg.inv(a) def sqrtm(self, a): - return tf.linalg.sqrtm(a) + L, V = tf.linalg.eigh(a) + return (V * tf.sqrt(L)[None, :]) @ V.T def kl_div(self, p, q, eps=1e-16): return tnp.sum(p * tnp.log(p / q + eps)) diff --git a/ot/gaussian.py b/ot/gaussian.py index 4ffb726e1..1a295567d 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -202,7 +202,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False): where : .. math:: - \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) Parameters ---------- @@ -264,7 +264,7 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None, where : .. math:: - \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) + \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) Parameters ----------