Skip to content

Commit

Permalink
Improve Bures-Wasserstein distance (#468)
Browse files Browse the repository at this point in the history
* Improve Bures-Wasserstein distance

* Revert changes and modify sqrtm

* Fix typo

* Add changes to RELEASES.md

---------

Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
  • Loading branch information
francois-rozet and rflamary authored May 4, 2023
1 parent 2aeb591 commit 83dc498
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- Fix issues with cuda for ot.binary_search_circle and with gradients for ot.sliced_wasserstein_sphere (PR #457)
- Major documentation cleanup (PR #462, #467)
- Fix gradients for "Wasserstein2 Minibatch GAN" example (PR #466)
- Faster Bures-Wasserstein distance with NumPy backend (PR #468)

## 0.9.0

Expand Down
8 changes: 5 additions & 3 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,8 @@ def inv(self, a):
return scipy.linalg.inv(a)

def sqrtm(self, a):
return scipy.linalg.sqrtm(a)
L, V = np.linalg.eigh(a)
return (V * np.sqrt(L)[None, :]) @ V.T

def kl_div(self, p, q, eps=1e-16):
return np.sum(p * np.log(p / q + eps))
Expand Down Expand Up @@ -2433,7 +2434,7 @@ def inv(self, a):

def sqrtm(self, a):
L, V = cp.linalg.eigh(a)
return (V * self.sqrt(L)[None, :]) @ V.T
return (V * cp.sqrt(L)[None, :]) @ V.T

def kl_div(self, p, q, eps=1e-16):
return cp.sum(p * cp.log(p / q + eps))
Expand Down Expand Up @@ -2824,7 +2825,8 @@ def inv(self, a):
return tf.linalg.inv(a)

def sqrtm(self, a):
return tf.linalg.sqrtm(a)
L, V = tf.linalg.eigh(a)
return (V * tf.sqrt(L)[None, :]) @ V.T

def kl_div(self, p, q, eps=1e-16):
return tnp.sum(p * tnp.log(p / q + eps))
Expand Down
4 changes: 2 additions & 2 deletions ot/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def bures_wasserstein_distance(ms, mt, Cs, Ct, log=False):
where :
.. math::
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
Parameters
----------
Expand Down Expand Up @@ -264,7 +264,7 @@ def empirical_bures_wasserstein_distance(xs, xt, reg=1e-6, ws=None,
where :
.. math::
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s^{1/2} + \Sigma_t^{1/2} - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
Parameters
----------
Expand Down

0 comments on commit 83dc498

Please sign in to comment.