Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] BAPG solvers for GW and FGW #581

Merged
merged 10 commits into from
Nov 30, 2023
Merged
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,7 @@ distances between Gaussian distributions](https://hal.science/hal-03197398v2/fil
[61] Charlier, B., Feydy, J., Glaunes, J. A., Collin, F. D., & Durif, G. (2021). [Kernel operations on the gpu, with autodiff, without memory overflows](https://www.jmlr.org/papers/volume22/20-275/20-275.pdf). The Journal of Machine Learning Research, 22(1), 3457-3462.

[62] H. Van Assel, C. Vincent-Cuaz, T. Vayer, R. Flamary, N. Courty (2023). [Interpolating between Clustering and Dimensionality Reduction with Gromov-Wasserstein](https://arxiv.org/pdf/2310.03398.pdf). NeurIPS 2023 Workshop Optimal Transport and Machine Learning.

[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J. (2022). [A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in Graph Data](https://openreview.net/pdf?id=0jxPyVWmiiF). In The Eleventh International Conference on Learning Representations.

[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W. (2023). [Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications](https://openreview.net/pdf?id=uqkUguNu40). In Thirty-seventh Conference on Neural Information Processing Systems.
2 changes: 2 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@
+ Wrapper for `geomloss`` solver on empirical samples (PR #571)
+ Add `stop_criterion` feature to (un)regularized (f)gw barycenter solvers (PR #578)
+ Add `fixed_structure` and `fixed_features` to entropic fgw barycenter solver (PR #578)
+ Add new BAPG solvers with KL projections for GW and FGW (PR #581)
+ Add Bures-Wasserstein barycenter in `ot.gaussian` (PR #582)


#### Closed issues
- Fix line search evaluating cost outside of the interpolation range (Issue #502, PR #504)
- Lazily instantiate backends to avoid unnecessary GPU memory pre-allocations on package import (Issue #516, PR #520)
Expand Down
133 changes: 89 additions & 44 deletions examples/gromov/plot_fgw_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
==============================

This example illustrates the computation of FGW for attributed graphs
using 3 different solvers to estimate the distance based on Conditional
Gradient [24] or Sinkhorn projections [12, 51].
using 4 different solvers to estimate the distance based on Conditional
Gradient [24], Sinkhorn projections [12, 51] and alternated Bregman
projections [63, 64].

We generate two graphs following Stochastic Block Models further endowed with
node features and compute their FGW matchings.
Expand All @@ -23,6 +24,16 @@
[51] Xu, H., Luo, D., Zha, H., & Duke, L. C. (2019).
"Gromov-wasserstein learning for graph matching and node embedding".
In International Conference on Machine Learning (ICML), 2019.

[63] Li, J., Tang, J., Kong, L., Liu, H., Li, J., So, A. M. C., & Blanchet, J.
"A Convergent Single-Loop Algorithm for Relaxation of Gromov-Wasserstein in
Graph Data". International Conference on Learning Representations (ICLR), 2023.

[64] Ma, X., Chu, X., Wang, Y., Lin, Y., Zhao, J., Ma, L., & Zhu, W.
"Fused Gromov-Wasserstein Graph Mixup for Graph-level Classifications".
In Thirty-seventh Conference on Neural Information Processing Systems
(NeurIPS), 2023.

"""

# Author: Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
Expand All @@ -33,9 +44,12 @@

import numpy as np
import matplotlib.pylab as pl
from ot.gromov import fused_gromov_wasserstein, entropic_fused_gromov_wasserstein
from ot.gromov import (fused_gromov_wasserstein,
entropic_fused_gromov_wasserstein,
BAPG_fused_gromov_wasserstein)
import networkx
from networkx.generators.community import stochastic_block_model as sbm
from time import time

#############################################################################
#
Expand Down Expand Up @@ -85,34 +99,59 @@


# Conditional Gradient algorithm
fgw0, log0 = fused_gromov_wasserstein(
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, verbose=True, log=True)
print('Conditional Gradient \n')
start_cg = time()
T_cg, log_cg = fused_gromov_wasserstein(
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, tol_rel=1e-9,
verbose=True, log=True)
end_cg = time()
time_cg = 1000 * (end_cg - start_cg)

# Proximal Point algorithm with Kullback-Leibler as proximal operator
fgw, log = entropic_fused_gromov_wasserstein(
print('Proximal Point Algorithm \n')
start_ppa = time()
T_ppa, log_ppa = entropic_fused_gromov_wasserstein(
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1., solver='PPA',
log=True, verbose=True, warmstart=False, numItermax=10)
tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10)
end_ppa = time()
time_ppa = 1000 * (end_ppa - start_ppa)

# Projected Gradient algorithm with entropic regularization
fgwe, loge = entropic_fused_gromov_wasserstein(
print('Projected Gradient Descent \n')
start_pgd = time()
T_pgd, log_pgd = entropic_fused_gromov_wasserstein(
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=0.01, solver='PGD',
log=True, verbose=True, warmstart=False, numItermax=10)

print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log0['fgw_dist']))
print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log['fgw_dist']))
print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(loge['fgw_dist']))
tol=1e-9, log=True, verbose=True, warmstart=False, numItermax=10)
end_pgd = time()
time_pgd = 1000 * (end_pgd - start_pgd)

# Alternated Bregman Projected Gradient algorithm with Kullback-Leibler as proximal operator
print('Bregman Alternated Projected Gradient \n')
start_bapg = time()
T_bapg, log_bapg = BAPG_fused_gromov_wasserstein(
M, C2, C3, h2, h3, 'square_loss', alpha=alpha, epsilon=1.,
tol=1e-9, marginal_loss=True, verbose=True, log=True)
end_bapg = time()
time_bapg = 1000 * (end_bapg - start_bapg)

print('Fused Gromov-Wasserstein distance estimated with Conditional Gradient solver: ' + str(log_cg['fgw_dist']))
print('Fused Gromov-Wasserstein distance estimated with Proximal Point solver: ' + str(log_ppa['fgw_dist']))
print('Entropic Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_pgd['fgw_dist']))
print('Fused Gromov-Wasserstein distance estimated with Projected Gradient solver: ' + str(log_bapg['fgw_dist']))

# compute OT sparsity level
fgw0_sparsity = 100 * (fgw0 == 0.).astype(np.float64).sum() / (N2 * N3)
fgw_sparsity = 100 * (fgw == 0.).astype(np.float64).sum() / (N2 * N3)
fgwe_sparsity = 100 * (fgwe == 0.).astype(np.float64).sum() / (N2 * N3)
T_cg_sparsity = 100 * (T_cg == 0.).astype(np.float64).sum() / (N2 * N3)
T_ppa_sparsity = 100 * (T_ppa == 0.).astype(np.float64).sum() / (N2 * N3)
T_pgd_sparsity = 100 * (T_pgd == 0.).astype(np.float64).sum() / (N2 * N3)
T_bapg_sparsity = 100 * (T_bapg == 0.).astype(np.float64).sum() / (N2 * N3)

# Methods using Sinkhorn projections tend to produce feasibility errors on the
# Methods using Sinkhorn/Bregman projections tend to produce feasibility errors on the
# marginal constraints

err0 = np.linalg.norm(fgw0.sum(1) - h2) + np.linalg.norm(fgw0.sum(0) - h3)
err = np.linalg.norm(fgw.sum(1) - h2) + np.linalg.norm(fgw.sum(0) - h3)
erre = np.linalg.norm(fgwe.sum(1) - h2) + np.linalg.norm(fgwe.sum(0) - h3)
err_cg = np.linalg.norm(T_cg.sum(1) - h2) + np.linalg.norm(T_cg.sum(0) - h3)
err_ppa = np.linalg.norm(T_ppa.sum(1) - h2) + np.linalg.norm(T_ppa.sum(0) - h3)
err_pgd = np.linalg.norm(T_pgd.sum(1) - h2) + np.linalg.norm(T_pgd.sum(0) - h3)
err_bapg = np.linalg.norm(T_bapg.sum(1) - h2) + np.linalg.norm(T_bapg.sum(0) - h3)

#############################################################################
#
Expand Down Expand Up @@ -242,46 +281,52 @@ def draw_transp_colored_GW(G1, C1, G2, C2, part_G1, p1, p2, T,
seed_G2 = 0
seed_G3 = 4

pl.figure(2, figsize=(12, 3.5))
pl.figure(2, figsize=(15, 3.5))
pl.clf()
pl.subplot(131)
pl.subplot(141)
pl.axis('off')
pl.axis
pl.title('(CG algo) FGW=%s \n \n OT sparsity = %s \n feasibility error = %s' % (
np.round(log0['fgw_dist'], 3), str(np.round(fgw0_sparsity, 2)) + ' %',
np.round(err0, 4)), fontsize=fontsize)

p0, q0 = fgw0.sum(1), fgw0.sum(0) # check marginals
pl.title('(CG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
np.round(log_cg['fgw_dist'], 3), str(np.round(T_cg_sparsity, 2)) + ' %',
np.round(err_cg, 4), str(np.round(time_cg, 2)) + ' ms'), fontsize=fontsize)

pos1, pos2 = draw_transp_colored_GW(
weightedG2, C2, weightedG3, C3, part_G2, p1=p0, p2=q0, T=fgw0,
shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)
weightedG2, C2, weightedG3, C3, part_G2, p1=T_cg.sum(1), p2=T_cg.sum(0),
T=T_cg, shiftx=1.5, node_size=node_size, seed_G1=seed_G2, seed_G2=seed_G3)

pl.subplot(132)
pl.subplot(142)
pl.axis('off')

p, q = fgw.sum(1), fgw.sum(0) # check marginals

pl.title('(PP algo) FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
np.round(log['fgw_dist'], 3), str(np.round(fgw_sparsity, 2)) + ' %',
np.round(err, 4)), fontsize=fontsize)
pl.title('(PPA) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
np.round(log_ppa['fgw_dist'], 3), str(np.round(T_ppa_sparsity, 2)) + ' %',
np.round(err_ppa, 4), str(np.round(time_ppa, 2)) + ' ms'), fontsize=fontsize)

pos1, pos2 = draw_transp_colored_GW(
weightedG2, C2, weightedG3, C3, part_G2, p1=p, p2=q, T=fgw,
pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
weightedG2, C2, weightedG3, C3, part_G2, p1=T_ppa.sum(1), p2=T_ppa.sum(0),
T=T_ppa, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)

pl.subplot(133)
pl.subplot(143)
pl.axis('off')

pe, qe = fgwe.sum(1), fgwe.sum(0) # check marginals
pl.title('(PGD) Entropic FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
np.round(log_pgd['fgw_dist'], 3), str(np.round(T_pgd_sparsity, 2)) + ' %',
np.round(err_pgd, 4), str(np.round(time_pgd, 2)) + ' ms'), fontsize=fontsize)

pos1, pos2 = draw_transp_colored_GW(
weightedG2, C2, weightedG3, C3, part_G2, p1=T_pgd.sum(1), p2=T_pgd.sum(0),
T=T_pgd, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)


pl.subplot(144)
pl.axis('off')

pl.title('Entropic FGW=%s\n \n OT sparsity = %s \n feasibility error = %s' % (
np.round(loge['fgw_dist'], 3), str(np.round(fgwe_sparsity, 2)) + ' %',
np.round(erre, 4)), fontsize=fontsize)
pl.title('(BAPG) FGW=%s\n \n OT sparsity = %s \n marg. error = %s \n runtime = %s' % (
np.round(log_bapg['fgw_dist'], 3), str(np.round(T_bapg_sparsity, 2)) + ' %',
np.round(err_bapg, 4), str(np.round(time_bapg, 2)) + ' ms'), fontsize=fontsize)

pos1, pos2 = draw_transp_colored_GW(
weightedG2, C2, weightedG3, C3, part_G2, p1=pe, p2=qe, T=fgwe,
pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)
weightedG2, C2, weightedG3, C3, part_G2, p1=T_bapg.sum(1), p2=T_bapg.sum(0),
T=T_bapg, pos1=pos1, pos2=pos2, shiftx=0., node_size=node_size, seed_G1=0, seed_G2=0)

pl.tight_layout()

Expand Down
8 changes: 7 additions & 1 deletion ot/gromov/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,13 @@

from ._bregman import (entropic_gromov_wasserstein,
entropic_gromov_wasserstein2,
BAPG_gromov_wasserstein,
BAPG_gromov_wasserstein2,
entropic_gromov_barycenters,
entropic_fused_gromov_wasserstein,
entropic_fused_gromov_wasserstein2,
BAPG_fused_gromov_wasserstein,
BAPG_fused_gromov_wasserstein2,
entropic_fused_gromov_barycenters)

from ._estimators import (GW_distance_estimation, pointwise_gromov_wasserstein,
Expand All @@ -49,8 +53,10 @@
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
'BAPG_gromov_wasserstein', 'BAPG_gromov_wasserstein2',
'entropic_gromov_barycenters', 'entropic_fused_gromov_wasserstein',
'entropic_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters',
'entropic_fused_gromov_wasserstein2', 'BAPG_fused_gromov_wasserstein',
'BAPG_fused_gromov_wasserstein2', 'entropic_fused_gromov_barycenters',
'GW_distance_estimation', 'pointwise_gromov_wasserstein', 'sampled_gromov_wasserstein',
'semirelaxed_gromov_wasserstein', 'semirelaxed_gromov_wasserstein2',
'semirelaxed_fused_gromov_wasserstein', 'semirelaxed_fused_gromov_wasserstein2',
Expand Down
Loading
Loading