From 2fe69eb130827560ada704bc25998397c4357821 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Flamary?= Date: Thu, 4 Nov 2021 11:00:09 +0100 Subject: [PATCH] [MRG] Make gromov loss differentiable wrt matrices and weights (#302) * grmov differentable * new stuff * test gromov gradients * fgwdifferentiable * fgw tested * correc name test * add awesome example with gromov optimizatrion * pep8+ typos * damn pep8 * thunbnail * remove prints --- README.md | 9 +- .../backends/plot_optim_gromov_pytorch.py | 260 ++++++++++++++++++ ot/__init__.py | 2 + ot/gromov.py | 141 ++++++++-- ot/optim.py | 3 +- test/test_gromov.py | 76 +++++ 6 files changed, 460 insertions(+), 31 deletions(-) create mode 100644 examples/backends/plot_optim_gromov_pytorch.py diff --git a/README.md b/README.md index ff32c53be..08db0039d 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ POT provides the following generic OT solvers (links to examples): * Debiased Sinkhorn barycenters [Sinkhorn divergence barycenter](https://pythonot.github.io/auto_examples/barycenters/plot_debiased_barycenter.html) [37] * [Smooth optimal transport solvers](https://pythonot.github.io/auto_examples/plot_OT_1D_smooth.html) (dual and semi-dual) for KL and squared L2 regularizations [17]. * Non regularized [Wasserstein barycenters [16] ](https://pythonot.github.io/auto_examples/barycenters/plot_barycenter_lp_vs_entropic.html)) with LP solver (only small scale). -* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]) +* [Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) and [GW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_gromov_barycenter.html) (exact [13] and regularized [12]), differentiable using gradients from * [Fused-Gromov-Wasserstein distances solver](https://pythonot.github.io/auto_examples/gromov/plot_fgw.html#sphx-glr-auto-examples-plot-fgw-py) and [FGW barycenters](https://pythonot.github.io/auto_examples/gromov/plot_barycenter_fgw.html) [24] * [Stochastic solver](https://pythonot.github.io/auto_examples/plot_stochastic.html) for Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * [Stochastic solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] @@ -295,5 +295,8 @@ You can also post bug reports and feature requests in Github issues. Make sure t via optimal transport and diffusions](http://proceedings.mlr.press/v97/liutkus19a/liutkus19a.pdf). In International Conference on Machine Learning (pp. 4104-4113). PMLR. -[37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International -Conference on Machine Learning, PMLR 119:4692-4701, 2020 \ No newline at end of file +[37] Janati, H., Cuturi, M., Gramfort, A. [Debiased sinkhorn barycenters](http://proceedings.mlr.press/v119/janati20a/janati20a.pdf) Proceedings of the 37th International +Conference on Machine Learning, PMLR 119:4692-4701, 2020 + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, [Online Graph +Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Conference on Machine Learning (ICML), 2021. \ No newline at end of file diff --git a/examples/backends/plot_optim_gromov_pytorch.py b/examples/backends/plot_optim_gromov_pytorch.py new file mode 100644 index 000000000..465f61237 --- /dev/null +++ b/examples/backends/plot_optim_gromov_pytorch.py @@ -0,0 +1,260 @@ +r""" +================================= +Optimizing the Gromov-Wasserstein distance with PyTorch +================================= + +In this exemple we use the pytorch backend to optimize the Gromov-Wasserstein +(GW) loss between two graphs expressed as empirical distribution. + +In the first example we optimize the weights on the node of a simple template +graph so that it minimizes the GW with a given Stochastic Block Model graph. +We can see that this actually recovers the proportion of classes in the SBM +and allows for an accurate clustering of the nodes using the GW optimal plan. + +In a second example we optimize simultaneously the weights and the sructure of +the template graph which allows us to perform graph compression and to recover +other properties of the SBM. + +The backend actually uses the gradients expressed in [38] to optimize the +weights. + +[38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online Graph +Dictionary Learning, International Conference on Machine Learning (ICML), 2021. + +""" +# Author: RĂ©mi Flamary +# +# License: MIT License + +# sphinx_gallery_thumbnail_number = 3 + +from sklearn.manifold import MDS +import numpy as np +import matplotlib.pylab as pl +import torch + +import ot +from ot.gromov import gromov_wasserstein2 + +# %% +# Graph generation +# --------------- + +rng = np.random.RandomState(42) + + +def get_sbm(n, nc, ratio, P): + nbpc = np.round(n * ratio).astype(int) + n = np.sum(nbpc) + C = np.zeros((n, n)) + for c1 in range(nc): + for c2 in range(c1 + 1): + if c1 == c2: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): + for j in range(np.sum(nbpc[:c2]), i): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + else: + for i in range(np.sum(nbpc[:c1]), np.sum(nbpc[:c1 + 1])): + for j in range(np.sum(nbpc[:c2]), np.sum(nbpc[:c2 + 1])): + if rng.rand() <= P[c1, c2]: + C[i, j] = 1 + + return C + C.T + + +n = 100 +nc = 3 +ratio = np.array([.5, .3, .2]) +P = np.array(0.6 * np.eye(3) + 0.05 * np.ones((3, 3))) +C1 = get_sbm(n, nc, ratio, P) + +# get 2d position for nodes +x1 = MDS(dissimilarity='precomputed', random_state=0).fit_transform(1 - C1) + + +def plot_graph(x, C, color='C0', s=None): + for j in range(C.shape[0]): + for i in range(j): + if C[i, j] > 0: + pl.plot([x[i, 0], x[j, 0]], [x[i, 1], x[j, 1]], alpha=0.2, color='k') + pl.scatter(x[:, 0], x[:, 1], c=color, s=s, zorder=10, edgecolors='k', cmap='tab10', vmax=9) + + +pl.figure(1, (10, 5)) +pl.clf() +pl.subplot(1, 2, 1) +plot_graph(x1, C1, color='C0') +pl.title("SBM Graph") +pl.axis("off") +pl.subplot(1, 2, 2) +pl.imshow(C1, interpolation='nearest') +pl.title("Adjacency matrix") +pl.axis("off") + + +# %% +# Optimizing the weights of a simple template C0=eye(3) to fit Graph 1 +# ------------------------------------------------ +# The adajacency matrix C1 is block diagonal with 3 blocks. We want to +# optimize the weights of a simple template C0=eye(3) and see if we can +# recover the proportion of classes from the SBM (up to a permutation). + +C0 = np.eye(3) + + +def min_weight_gw(C1, C2, a2, nb_iter_max=100, lr=1e-2): + """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + + # use pyTorch for our data + C1_torch = torch.tensor(C1) + C2_torch = torch.tensor(C2) + + a0 = rng.rand(C1.shape[0]) # random_init + a0 /= a0.sum() # on simplex + a1_torch = torch.tensor(a0).requires_grad_(True) + a2_torch = torch.tensor(a2) + + loss_iter = [] + + for i in range(nb_iter_max): + + loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + #print("{:03d} | {}".format(i, loss_iter[-1])) + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a1_torch.grad + a1_torch -= grad * lr # step + a1_torch.grad.zero_() + a1_torch.data = ot.utils.proj_simplex(a1_torch) + + a1 = a1_torch.clone().detach().cpu().numpy() + + return a1, loss_iter + + +a0_est, loss_iter0 = min_weight_gw(C0, C1, ot.unif(n), nb_iter_max=100, lr=1e-2) + +pl.figure(2) +pl.plot(loss_iter0) +pl.title("Loss along iterations") + +print("Estimated weights : ", a0_est) +print("True proportions : ", ratio) + + +# %% +# It is clear that the optimization has converged and that we recover the +# ratio of the different classes in the SBM graph up to a permutation. + + +# %% +# Community clustering with uniform and estimated weights +# -------------------------------------------- +# The GW OT plan can be used to perform a clustering of the nodes of a graph +# when computing the GW with a simple template like C0 by labeling nodes in +# the original graph using by the index of the noe in the template receiving +# the most mass. +# +# We show here the result of such a clustering when using uniform weights on +# the template C0 and when using the optimal weights previously estimated. + + +T_unif = ot.gromov_wasserstein(C1, C0, ot.unif(n), ot.unif(3)) +label_unif = T_unif.argmax(1) + +T_est = ot.gromov_wasserstein(C1, C0, ot.unif(n), a0_est) +label_est = T_est.argmax(1) + +pl.figure(3, (10, 5)) +pl.clf() +pl.subplot(1, 2, 1) +plot_graph(x1, C1, color=label_unif) +pl.title("Graph clustering unif. weights") +pl.axis("off") +pl.subplot(1, 2, 2) +plot_graph(x1, C1, color=label_est) +pl.title("Graph clustering est. weights") +pl.axis("off") + + +# %% +# Graph compression with GW +# ------------------------- + +# Now we optimize both the weights and structure of a small graph that +# minimize the GW distance wrt our data graph. This can be seen as graph +# compression but can also recover important properties of an SBM such +# as its class proportion but also its matrix of probability of links between +# classes + + +def graph_compession_gw(nb_nodes, C2, a2, nb_iter_max=100, lr=1e-2): + """ solve min_a GW(C1,C2,a, a2) by gradient descent""" + + # use pyTorch for our data + + C2_torch = torch.tensor(C2) + a2_torch = torch.tensor(a2) + + a0 = rng.rand(nb_nodes) # random_init + a0 /= a0.sum() # on simplex + a1_torch = torch.tensor(a0).requires_grad_(True) + C0 = np.eye(nb_nodes) + C1_torch = torch.tensor(C0).requires_grad_(True) + + loss_iter = [] + + for i in range(nb_iter_max): + + loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) + + loss_iter.append(loss.clone().detach().cpu().numpy()) + loss.backward() + + #print("{:03d} | {}".format(i, loss_iter[-1])) + + # performs a step of projected gradient descent + with torch.no_grad(): + grad = a1_torch.grad + a1_torch -= grad * lr # step + a1_torch.grad.zero_() + a1_torch.data = ot.utils.proj_simplex(a1_torch) + + grad = C1_torch.grad + C1_torch -= grad * lr # step + C1_torch.grad.zero_() + C1_torch.data = torch.clamp(C1_torch, 0, 1) + + a1 = a1_torch.clone().detach().cpu().numpy() + C1 = C1_torch.clone().detach().cpu().numpy() + + return a1, C1, loss_iter + + +nb_nodes = 3 +a0_est2, C0_est2, loss_iter2 = graph_compession_gw(nb_nodes, C1, ot.unif(n), + nb_iter_max=100, lr=5e-2) + +pl.figure(4) +pl.plot(loss_iter2) +pl.title("Loss along iterations") + + +print("Estimated weights : ", a0_est2) +print("True proportions : ", ratio) + +pl.figure(6, (10, 3.5)) +pl.clf() +pl.subplot(1, 2, 1) +pl.imshow(P, vmin=0, vmax=1) +pl.title('True SBM P matrix') +pl.subplot(1, 2, 2) +pl.imshow(C0_est2, vmin=0, vmax=1) +pl.title('Estimated C0 matrix') +pl.colorbar() diff --git a/ot/__init__.py b/ot/__init__.py index f20332cc6..4292b4107 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -43,6 +43,8 @@ sinkhorn_unbalanced2) from .da import sinkhorn_lpl1_mm from .sliced import sliced_wasserstein_distance, max_sliced_wasserstein_distance +from .gromov import (gromov_wasserstein, gromov_wasserstein2, + gromov_barycenters, fused_gromov_wasserstein, fused_gromov_wasserstein2) # utils functions from .utils import dist, unif, tic, toc, toq diff --git a/ot/gromov.py b/ot/gromov.py index 465693dd9..ea667e414 100644 --- a/ot/gromov.py +++ b/ot/gromov.py @@ -174,7 +174,7 @@ def tensor_product(constC, hC1, hC2, T): def gwloss(constC, hC1, hC2, T): - """Return the Loss for Gromov-Wasserstein + r"""Return the Loss for Gromov-Wasserstein The loss is computed as described in Proposition 1 Eq. (6) in :ref:`[12] ` @@ -213,7 +213,7 @@ def gwloss(constC, hC1, hC2, T): def gwggrad(constC, hC1, hC2, T): - """Return the gradient for Gromov-Wasserstein + r"""Return the gradient for Gromov-Wasserstein The gradient is computed as described in Proposition 2 in :ref:`[12] ` @@ -247,7 +247,7 @@ def gwggrad(constC, hC1, hC2, T): def update_square_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -284,7 +284,7 @@ def update_square_loss(p, lambdas, T, Cs): def update_kl_loss(p, lambdas, T, Cs): - """ + r""" Updates :math:`\mathbf{C}` according to the KL Loss kernel with the `S` :math:`\mathbf{T}_s` couplings calculated at each iteration @@ -320,7 +320,7 @@ def update_kl_loss(p, lambdas, T, Cs): return nx.exp(tmpsum / ppt) -def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein transport between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -386,6 +386,14 @@ def gromov_wasserstein(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -398,13 +406,15 @@ def df(G): if log: res, log = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log['gw_dist'] = gwloss(constC, hC1, hC2, res) - return res, log + log['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, res), type_as=C10) + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log else: - return cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, 0, 1, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=False, **kwargs), type_as=C10) -def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwargs): +def gromov_wasserstein2(C1, C2, p, q, loss_fun='square_loss', log=False, armijo=False, **kwargs): r""" Returns the gromov-wasserstein discrepancy between :math:`(\mathbf{C_1}, \mathbf{p})` and :math:`(\mathbf{C_2}, \mathbf{q})` @@ -420,7 +430,11 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg - :math:`\mathbf{C_2}`: Metric cost matrix in the target space - :math:`\mathbf{p}`: distribution in the source space - :math:`\mathbf{q}`: distribution in the target space - - `L`: loss function to account for the misfit between the similarity matrices + - `L`: loss function to account for the misfit between the similarity + matrices + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -463,9 +477,21 @@ def gromov_wasserstein2(C1, C2, p, q, loss_fun, log=False, armijo=False, **kwarg metric approach to object matching. Foundations of computational mathematics 11.4 (2011): 417-487. + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. + """ p, q = list_to_array(p, q) + p0, q0, C10, C20 = p, q, C1, C2 + nx = get_backend(p0, q0, C10, C20) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -475,13 +501,28 @@ def f(G): def df(G): return gwggrad(constC, hC1, hC2, G) - res, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) - log_gw['gw_dist'] = gwloss(constC, hC1, hC2, res) - log_gw['T'] = res + + T, log_gw = cg(p, q, 0, 1, f, df, G0, log=True, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + + T0 = nx.from_numpy(T, type_as=C10) + + log_gw['gw_dist'] = nx.from_numpy(gwloss(constC, hC1, hC2, T), type_as=C10) + log_gw['u'] = nx.from_numpy(log_gw['u'], type_as=C10) + log_gw['v'] = nx.from_numpy(log_gw['v'], type_as=C10) + log_gw['T'] = T0 + + gw = log_gw['gw_dist'] + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + gw = nx.set_gradients(gw, (p0, q0, C10, C20), + (log_gw['u'], log_gw['v'], gC1, gC2)) + if log: - return log_gw['gw_dist'], log_gw + return gw, log_gw else: - return log_gw['gw_dist'] + return gw def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -548,6 +589,15 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -560,10 +610,16 @@ def df(G): if log: res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) - log['fgw_dist'] = log['loss'][::-1][0] - return res, log + + fgw_dist = nx.from_numpy(log['loss'][-1], type_as=C10) + + log['fgw_dist'] = fgw_dist + log['u'] = nx.from_numpy(log['u'], type_as=C10) + log['v'] = nx.from_numpy(log['v'], type_as=C10) + return nx.from_numpy(res, type_as=C10), log + else: - return cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs) + return nx.from_numpy(cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, **kwargs), type_as=C10) def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5, armijo=False, log=False, **kwargs): @@ -586,7 +642,11 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 - :math:`\mathbf{p}` and :math:`\mathbf{q}` are source and target weights (sum to 1) - `L` is a loss function to account for the misfit between the similarity matrices - The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[24] ` + The algorithm used for solving the problem is conditional gradient as + discussed in :ref:`[24] ` + + Note that when using backends, this loss function is differentiable wrt the + marices and weights for quadratic loss using the gradients from [38]_. Parameters ---------- @@ -627,9 +687,22 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', alpha=0.5 and Courty Nicolas "Optimal Transport for structured data with application on graphs" International Conference on Machine Learning (ICML). 2019. + + .. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, Online + Graph Dictionary Learning, International Conference on Machine Learning + (ICML), 2021. """ p, q = list_to_array(p, q) + p0, q0, C10, C20, M0 = p, q, C1, C2, M + nx = get_backend(p0, q0, C10, C20, M0) + + p = nx.to_numpy(p) + q = nx.to_numpy(q) + C1 = nx.to_numpy(C10) + C2 = nx.to_numpy(C20) + M = nx.to_numpy(M0) + constC, hC1, hC2 = init_matrix(C1, C2, p, q, loss_fun) G0 = p[:, None] * q[None, :] @@ -640,13 +713,27 @@ def f(G): def df(G): return gwggrad(constC, hC1, hC2, G) - res, log = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + T, log_fgw = cg(p, q, (1 - alpha) * M, alpha, f, df, G0, armijo=armijo, C1=C1, C2=C2, constC=constC, log=True, **kwargs) + + fgw_dist = nx.from_numpy(log_fgw['loss'][-1], type_as=C10) + + T0 = nx.from_numpy(T, type_as=C10) + + log_fgw['fgw_dist'] = fgw_dist + log_fgw['u'] = nx.from_numpy(log_fgw['u'], type_as=C10) + log_fgw['v'] = nx.from_numpy(log_fgw['v'], type_as=C10) + log_fgw['T'] = T0 + + if loss_fun == 'square_loss': + gC1 = nx.from_numpy(2 * C1 * (p[:, None] * p[None, :]) - 2 * T.dot(C2).dot(T.T)) + gC2 = nx.from_numpy(2 * C2 * (q[:, None] * q[None, :]) - 2 * T.T.dot(C1).dot(T)) + fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0), + (log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0)) + if log: - log['fgw_dist'] = log['loss'][::-1][0] - log['T'] = res - return log['fgw_dist'], log + return fgw_dist, log_fgw else: - return log['fgw_dist'] + return fgw_dist def GW_distance_estimation(C1, C2, p, q, loss_fun, T, @@ -1447,7 +1534,7 @@ def gromov_barycenters(N, Cs, ps, p, lambdas, loss_fun, def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_features=False, p=None, loss_fun='square_loss', max_iter=100, tol=1e-9, verbose=False, log=False, init_C=None, init_X=None, random_state=None): - """Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` + r"""Compute the fgw barycenter as presented eq (5) in :ref:`[24] ` Parameters ---------- @@ -1604,7 +1691,7 @@ def fgw_barycenters(N, Ys, Cs, ps, lambdas, alpha, fixed_structure=False, fixed_ def update_structure_matrix(p, lambdas, T, Cs): - """Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates :math:`\mathbf{C}` according to the L2 Loss kernel with the `S` :math:`\mathbf{T}_s` couplings. It is calculated at each iteration @@ -1640,7 +1727,7 @@ def update_structure_matrix(p, lambdas, T, Cs): def update_feature_matrix(lambdas, Ys, Ts, p): - """Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. + r"""Updates the feature with respect to the `S` :math:`\mathbf{T}_s` couplings. See "Solving the barycenter problem with Block Coordinate Descent (BCD)" diff --git a/ot/optim.py b/ot/optim.py index cc286b62d..bd8ca26e0 100644 --- a/ot/optim.py +++ b/ot/optim.py @@ -267,7 +267,7 @@ def cost(G): Mi += nx.min(Mi) # solve linear program - Gc = emd(a, b, Mi, numItermax=numItermaxEmd) + Gc, logemd = emd(a, b, Mi, numItermax=numItermaxEmd, log=True) deltaG = Gc - G @@ -297,6 +297,7 @@ def cost(G): print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, f_val, relative_delta_fval, abs_delta_fval)) if log: + log.update(logemd) return G, log else: return G diff --git a/test/test_gromov.py b/test/test_gromov.py index 509c54dbd..bcbcc3a94 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -9,6 +9,7 @@ import numpy as np import ot from ot.backend import NumpyBackend +from ot.backend import torch import pytest @@ -74,6 +75,42 @@ def test_gromov(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_gromov2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + + val = ot.gromov_wasserstein2(C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + + @pytest.skip_backend("jax", reason="test very slow with jax backend") def test_entropic_gromov(nx): n_samples = 50 # nb samples @@ -389,6 +426,45 @@ def test_fgw(nx): q, Gb.sum(0), atol=1e-04) # cf convergence gromov +def test_fgw2_gradients(): + n_samples = 50 # nb samples + + mu_s = np.array([0, 0]) + cov_s = np.array([[1, 0], [0, 1]]) + + xs = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=4) + + xt = ot.datasets.make_2D_samples_gauss(n_samples, mu_s, cov_s, random_state=5) + + p = ot.unif(n_samples) + q = ot.unif(n_samples) + + C1 = ot.dist(xs, xs) + C2 = ot.dist(xt, xt) + M = ot.dist(xs, xt) + + C1 /= C1.max() + C2 /= C2.max() + + if torch: + + p1 = torch.tensor(p, requires_grad=True) + q1 = torch.tensor(q, requires_grad=True) + C11 = torch.tensor(C1, requires_grad=True) + C12 = torch.tensor(C2, requires_grad=True) + M1 = torch.tensor(M, requires_grad=True) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1) + + val.backward() + + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert M1.shape == M1.grad.shape + + def test_fgw_barycenter(nx): np.random.seed(42)