Skip to content

[WIP] Nystrom sinkhorn #742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,5 @@ Artificial Intelligence.
[74] Chewi, S., Maunu, T., Rigollet, P., & Stromme, A. J. (2020). [Gradient descent algorithms for Bures-Wasserstein barycenters](https://proceedings.mlr.press/v125/chewi20a.html). In Conference on Learning Theory (pp. 1276-1304). PMLR.

[75] Altschuler, J., Chewi, S., Gerber, P. R., & Stromme, A. (2021). [Averaging on the Bures-Wasserstein manifold: dimension-free convergence of gradient descent](https://papers.neurips.cc/paper_files/paper/2021/hash/b9acb4ae6121c941324b2b1d3fac5c30-Abstract.html). Advances in Neural Information Processing Systems, 34, 22132-22145.

[76] Altschuler, J., Bach, F., Rudi, A., Niles-Weed, J., [Massively scalable Sinkhorn distances via the Nyström method](https://proceedings.neurips.cc/paper_files/paper/2019/file/f55cadb97eaff2ba1980e001b0bd9842-Paper.pdf), Advances in Neural Information Processing Systems, 2019.
3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
- Backend implementation of `ot.dist` for (PR #701)
- Updated documentation Quickstart guide and User guide with new API (PR #726)
- Fix jax version for auto-grad (PR #732)
- Add Nystrom kernel approximation for Sinkhorn (PR #742)
- Added to each example in the examples gallery the information about the release version in which it was introduced (PR #743)

#### Closed issues
Expand All @@ -37,7 +38,7 @@ This new release contains several new features, starting with
a novel [Gaussian Mixture Model Optimal Transport (GMM-OT)](https://pythonot.github.io/master/gen_modules/ot.gmm.html#examples-using-ot-gmm-gmm-ot-apply-map) solver to compare GMM while enforcing the transport plan to remain a GMM, that benefits from a closed-form solution making it practical for high-dimensional matching problems. We also extended our general unbalanced OT solvers to support any non-negative reference measure in the regularization terms, before adding the novel [translation invariant UOT](https://pythonot.github.io/master/auto_examples/unbalanced-partial/plot_conv_sinkhorn_ti.html) solver showcasing a higher convergence speed. We also implemented several new solvers and enhanced existing ones to perform OT across spaces. These include a [semi-relaxed FGW barycenter](https://pythonot.github.io/master/auto_examples/gromov/plot_semirelaxed_gromov_wasserstein_barycenter.html) solver, coupled with new initialization heuristics for the inner divergence computation, to perform graph partitioning or dictionary learning. Followed by novel [unbalanced FGW and Co-optimal transport](https://pythonot.github.io/master/auto_examples/others/plot_outlier_detection_with_COOT_and_unbalanced_COOT.html) solvers to promote robustness to outliers in such matching problems. And we finally updated the implementation of partial GW now supporting asymmetric structures and the KL divergence, while leveraging a new generic conditional gradient solver for partial transport problems enabling significant speed improvements. These latest updates required some modifications to the line search functions of our generic conditional gradient solver, paving the way for future improvements to other GW-based solvers. Last but not least, we implemented a pre-commit scheme to automatically correct common programming mistakes likely to be made by our future contributors.

This release also contains few bug fixes, concerning the support of any metric in `ot.emd_1d` / `ot.emd2_1d`, and the support of any weights in `ot.gaussian`.

#### Breaking change
- Custom functions provided as parameter `line_search` to `ot.optim.generic_conditional_gradient` must now have the signature `line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)`, adding as input `df_G` the gradient of the regularizer evaluated at the transport plan `G`. This change aims at improving speed of solvers having quadratic polynomial functions as regularizer such as the Gromov-Wassertein loss (PR #663).

Expand Down
169 changes: 169 additions & 0 deletions examples/others/plot_nystroem_approximation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
"""
============================
Nyström approximation for OT
============================

Shows how to use Nyström kernel approximation for approximating the Sinkhorn algorithm in linear time.


"""

# Author: Titouan Vayer <titouan.vayer@inria.fr>
#
# License: MIT License

# sphinx_gallery_thumbnail_number = 2

import numpy as np
from ot.lowrank import kernel_nystroem, sinkhorn_low_rank_kernel
from ot.bregman import empirical_sinkhorn_nystroem
import math
import ot
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm

##############################################################################
# Generate data
# -------------

# %%
offset = 1
n_samples_per_blob = 500 # We use 2D ''blobs'' data
random_state = 42
std = 0.2 # standard deviation
np.random.seed(random_state)

centers = np.array(
[
[-offset, -offset], # Class 0 - blob 1
[-offset, offset], # Class 0 - blob 2
[offset, -offset], # Class 1 - blob 1
[offset, offset], # Class 1 - blob 2
]
)

X_list = []
y_list = []

for i, center in enumerate(centers):
blob_points = np.random.randn(n_samples_per_blob, 2) * std + center
label = 0 if i < 2 else 1
X_list.append(blob_points)
y_list.append(np.full(n_samples_per_blob, label))

X = np.vstack(X_list)
y = np.concatenate(y_list)
Xs = X[y == 0] # source data
Xt = X[y == 1] # target data


##############################################################################
# Plot data
# ---------

# %%
plt.scatter(Xs[:, 0], Xs[:, 1], label="Source")
plt.scatter(Xt[:, 0], Xt[:, 1], label="Target")
plt.legend()

##############################################################################
# Compute the Nyström approximation of the Gaussian kernel
# --------------------------------------------------------

# %%
reg = 5.0 # proportional to the std of the Gaussian kernel
anchors = 5 # number of anchor points for the Nyström approximation
ot.tic()
left_factor, right_factor = kernel_nystroem(
Xs, Xt, anchors=anchors, sigma=math.sqrt(reg / 2.0), random_state=random_state
)
ot.toc()

##############################################################################
# Use this approximation in a Sinkhorn algorithm with low rank kernel.
# Each matrix/vector product in the Sinkhorn is accelerated
# since :math:`Kv = K_1 (K_2^\top v)` can be computed in :math:`O(nr)` time
# instead of :math:`O(n^2)`

# %%
numItermax = 1000
stopThr = 1e-7
verbose = True
a, b = None, None
warn = True
warmstart = None
ot.tic()
u, v, dict_log = sinkhorn_low_rank_kernel(
K1=left_factor,
K2=right_factor,
a=a,
b=b,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=True,
warn=warn,
warmstart=warmstart,
)
ot.toc()
##############################################################################
# Compare with Sinkhorn
# ---------------------

# %%
M = ot.dist(Xs, Xt)
ot.tic()
G, log_ = ot.sinkhorn(
a=[],
b=[],
M=M,
reg=reg,
numItermax=numItermax,
verbose=verbose,
log=True,
warn=warn,
warmstart=warmstart,
)
ot.toc()

##############################################################################
# Use directly ot.bregman.empirical_sinkhorn_nystroem
# --------------------------------------------------

# %%
ot.tic()
G_nys = empirical_sinkhorn_nystroem(
Xs,
Xt,
anchors=anchors,
reg=reg,
numItermax=numItermax,
verbose=True,
random_state=random_state,
)[:]
ot.toc()
# %%
ot.tic()
G_sinkh = ot.bregman.empirical_sinkhorn(
Xs, Xt, reg=reg, numIterMax=numItermax, verbose=True
)
ot.toc()

##############################################################################
# Compare OT plans
# ----------------

fig, ax = plt.subplots(1, 2, figsize=(10, 4), constrained_layout=True)
vmin = min(G_sinkh.min(), G_nys.min())
vmax = max(G_sinkh.max(), G_nys.max())
norm = LogNorm(vmin=vmin, vmax=vmax)
im0 = ax[0].imshow(G_sinkh, norm=norm, cmap="coolwarm")
im1 = ax[1].imshow(G_nys, norm=norm, cmap="coolwarm")
cbar = fig.colorbar(im1, ax=ax, orientation="vertical", fraction=0.046, pad=0.04)
ax[0].set_title("OT plan Sinkhorn")
ax[1].set_title("OT plan Nyström Sinkhorn")
for a in ax:
a.set_xticks([])
a.set_yticks([])
plt.show()
15 changes: 15 additions & 0 deletions ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,6 +1368,9 @@
def inv(self, a):
return scipy.linalg.inv(a)

def pinv(self, a, hermitian=False):
return np.linalg.pinv(a, hermitian=hermitian)

def sqrtm(self, a):
L, V = np.linalg.eigh(a)
L = np.sqrt(L)
Expand Down Expand Up @@ -1781,6 +1784,9 @@
def inv(self, a):
return jnp.linalg.inv(a)

def pinv(self, a, hermitian=False):
return jnp.linalg.pinv(a, hermitian=hermitian)

Check warning on line 1788 in ot/backend.py

View check run for this annotation

Codecov / codecov/patch

ot/backend.py#L1788

Added line #L1788 was not covered by tests

def sqrtm(self, a):
L, V = jnp.linalg.eigh(a)
L = jnp.sqrt(L)
Expand Down Expand Up @@ -2314,6 +2320,9 @@
def inv(self, a):
return torch.linalg.inv(a)

def pinv(self, a, hermitian=False):
return torch.linalg.pinv(a, hermitian=hermitian)

Check warning on line 2324 in ot/backend.py

View check run for this annotation

Codecov / codecov/patch

ot/backend.py#L2324

Added line #L2324 was not covered by tests

def sqrtm(self, a):
L, V = torch.linalg.eigh(a)
L = torch.sqrt(L)
Expand Down Expand Up @@ -2728,6 +2737,9 @@
def inv(self, a):
return cp.linalg.inv(a)

def pinv(self, a, hermitian=False):
return cp.linalg.pinv(a)

def sqrtm(self, a):
L, V = cp.linalg.eigh(a)
L = cp.sqrt(L)
Expand Down Expand Up @@ -3164,6 +3176,9 @@
def inv(self, a):
return tf.linalg.inv(a)

def pinv(self, a, hermitian=False):
return tf.linalg.pinv(a)

Check warning on line 3180 in ot/backend.py

View check run for this annotation

Codecov / codecov/patch

ot/backend.py#L3180

Added line #L3180 was not covered by tests

def sqrtm(self, a):
L, V = tf.linalg.eigh(a)
L = tf.sqrt(L)
Expand Down
4 changes: 4 additions & 0 deletions ot/bregman/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
empirical_sinkhorn,
empirical_sinkhorn2,
empirical_sinkhorn_divergence,
empirical_sinkhorn_nystroem,
empirical_sinkhorn_nystroem2,
)

from ._screenkhorn import screenkhorn
Expand Down Expand Up @@ -71,6 +73,8 @@
"empirical_sinkhorn2",
"empirical_sinkhorn2_geomloss",
"empirical_sinkhorn_divergence",
"empirical_sinkhorn_nystroem",
"empirical_sinkhorn_nystroem2",
"geomloss",
"screenkhorn",
"unmix",
Expand Down
Loading
Loading