Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Low rank sinkhorn algorithm #568

Merged
merged 40 commits into from
Dec 5, 2023
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
f49f6b4
new file for lr sinkhorn
laudavid Oct 24, 2023
3c4b50f
lr sinkhorn, solve_sample, OTResultLazy
laudavid Oct 24, 2023
3034e57
add test functions + small modif lr_sin/solve_sample
laudavid Oct 25, 2023
085863a
add import to __init__
laudavid Oct 26, 2023
9becafc
modify low rank, remove solve_sample,OTResultLazy
laudavid Nov 3, 2023
855234d
pull from master
laudavid Nov 3, 2023
6ea251c
new file for lr sinkhorn
laudavid Oct 24, 2023
965e4d6
lr sinkhorn, solve_sample, OTResultLazy
laudavid Oct 24, 2023
fd5e26d
add test functions + small modif lr_sin/solve_sample
laudavid Oct 25, 2023
3df3b77
add import to __init__
laudavid Oct 26, 2023
fae28f7
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 3, 2023
ab5475b
remove test solve_sample
laudavid Nov 3, 2023
f1c8cdd
add value, value_linear, lazy_plan
laudavid Nov 8, 2023
b1a2136
Merge branch 'PythonOT:master' into master
laudavid Nov 8, 2023
9e51a83
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 8, 2023
df01cff
add comments to lr algorithm
laudavid Nov 8, 2023
a0b0a9d
Merge branch 'PythonOT:master' into master
laudavid Nov 9, 2023
7075c8b
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 9, 2023
5f2af0e
Merge branch 'PythonOT:master' into lowrank_v2
laudavid Nov 9, 2023
5bc9de9
modify test functions + add comments to lowrank
laudavid Nov 9, 2023
c66951b
Merge branch 'lowrank_v2' of https://github.com/hi-paris/POT into low…
laudavid Nov 9, 2023
6040e6f
modify __init__ with lowrank
laudavid Nov 9, 2023
a7fdffd
debug lowrank + test
laudavid Nov 14, 2023
d90c186
debug test function low_rank
laudavid Nov 14, 2023
ea3a3e0
error test
laudavid Nov 14, 2023
f6a36bf
Merge branch 'PythonOT:master' into master
laudavid Nov 15, 2023
fe067fd
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 15, 2023
5d3ed32
Merge branch 'PythonOT:master' into lowrank_v2
laudavid Nov 15, 2023
3e6b9aa
Merge branch 'lowrank_v2' of https://github.com/hi-paris/POT into low…
laudavid Nov 15, 2023
165e8f5
final debug of lowrank + add new test functions
laudavid Nov 15, 2023
de54bb9
branch up to date with master
laudavid Nov 17, 2023
bdcfaf6
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 22, 2023
d0d9f46
Merge branch 'master' of https://github.com/hi-paris/POT into lowrank_v2
laudavid Nov 24, 2023
ec96836
Merge branch 'PythonOT:master' into lowrank_v2
laudavid Nov 24, 2023
8c6ac67
Debug tests + add lowrank to solve_sample
laudavid Nov 24, 2023
fafc5f6
Merge branch 'lowrank_v2' of https://github.com/hi-paris/POT into low…
laudavid Nov 24, 2023
bc7af6b
fix torch backend for lowrank
laudavid Nov 25, 2023
b40705c
fix jax backend and skip tf
laudavid Nov 28, 2023
55c8d2b
fix pep 8 tests
laudavid Nov 28, 2023
218c54a
merge master + doc for lowrank
laudavid Dec 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ The contributors to this library are:
* [Ronak Mehta](https://ronakrm.github.io) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Xizheng Yu](https://github.com/x12hengyu) (Efficient Discrete Multi Marginal Optimal Transport Regularization)
* [Sonia Mazelet](https://github.com/SoniaMaz8) (Template based GNN layers)
* [Laurène David](https://github.com/laudavid) (Low rank sinkhorn)

## Acknowledgments

Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,5 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462.

[62] H. Van Assel, C. Vincent-Cuaz, T. Vayer, R. Flamary, N. Courty (2023). [Interpolating between Clustering and Dimensionality Reduction with Gromov-Wasserstein](https://arxiv.org/pdf/2310.03398.pdf). NeurIPS 2023 Workshop Optimal Transport and Machine Learning.

[63] Scetbon, M., Cuturi, M., & Peyré, G. (2021). [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf).
1 change: 1 addition & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
+ Added support for [Low-Rank Sinkhorn Factorization](https://arxiv.org/pdf/2103.04737.pdf) (PR #568)

#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
Expand Down
4 changes: 3 additions & 1 deletion ot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from . import factored
from . import solvers
from . import gaussian
from . import lowrank


# OT functions
Expand All @@ -52,6 +53,7 @@
from .weak import weak_optimal_transport
from .factored import factored_optimal_transport
from .solvers import solve, solve_gromov, solve_sample
from .lowrank import lowrank_sinkhorn

# utils functions
from .utils import dist, unif, tic, toc, toq
Expand All @@ -69,4 +71,4 @@
'factored_optimal_transport', 'solve', 'solve_gromov','solve_sample',
'smooth', 'stochastic', 'unbalanced', 'partial', 'regpath', 'solvers',
'binary_search_circle', 'wasserstein_circle',
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif']
'semidiscrete_wasserstein2_unif_circle', 'sliced_wasserstein_sphere_unif', 'lowrank_sinkhorn']
289 changes: 289 additions & 0 deletions ot/lowrank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,289 @@
"""
Low rank OT solvers
"""

# Author: Laurène David <laurene.david@ip-paris.fr>
#
# License: MIT License


import warnings
from .utils import unif, get_lowrank_lazytensor
from .backend import get_backend


def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None):
"""
Compute low rank decomposition of a sqeuclidean cost matrix.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

squared euclidean distance matrix* in the documentation. Plus could you specify the parameters as in lowrank_sinkhorn ?

This function won't work for other metrics.

See "Section 3.5, proposition 1" of the paper

References
----------
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to put the reference of the published paper.

Scetbon, M., Cuturi, M., & Peyré, G. (2021). "Low-rank Sinkhorn factorization". In International Conference on Machine Learning.

Also the number [63] should be updated after the merge, as new references have been included and [63] relates to another paper.

"Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737.
"""

if nx is None:
nx = get_backend(X_s, X_t)

ns = X_s.shape[0]
nt = X_t.shape[0]

# First low rank decomposition of the cost matrix (A)
array1 = nx.reshape(nx.sum(X_s**2, 1), (-1, 1))
array2 = nx.reshape(nx.ones(ns, type_as=X_s), (-1, 1))
M1 = nx.concatenate((array1, array2, -2 * X_s), axis=1)

# Second low rank decomposition of the cost matrix (B)
array1 = nx.reshape(nx.ones(nt, type_as=X_s), (-1, 1))
array2 = nx.reshape(nx.sum(X_t**2, 1), (-1, 1))
M2 = nx.concatenate((array1, array2, X_t), axis=1)

return M1, M2


def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to _LR_Dysktra as it seems specific to lowrank_sinkhorn

"""
Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you specify the parameters here too ?


References
----------
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021).
"Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737.

"""

# POT backend if None
if nx is None:
nx = get_backend(eps1, eps2, eps3, p1, p2)

Check warning on line 60 in ot/lowrank.py

View check run for this annotation

Codecov / codecov/patch

ot/lowrank.py#L60

Added line #L60 was not covered by tests

# ----------------- Initialisation of Dykstra algorithm -----------------
r = len(eps3) # rank
g_ = nx.copy(eps3) # \tilde{g}
q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2
v1_, v2_ = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # \tilde{v}^{(1)}, \tilde{v}^{(2)}
q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)}
err = 1 # initial error

# --------------------- Dykstra algorithm -------------------------

# See Section 3.3 - "Algorithm 2 LR-Dykstra" in paper

for ii in range(numItermax):
if err > stopThr:
# Compute u^{(1)} and u^{(2)}
u1 = p1 / nx.dot(eps1, v1_)
u2 = p2 / nx.dot(eps2, v2_)

# Compute g, g^{(3)}_1 and update \tilde{g}
g = nx.maximum(alpha, g_ * q3_1)
q3_1 = (g_ * q3_1) / g
g_ = nx.copy(g)

# Compute new value of g with \prod
prod1 = (v1_ * q1) * nx.dot(eps1.T, u1)
prod2 = (v2_ * q2) * nx.dot(eps2.T, u2)
g = (g_ * q3_2 * prod1 * prod2) ** (1 / 3)

# Compute v^{(1)} and v^{(2)}
v1 = g / nx.dot(eps1.T, u1)
v2 = g / nx.dot(eps2.T, u2)

# Compute q^{(1)}, q^{(2)} and q^{(3)}_2
q1 = (v1_ * q1) / v1
q2 = (v2_ * q2) / v2
q3_2 = (g_ * q3_2) / g

# Update values of \tilde{v}^{(1)}, \tilde{v}^{(2)} and \tilde{g}
v1_, v2_ = nx.copy(v1), nx.copy(v2)
g_ = nx.copy(g)

# Compute error
err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1))
err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2))
err = err1 + err2

else:
break

else:
if warn:
warnings.warn(
"Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
)

# Compute low rank matrices Q, R
Q = u1[:, None] * eps1 * v1[None, :]
R = u2[:, None] * eps2 * v2[None, :]

return Q, R, g


def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto",
numItermax=1000, stopThr=1e-9, warn=True, log=False):
r"""
Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints.

The function solves the following optimization problem:

.. math::
\mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle -
\mathrm{reg} \cdot H((Q,R,g))

where :
- :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term.
- :math: `Q` and `R` are the low-rank matrix decomposition of the OT plan
- :math: `g` is the weight vector for the low-rank decomposition of the OT plan
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
- :math: `r` is the rank of the OT plan
- :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem
\mathcal{C(a,b,r)} = \mathcal{C_1(a,b,r)} \cap \mathcal{C_2(r)} with
\mathcal{C_1(a,b,r)} = \{ (Q,R,g) s.t Q\mathbb{1}_r = a, R^T \mathbb{1}_m = b \}
\mathcal{C_2(r)} = \{ (Q,R,g) s.t Q\mathbb{1}_n = R^T \mathbb{1}_m = g \}


Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
a : array-like, shape (n_samples_a,)
samples weights in the source domain
b : array-like, shape (n_samples_b,)
samples weights in the target domain
reg : float, optional
Regularization term >0
rank: int, default "auto"
Nonnegative rank of the OT plan
alpha: int, default "auto" (1e-10)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more conventional in POT API, to set a None value as default such as: rank: int, optional. Default is None. Nonnegative rank of the OT plan. If None, min(ns, nt) is considered.

Same template for alpha.

NB: the rank actually has to be strictly positive.

Lower bound for the weight vector g (>0 and <1/r)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
log : bool, optional
record log if True


Returns
-------
lazy_plan : LazyTensor()
OT plan in a LazyTensor object of shape (shape_plan)
See :any:`LazyTensor` for more information.
value : float
Optimal value of the optimization problem
value_linear : float
Linear OT loss with the optimal OT
Q : array-like, shape (n_samples_a, r)
First low-rank matrix decomposition of the OT plan
R: array-like, shape (n_samples_b, r)
Second low-rank matrix decomposition of the OT plan
g : array-like, shape (r, )
Weight vector for the low-rank decomposition of the OT plan


References
----------
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update ref

"Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737.

"""

# POT backend
nx = get_backend(X_s, X_t)
ns, nt = X_s.shape[0], X_t.shape[0]

# Initialize weights a, b
if a is None:
a = unif(ns, type_as=X_s)
if b is None:
b = unif(nt, type_as=X_t)

# Compute rank (see Section 3.1, def 1)
r = rank
if rank == "auto":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if rank is None:.

r = min(ns, nt)

if alpha == "auto":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

alpha = 1e-10

# Dykstra algorithm won't converge if 1/rank < alpha (alpha is the lower bound for 1/rank)
# (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper)
if 1 / r < alpha:
raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format(
a=alpha, r=1 / rank))

# Low rank decomposition of the sqeuclidean cost matrix (A, B)
M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None)

# Compute gamma (see "Section 3.4, proposition 4" in the paper)
L = nx.sqrt(
3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) +
(reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2
)
gamma = 1 / (2 * L)

# Initialize the low rank matrices Q, R, g
Q = nx.ones((ns, r), type_as=a)
R = nx.ones((nt, r), type_as=a)
g = nx.ones(r, type_as=a)
k = 100

# -------------------------- Low rank algorithm ------------------------------
# see "Section 3.3, Algorithm 3 LOT" in the paper

for ii in range(k):
# Compute the C*R dot matrix using the lr decomposition of C
CR_ = nx.dot(M2.T, R)
CR = nx.dot(M1, CR_)

# Compute the C.t * Q dot matrix using the lr decomposition of C
CQ_ = nx.dot(M1.T, Q)
CQ = nx.dot(M2, CQ_)

diag_g = (1 / g)[None, :]

eps1 = nx.exp(-gamma * (CR * diag_g) - ((gamma * reg) - 1) * nx.log(Q))
eps2 = nx.exp(-gamma * (CQ * diag_g) - ((gamma * reg) - 1) * nx.log(R))
omega = nx.diag(nx.dot(Q.T, CR))
eps3 = nx.exp(gamma * omega / (g**2) - (gamma * reg - 1) * nx.log(g))

Q, R, g = LR_Dysktra(
eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx
)
Q = Q + 1e-16
R = R + 1e-16

# ----------------- Compute lazy_plan, value and value_linear ------------------
# see "Section 3.2: The Low-rank OT Problem" in the paper

# Compute lazy plan (using LazyTensor class)
lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g)

# Compute value_linear (using trace formula)
v1 = nx.dot(Q.T, M1)
v2 = nx.dot(R, (v1.T * diag_g).T)
value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2)))

# Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper)
reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q
reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g
reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R
value = value_linear + reg * (reg_Q + reg_g + reg_R)

if log:
dict_log = dict()
dict_log["value"] = value
dict_log["value_linear"] = value_linear
dict_log["lazy_plan"] = lazy_plan

return Q, R, g, dict_log

return Q, R, g
20 changes: 20 additions & 0 deletions ot/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2
from .gaussian import empirical_bures_wasserstein_distance
from .factored import factored_optimal_transport
from .lowrank import lowrank_sinkhorn

lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale']

Expand Down Expand Up @@ -1248,6 +1249,25 @@
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]

elif method == "lowrank":

if not metric.lower() in ['sqeuclidean']:
raise (NotImplementedError('Not implemented metric="{}"'.format(metric)))

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if reg is None:
reg = 0

Check warning on line 1262 in ot/solvers.py

View check run for this annotation

Codecov / codecov/patch

ot/solvers.py#L1262

Added line #L1262 was not covered by tests

Q, R, g, log = lowrank_sinkhorn(X_a, X_b, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True)
value = log['value']
value_linear = log['value_linear']
lazy_plan = log['lazy_plan']
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]

elif method.startswith('geomloss'): # Geomloss solver for entropi OT

split_method = method.split('_')
Expand Down
Loading
Loading