Skip to content

Commit

Permalink
Merge branch 'sparsity_constrained' of https://github.com/liutianlin0…
Browse files Browse the repository at this point in the history
…121/POT into sparsity_constrained
  • Loading branch information
liutianlin0121 committed Apr 25, 2023
2 parents 9206c00 + 6107526 commit 11e07aa
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 7 deletions.
3 changes: 2 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
## 0.9.1dev

#### New features
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
- Make alpha parameter in Fused Gromov Wasserstein differentiable (PR #463)
- Added the sparsity-constrained OT solver to `ot.smooth` and added ` projection_sparse_simplex` to `ot.utils` (PR #459)
#### Closed issues

- Fix circleci-redirector action and codecov (PR #460)
Expand Down
6 changes: 5 additions & 1 deletion ot/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,10 +1694,12 @@ def backward(ctx, grad_output):
self.ValFunction = ValFunction

def _to_numpy(self, a):
if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
return np.array(a)
return a.cpu().detach().numpy()

def _from_numpy(self, a, type_as=None):
if isinstance(a, float):
if isinstance(a, float) or isinstance(a, int):
a = np.array(a)
if type_as is None:
return torch.from_numpy(a)
Expand Down Expand Up @@ -2501,6 +2503,8 @@ def __init__(self):
)

def _to_numpy(self, a):
if isinstance(a, float) or isinstance(a, int) or isinstance(a, np.ndarray):
return np.array(a)
return a.numpy()

def _from_numpy(self, a, type_as=None):
Expand Down
20 changes: 15 additions & 5 deletions ot/gromov/_gw.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
Information and Inference: A Journal of the IMA, 8(4), 757-787.
"""
p, q = list_to_array(p, q)
p0, q0, C10, C20, M0 = p, q, C1, C2, M
p0, q0, C10, C20, M0, alpha0 = p, q, C1, C2, M, alpha
if G0 is None:
nx = get_backend(p0, q0, C10, C20, M0)
else:
Expand All @@ -382,6 +382,7 @@ def fused_gromov_wasserstein(M, C1, C2, p, q, loss_fun='square_loss', symmetric=
C1 = nx.to_numpy(C10)
C2 = nx.to_numpy(C20)
M = nx.to_numpy(M0)
alpha = nx.to_numpy(alpha0)

if symmetric is None:
symmetric = np.allclose(C1, C1.T, atol=1e-10) and np.allclose(C2, C2.T, atol=1e-10)
Expand Down Expand Up @@ -535,10 +536,19 @@ def fused_gromov_wasserstein2(M, C1, C2, p, q, loss_fun='square_loss', symmetric
if loss_fun == 'square_loss':
gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T))
gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T))
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T))
if isinstance(alpha, int) or isinstance(alpha, float):
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T))
else:
lin_term = nx.sum(T * M)
gw_term = (fgw_dist - (1 - alpha) * lin_term) / alpha
fgw_dist = nx.set_gradients(fgw_dist, (p, q, C1, C2, M, alpha),
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T,
gw_term - lin_term))

if log:
return fgw_dist, log_fgw
Expand Down
27 changes: 27 additions & 0 deletions test/test_gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,8 @@ def test_gromov2_gradients():
if torch.cuda.is_available():
devices.append(torch.device("cuda"))
for device in devices:

# classical gradients
p1 = torch.tensor(p, requires_grad=True, device=device)
q1 = torch.tensor(q, requires_grad=True, device=device)
C11 = torch.tensor(C1, requires_grad=True, device=device)
Expand All @@ -226,6 +228,12 @@ def test_gromov2_gradients():
assert C12.shape == C12.grad.shape

# Test with armijo line-search
# classical gradients
p1 = torch.tensor(p, requires_grad=True, device=device)
q1 = torch.tensor(q, requires_grad=True, device=device)
C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)

q1.grad = None
p1.grad = None
C11.grad = None
Expand Down Expand Up @@ -830,6 +838,25 @@ def test_fgw2_gradients():
assert C12.shape == C12.grad.shape
assert M1.shape == M1.grad.shape

# full gradients with alpha
p1 = torch.tensor(p, requires_grad=True, device=device)
q1 = torch.tensor(q, requires_grad=True, device=device)
C11 = torch.tensor(C1, requires_grad=True, device=device)
C12 = torch.tensor(C2, requires_grad=True, device=device)
M1 = torch.tensor(M, requires_grad=True, device=device)
alpha = torch.tensor(0.5, requires_grad=True, device=device)

val = ot.fused_gromov_wasserstein2(M1, C11, C12, p1, q1, alpha=alpha)

val.backward()

assert val.device == p1.device
assert q1.shape == q1.grad.shape
assert p1.shape == p1.grad.shape
assert C11.shape == C11.grad.shape
assert C12.shape == C12.grad.shape
assert alpha.shape == alpha.grad.shape


def test_fgw_helper_backend(nx):
n_samples = 20 # nb samples
Expand Down

0 comments on commit 11e07aa

Please sign in to comment.