diff --git a/RELEASES.md b/RELEASES.md index 28635ff8c..b18fdc3b5 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -3,7 +3,8 @@ ## 0.9.1dev #### New features - - Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459) +- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463) +- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459) #### Closed issues - Fix circleci-redirector action and codecov (PR #460) diff --git a/ot/backend.py b/ot/backend.py index 74f8366e1..0dd6fb842 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -1694,10 +1694,12 @@ def backward(ctx, grad_output): self.ValFunction = ValFunction def _to_numpy(self, a): + if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray): + return np.array(a) return a.cpu().detach().numpy() def _from_numpy(self, a, type_as=None): - if isinstance(a, float): + if isinstance(a, float) or isinstance(a, int): a = np.array(a) if type_as is None: return torch.from_numpy(a) @@ -2501,6 +2503,8 @@ def __init__(self): ) def _to_numpy(self, a): + if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray): + return np.array(a) return a.numpy() def _from_numpy(self, a, type_as=None): diff --git a/ot/gromov/_gw.py b/ot/gromov/_gw.py index c6e40764d..bc4719dd3 100644 --- a/ot/gromov/_gw.py +++ b/ot/gromov/_gw.py @@ -370,7 +370,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= Information and Inference: A Journal of the IMA, 8(4), 757-787. """ p, q = list_to_array(p, q) - p0, q0, C10, C20, M0 = p, q, C1, C2, M + p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha if G0 is None: nx = get_backend(p0, q0, C10, C20, M0) else: @@ -382,6 +382,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric= C1 = nx.to_numpy(C10) C2 = nx.to_numpy(C20) M = nx.to_numpy(M0) + alpha = nx.to_numpy(alpha0) if symmetric is None: symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10) @@ -535,10 +536,19 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) - fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), - (log_fgw['u'] - nx.mean(log_fgw['u']), - log_fgw['v'] - nx.mean(log_fgw['v']), - alpha * gC1, alpha * gC2, (1 - alpha) * T)) + if isinstance(alpha, int) or isinstance(alpha, float): + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T)) + else: + lin_term = nx.sum(T * M) + gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha + fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha), + (log_fgw['u'] - nx.mean(log_fgw['u']), + log_fgw['v'] - nx.mean(log_fgw['v']), + alpha * gC1, alpha * gC2, (1 - alpha) * T, + gw_term - lin_term)) if log: return fgw_dist, log_fgw diff --git a/test/test_gromov.py b/test/test_gromov.py index 80b6df425..f70f410ad 100644 --- a/test/test_gromov.py +++ b/test/test_gromov.py @@ -209,6 +209,8 @@ def test_gromov2_gradients(): if torch.cuda.is_available(): devices.append(torch.device("cuda")) for device in devices: + + # classical gradients p1 = torch.tensor(p, requires_grad=True, device=device) q1 = torch.tensor(q, requires_grad=True, device=device) C11 = torch.tensor(C1, requires_grad=True, device=device) @@ -226,6 +228,12 @@ def test_gromov2_gradients(): assert C12.shape == C12.grad.shape # Test with armijo line-search + # classical gradients + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + q1.grad = None p1.grad = None C11.grad = None @@ -830,6 +838,25 @@ def test_fgw2_gradients(): assert C12.shape == C12.grad.shape assert M1.shape == M1.grad.shape + # full gradients with alpha + p1 = torch.tensor(p, requires_grad=True, device=device) + q1 = torch.tensor(q, requires_grad=True, device=device) + C11 = torch.tensor(C1, requires_grad=True, device=device) + C12 = torch.tensor(C2, requires_grad=True, device=device) + M1 = torch.tensor(M, requires_grad=True, device=device) + alpha = torch.tensor(0.5, requires_grad=True, device=device) + + val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1, alpha=alpha) + + val.backward() + + assert val.device == p1.device + assert q1.shape == q1.grad.shape + assert p1.shape == p1.grad.shape + assert C11.shape == C11.grad.shape + assert C12.shape == C12.grad.shape + assert alpha.shape == alpha.grad.shape + def test_fgw_helper_backend(nx): n_samples = 20 # nb samples