-
Notifications
You must be signed in to change notification settings - Fork 508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MRG] Low rank sinkhorn algorithm #568
Changes from 39 commits
f49f6b4
3c4b50f
3034e57
085863a
9becafc
855234d
6ea251c
965e4d6
fd5e26d
3df3b77
fae28f7
ab5475b
f1c8cdd
b1a2136
9e51a83
df01cff
a0b0a9d
7075c8b
5f2af0e
5bc9de9
c66951b
6040e6f
a7fdffd
d90c186
ea3a3e0
f6a36bf
fe067fd
5d3ed32
3e6b9aa
165e8f5
de54bb9
bdcfaf6
d0d9f46
ec96836
8c6ac67
fafc5f6
bc7af6b
b40705c
55c8d2b
218c54a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,289 @@ | ||
""" | ||
Low rank OT solvers | ||
""" | ||
|
||
# Author: Laurène David <laurene.david@ip-paris.fr> | ||
# | ||
# License: MIT License | ||
|
||
|
||
import warnings | ||
from .utils import unif, get_lowrank_lazytensor | ||
from .backend import get_backend | ||
|
||
|
||
def compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None): | ||
""" | ||
Compute low rank decomposition of a sqeuclidean cost matrix. | ||
This function won't work for other metrics. | ||
|
||
See "Section 3.5, proposition 1" of the paper | ||
|
||
References | ||
---------- | ||
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It would be better to put the reference of the published paper. Scetbon, M., Cuturi, M., & Peyré, G. (2021). "Low-rank Sinkhorn factorization". In International Conference on Machine Learning. Also the number [63] should be updated after the merge, as new references have been included and [63] relates to another paper. |
||
"Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. | ||
""" | ||
|
||
if nx is None: | ||
nx = get_backend(X_s, X_t) | ||
|
||
ns = X_s.shape[0] | ||
nt = X_t.shape[0] | ||
|
||
# First low rank decomposition of the cost matrix (A) | ||
array1 = nx.reshape(nx.sum(X_s**2, 1), (-1, 1)) | ||
array2 = nx.reshape(nx.ones(ns, type_as=X_s), (-1, 1)) | ||
M1 = nx.concatenate((array1, array2, -2 * X_s), axis=1) | ||
|
||
# Second low rank decomposition of the cost matrix (B) | ||
array1 = nx.reshape(nx.ones(nt, type_as=X_s), (-1, 1)) | ||
array2 = nx.reshape(nx.sum(X_t**2, 1), (-1, 1)) | ||
M2 = nx.concatenate((array1, array2, X_t), axis=1) | ||
|
||
return M1, M2 | ||
|
||
|
||
def LR_Dysktra(eps1, eps2, eps3, p1, p2, alpha, stopThr, numItermax, warn, nx=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. rename to |
||
""" | ||
Implementation of the Dykstra algorithm for the Low Rank sinkhorn OT solver. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you specify the parameters here too ? |
||
|
||
References | ||
---------- | ||
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). | ||
"Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. | ||
|
||
""" | ||
|
||
# POT backend if None | ||
if nx is None: | ||
nx = get_backend(eps1, eps2, eps3, p1, p2) | ||
|
||
# ----------------- Initialisation of Dykstra algorithm ----------------- | ||
r = len(eps3) # rank | ||
g_ = nx.copy(eps3) # \tilde{g} | ||
q3_1, q3_2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(3)}_1, q^{(3)}_2 | ||
v1_, v2_ = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # \tilde{v}^{(1)}, \tilde{v}^{(2)} | ||
q1, q2 = nx.ones(r, type_as=p1), nx.ones(r, type_as=p1) # q^{(1)}, q^{(2)} | ||
err = 1 # initial error | ||
|
||
# --------------------- Dykstra algorithm ------------------------- | ||
|
||
# See Section 3.3 - "Algorithm 2 LR-Dykstra" in paper | ||
|
||
for ii in range(numItermax): | ||
if err > stopThr: | ||
# Compute u^{(1)} and u^{(2)} | ||
u1 = p1 / nx.dot(eps1, v1_) | ||
u2 = p2 / nx.dot(eps2, v2_) | ||
|
||
# Compute g, g^{(3)}_1 and update \tilde{g} | ||
g = nx.maximum(alpha, g_ * q3_1) | ||
q3_1 = (g_ * q3_1) / g | ||
g_ = nx.copy(g) | ||
|
||
# Compute new value of g with \prod | ||
prod1 = (v1_ * q1) * nx.dot(eps1.T, u1) | ||
prod2 = (v2_ * q2) * nx.dot(eps2.T, u2) | ||
g = (g_ * q3_2 * prod1 * prod2) ** (1 / 3) | ||
|
||
# Compute v^{(1)} and v^{(2)} | ||
v1 = g / nx.dot(eps1.T, u1) | ||
v2 = g / nx.dot(eps2.T, u2) | ||
|
||
# Compute q^{(1)}, q^{(2)} and q^{(3)}_2 | ||
q1 = (v1_ * q1) / v1 | ||
q2 = (v2_ * q2) / v2 | ||
q3_2 = (g_ * q3_2) / g | ||
|
||
# Update values of \tilde{v}^{(1)}, \tilde{v}^{(2)} and \tilde{g} | ||
v1_, v2_ = nx.copy(v1), nx.copy(v2) | ||
g_ = nx.copy(g) | ||
|
||
# Compute error | ||
err1 = nx.sum(nx.abs(u1 * (eps1 @ v1) - p1)) | ||
err2 = nx.sum(nx.abs(u2 * (eps2 @ v2) - p2)) | ||
err = err1 + err2 | ||
|
||
else: | ||
break | ||
|
||
else: | ||
if warn: | ||
warnings.warn( | ||
"Sinkhorn did not converge. You might want to " | ||
"increase the number of iterations `numItermax` " | ||
) | ||
|
||
# Compute low rank matrices Q, R | ||
Q = u1[:, None] * eps1 * v1[None, :] | ||
R = u2[:, None] * eps2 * v2[None, :] | ||
|
||
return Q, R, g | ||
|
||
|
||
def lowrank_sinkhorn(X_s, X_t, a=None, b=None, reg=0, rank="auto", alpha="auto", | ||
numItermax=1000, stopThr=1e-9, warn=True, log=False): | ||
r""" | ||
Solve the entropic regularization optimal transport problem under low-nonnegative rank constraints. | ||
|
||
The function solves the following optimization problem: | ||
|
||
.. math:: | ||
\mathop{\inf_{(Q,R,g) \in \mathcal{C(a,b,r)}}} \langle C, Q\mathrm{diag}(1/g)R^T \rangle - | ||
\mathrm{reg} \cdot H((Q,R,g)) | ||
|
||
where : | ||
- :math:`C` is the (`dim_a`, `dim_b`) metric cost matrix | ||
- :math:`H((Q,R,g))` is the values of the three respective entropies evaluated for each term. | ||
- :math: `Q` and `R` are the low-rank matrix decomposition of the OT plan | ||
- :math: `g` is the weight vector for the low-rank decomposition of the OT plan | ||
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1) | ||
- :math: `r` is the rank of the OT plan | ||
- :math: `\mathcal{C(a,b,r)}` are the low-rank couplings of the OT problem | ||
\mathcal{C(a,b,r)} = \mathcal{C_1(a,b,r)} \cap \mathcal{C_2(r)} with | ||
\mathcal{C_1(a,b,r)} = \{ (Q,R,g) s.t Q\mathbb{1}_r = a, R^T \mathbb{1}_m = b \} | ||
\mathcal{C_2(r)} = \{ (Q,R,g) s.t Q\mathbb{1}_n = R^T \mathbb{1}_m = g \} | ||
|
||
|
||
Parameters | ||
---------- | ||
X_s : array-like, shape (n_samples_a, dim) | ||
samples in the source domain | ||
X_t : array-like, shape (n_samples_b, dim) | ||
samples in the target domain | ||
a : array-like, shape (n_samples_a,) | ||
samples weights in the source domain | ||
b : array-like, shape (n_samples_b,) | ||
samples weights in the target domain | ||
reg : float, optional | ||
Regularization term >0 | ||
rank: int, default "auto" | ||
Nonnegative rank of the OT plan | ||
alpha: int, default "auto" (1e-10) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is more conventional in POT API, to set a Same template for NB: the rank actually has to be strictly positive. |
||
Lower bound for the weight vector g (>0 and <1/r) | ||
numItermax : int, optional | ||
Max number of iterations | ||
stopThr : float, optional | ||
Stop threshold on error (>0) | ||
warn : bool, optional | ||
if True, raises a warning if the algorithm doesn't convergence. | ||
log : bool, optional | ||
record log if True | ||
|
||
|
||
Returns | ||
------- | ||
lazy_plan : LazyTensor() | ||
OT plan in a LazyTensor object of shape (shape_plan) | ||
See :any:`LazyTensor` for more information. | ||
value : float | ||
Optimal value of the optimization problem | ||
value_linear : float | ||
Linear OT loss with the optimal OT | ||
Q : array-like, shape (n_samples_a, r) | ||
First low-rank matrix decomposition of the OT plan | ||
R: array-like, shape (n_samples_b, r) | ||
Second low-rank matrix decomposition of the OT plan | ||
g : array-like, shape (r, ) | ||
Weight vector for the low-rank decomposition of the OT plan | ||
|
||
|
||
References | ||
---------- | ||
.. [63] Scetbon, M., Cuturi, M., & Peyré, G (2021). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. update ref |
||
"Low-Rank Sinkhorn Factorization" arXiv preprint arXiv:2103.04737. | ||
|
||
""" | ||
|
||
# POT backend | ||
nx = get_backend(X_s, X_t) | ||
ns, nt = X_s.shape[0], X_t.shape[0] | ||
|
||
# Initialize weights a, b | ||
if a is None: | ||
a = unif(ns, type_as=X_s) | ||
if b is None: | ||
b = unif(nt, type_as=X_t) | ||
|
||
# Compute rank (see Section 3.1, def 1) | ||
r = rank | ||
if rank == "auto": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
r = min(ns, nt) | ||
|
||
if alpha == "auto": | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
alpha = 1e-10 | ||
|
||
# Dykstra algorithm won't converge if 1/rank < alpha (alpha is the lower bound for 1/rank) | ||
# (see "Section 3.2: The Low-rank OT Problem (LOT)" in the paper) | ||
if 1 / r < alpha: | ||
raise ValueError("alpha ({a}) should be smaller than 1/rank ({r}) for the Dykstra algorithm to converge.".format( | ||
a=alpha, r=1 / rank)) | ||
|
||
# Low rank decomposition of the sqeuclidean cost matrix (A, B) | ||
M1, M2 = compute_lr_sqeuclidean_matrix(X_s, X_t, nx=None) | ||
|
||
# Compute gamma (see "Section 3.4, proposition 4" in the paper) | ||
L = nx.sqrt( | ||
3 * (2 / (alpha**4)) * ((nx.norm(M1) * nx.norm(M2)) ** 2) + | ||
(reg + (2 / (alpha**3)) * (nx.norm(M1) * nx.norm(M2))) ** 2 | ||
) | ||
gamma = 1 / (2 * L) | ||
|
||
# Initialize the low rank matrices Q, R, g | ||
Q = nx.ones((ns, r), type_as=a) | ||
R = nx.ones((nt, r), type_as=a) | ||
g = nx.ones(r, type_as=a) | ||
k = 100 | ||
|
||
# -------------------------- Low rank algorithm ------------------------------ | ||
# see "Section 3.3, Algorithm 3 LOT" in the paper | ||
|
||
for ii in range(k): | ||
# Compute the C*R dot matrix using the lr decomposition of C | ||
CR_ = nx.dot(M2.T, R) | ||
CR = nx.dot(M1, CR_) | ||
|
||
# Compute the C.t * Q dot matrix using the lr decomposition of C | ||
CQ_ = nx.dot(M1.T, Q) | ||
CQ = nx.dot(M2, CQ_) | ||
|
||
diag_g = (1 / g)[None, :] | ||
|
||
eps1 = nx.exp(-gamma * (CR * diag_g) - ((gamma * reg) - 1) * nx.log(Q)) | ||
eps2 = nx.exp(-gamma * (CQ * diag_g) - ((gamma * reg) - 1) * nx.log(R)) | ||
omega = nx.diag(nx.dot(Q.T, CR)) | ||
eps3 = nx.exp(gamma * omega / (g**2) - (gamma * reg - 1) * nx.log(g)) | ||
|
||
Q, R, g = LR_Dysktra( | ||
eps1, eps2, eps3, a, b, alpha, stopThr, numItermax, warn, nx | ||
) | ||
Q = Q + 1e-16 | ||
R = R + 1e-16 | ||
|
||
# ----------------- Compute lazy_plan, value and value_linear ------------------ | ||
# see "Section 3.2: The Low-rank OT Problem" in the paper | ||
|
||
# Compute lazy plan (using LazyTensor class) | ||
lazy_plan = get_lowrank_lazytensor(Q, R, 1 / g) | ||
|
||
# Compute value_linear (using trace formula) | ||
v1 = nx.dot(Q.T, M1) | ||
v2 = nx.dot(R, (v1.T * diag_g).T) | ||
value_linear = nx.sum(nx.diag(nx.dot(M2.T, v2))) | ||
|
||
# Compute value with entropy reg (entropy of Q, R, g must be computed separatly, see "Section 3.2" in the paper) | ||
reg_Q = nx.sum(Q * nx.log(Q + 1e-16)) # entropy for Q | ||
reg_g = nx.sum(g * nx.log(g + 1e-16)) # entropy for g | ||
reg_R = nx.sum(R * nx.log(R + 1e-16)) # entropy for R | ||
value = value_linear + reg * (reg_Q + reg_g + reg_R) | ||
|
||
if log: | ||
dict_log = dict() | ||
dict_log["value"] = value | ||
dict_log["value_linear"] = value_linear | ||
dict_log["lazy_plan"] = lazy_plan | ||
|
||
return Q, R, g, dict_log | ||
|
||
return Q, R, g |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
squared euclidean distance matrix* in the documentation. Plus could you specify the parameters as in
lowrank_sinkhorn
?