Skip to content

Commit

Permalink
[MRG] Projection Robust Wasserstein (#267)
Browse files Browse the repository at this point in the history
* ot.dr: PRW code; text.text_dr: PRW test code.

* ot.dr: PRW code; test.test_dr: PRW test code.

* fix errors: pep8(3.8)

* fix errors: pep8(3.8)

* modified readme; prw code review

* fix pep error

* edit comment

* modified math comment

Co-authored-by: Rémi Flamary <remi.flamary@gmail.com>
  • Loading branch information
mhhuang95 and rflamary authored Sep 6, 2021
1 parent c105dcb commit 96bf1a4
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ The contributors to this library are
* [Mokhtar Z. Alaya](http://mzalaya.github.io/) (Screenkhorn)
* [Ievgen Redko](https://ievred.github.io/) (Laplacian DA, JCPOT)
* [Adrien Corenflos](https://adriencorenflos.github.io/) (Sliced Wasserstein Distance)
* [Minhui Huang](https://mhhuang95.github.io) (Projection Robust Wasserstein Distance)

This toolbox benefit a lot from open source research and we would like to thank the following persons for providing some code (in various languages):

Expand Down Expand Up @@ -283,3 +284,5 @@ You can also post bug reports and feature requests in Github issues. Make sure t
[30] Flamary R., Courty N., Tuia D., Rakotomamonjy A. (2014). [Optimal transport with Laplacian regularization: Applications to domain adaptation and shape matching](https://remi.flamary.com/biblio/flamary2014optlaplace.pdf), NIPS Workshop on Optimal Transport and Machine Learning OTML, 2014.

[31] Bonneel, Nicolas, et al. [Sliced and radon wasserstein barycenters of measures](https://perso.liris.cnrs.fr/nicolas.bonneel/WassersteinSliced-JMIV.pdf), Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45

[32] Huang, M., Ma S., Lai, L. (2021). [A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance](http://proceedings.mlr.press/v139/huang21e.html), Proceedings of the 38th International Conference on Machine Learning (ICML).
114 changes: 114 additions & 0 deletions ot/dr.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License

Expand Down Expand Up @@ -198,3 +199,116 @@ def proj(X):
return (X - mx.reshape((1, -1))).dot(Popt)

return Popt, proj


def projection_robust_wasserstein(X, Y, a, b, tau, U0=None, reg=0.1, k=2, stopThr=1e-3, maxiter=100, verbose=0):
r"""
Projection Robust Wasserstein Distance [32]
The function solves the following optimization problem:
.. math::
\max_{U \in St(d, k)} \min_{\pi \in \Pi(\mu,\nu)} \sum_{i,j} \pi_{i,j} \|U^T(x_i - y_j)\|^2 - reg * H(\pi)
- :math:`U` is a linear projection operator in the Stiefel(d, k) manifold
- :math:`H(\pi)` is entropy regularizer
- :math:`x_i`, :math:`y_j` are samples of measures \mu and \nu respectively
Parameters
----------
X : ndarray, shape (n, d)
Samples from measure \mu
Y : ndarray, shape (n, d)
Samples from measure \nu
a : ndarray, shape (n, )
weights for measure \mu
b : ndarray, shape (n, )
weights for measure \nu
tau : float
stepsize for Riemannian Gradient Descent
U0 : ndarray, shape (d, p)
Initial starting point for projection.
reg : float, optional
Regularization term >0 (entropic regularization)
k : int
Subspace dimension
stopThr : float, optional
Stop threshold on error (>0)
verbose : int, optional
Print information along iterations.
Returns
-------
pi : ndarray, shape (n, n)
Optimal transportation matrix for the given parameters
U : ndarray, shape (d, k)
Projection operator.
References
----------
.. [32] Huang, M. , Ma S. & Lai L. (2021).
A Riemannian Block Coordinate Descent Method for Computing
the Projection Robust Wasserstein Distance, ICML.
""" # noqa

# initialization
n, d = X.shape
m, d = Y.shape
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
u = np.ones(n) / n
v = np.ones(m) / m
ones = np.ones((n, m))

assert d > k

if U0 is None:
U = np.random.randn(d, k)
U, _ = np.linalg.qr(U)
else:
U = U0

def Vpi(X, Y, a, b, pi):
# Return the second order matrix of the displacements: sum_ij { (pi)_ij (X_i-Y_j)(X_i-Y_j)^T }.
A = X.T.dot(pi).dot(Y)
return X.T.dot(np.diag(a)).dot(X) + Y.T.dot(np.diag(np.sum(pi, 0))).dot(Y) - A - A.T

err = 1
iter = 0

while err > stopThr and iter < maxiter:

# Projected cost matrix
UUT = U.dot(U.T)
M = np.diag(np.diag(X.dot(UUT.dot(X.T)))).dot(ones) + ones.dot(
np.diag(np.diag(Y.dot(UUT.dot(Y.T))))) - 2 * X.dot(UUT.dot(Y.T))

A = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=A)
np.exp(A, out=A)

# Sinkhorn update
Ap = (1 / a).reshape(-1, 1) * A
AtransposeU = np.dot(A.T, u)
v = np.divide(b, AtransposeU)
u = 1. / np.dot(Ap, v)
pi = u.reshape((-1, 1)) * A * v.reshape((1, -1))

V = Vpi(X, Y, a, b, pi)

# Riemannian gradient descent
G = 2 / reg * V.dot(U)
GTU = G.T.dot(U)
xi = G - U.dot(GTU + GTU.T) / 2 # Riemannian gradient
U, _ = np.linalg.qr(U + tau * xi) # Retraction by QR decomposition

grad_norm = np.linalg.norm(xi)
err = max(reg * grad_norm, np.linalg.norm(np.sum(pi, 0) - b, 1))

f_val = np.trace(U.T.dot(V.dot(U)))
if verbose:
print('RBCD Iteration: ', iter, ' error', err, '\t fval: ', f_val)

iter = iter + 1

return pi, U
37 changes: 37 additions & 0 deletions test/test_dr.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for module dr on Dimensionality Reduction """

# Author: Remi Flamary <remi.flamary@unice.fr>
# Minhui Huang <mhhuang@ucdavis.edu>
#
# License: MIT License

Expand Down Expand Up @@ -57,3 +58,39 @@ def test_wda():
projwda(xs)

np.testing.assert_allclose(np.sum(Pwda**2, 0), np.ones(p))


@pytest.mark.skipif(nogo, reason="Missing modules (autograd or pymanopt)")
def test_prw():
d = 100 # Dimension
n = 100 # Number samples
k = 3 # Subspace dimension
dim = 3

def fragmented_hypercube(n, d, dim):
assert dim <= d
assert dim >= 1
assert dim == int(dim)

a = (1. / n) * np.ones(n)
b = (1. / n) * np.ones(n)

# First measure : uniform on the hypercube
X = np.random.uniform(-1, 1, size=(n, d))

# Second measure : fragmentation
tmp_y = np.random.uniform(-1, 1, size=(n, d))
Y = tmp_y + 2 * np.sign(tmp_y) * np.array(dim * [1] + (d - dim) * [0])
return a, b, X, Y

a, b, X, Y = fragmented_hypercube(n, d, dim)

tau = 0.002
reg = 0.2

pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, reg=reg, k=k, maxiter=1000, verbose=1)

U0 = np.random.randn(d, k)
U0, _ = np.linalg.qr(U0)

pi, U = ot.dr.projection_robust_wasserstein(X, Y, a, b, tau, U0=U0, reg=reg, k=k, maxiter=1000, verbose=1)

0 comments on commit 96bf1a4

Please sign in to comment.