Skip to content

Commit

Permalink
[MRG] Center gradients for mass of emd2 and gw2 (#363)
Browse files Browse the repository at this point in the history
* center gradients for mass of emd2 and gw2

* debug fgw gradient

* debug fgw
  • Loading branch information
rflamary authored Apr 11, 2022
1 parent ac4cf44 commit 486b0d6
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 7 deletions.
4 changes: 3 additions & 1 deletion RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#### New features

- remode deprecated `ot.gpu` submodule (PR #361)
- Remove deprecated `ot.gpu` submodule (PR #361)
- Update examples in the gallery (PR #359).
- Add stochastic loss and OT plan computation for regularized OT and
backend examples(PR #360).
Expand All @@ -23,6 +23,8 @@

#### Closed issues

- Fix mass gradient of `ot.emd2` and `ot.gromov_wasserstein2` so that they are
centered (Issue #364, PR #363)
- Fix bug in instantiating an `autograd` function `ValFunction` (Issue #337,
PR #338)
- Fix POT ABI compatibility with old and new numpy (Issue #346, PR #349)
Expand Down
7 changes: 5 additions & 2 deletions ot/gromov.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,7 +551,8 @@ def df(G):
gC1 = nx.from_numpy(gC1, type_as=C10)
gC2 = nx.from_numpy(gC2, type_as=C10)
gw = nx.set_gradients(gw, (p0, q0, C10, C20),
(log_gw['u'], log_gw['v'], gC1, gC2))
(log_gw['u'] - nx.mean(log_gw['u']),
log_gw['v'] - nx.mean(log_gw['v']), gC1, gC2))

if log:
return gw, log_gw
Expand Down Expand Up @@ -793,7 +794,9 @@ def df(G):
gC1 = nx.from_numpy(gC1, type_as=C10)
gC2 = nx.from_numpy(gC2, type_as=C10)
fgw_dist = nx.set_gradients(fgw_dist, (p0, q0, C10, C20, M0),
(log_fgw['u'], log_fgw['v'], alpha * gC1, alpha * gC2, (1 - alpha) * T0))
(log_fgw['u'] - nx.mean(log_fgw['u']),
log_fgw['v'] - nx.mean(log_fgw['v']),
alpha * gC1, alpha * gC2, (1 - alpha) * T0))

if log:
return fgw_dist, log_fgw
Expand Down
7 changes: 4 additions & 3 deletions ot/lp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@ def f(b):
log['warning'] = result_code_string
log['result_code'] = result_code
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
(a0, b0, M0), (log['u'], log['v'], G))
(a0, b0, M0), (log['u'] - nx.mean(log['u']),
log['v'] - nx.mean(log['v']), G))
return [cost, log]
else:
def f(b):
Expand All @@ -540,8 +541,8 @@ def f(b):
)
G = nx.from_numpy(G, type_as=type_as)
cost = nx.set_gradients(nx.from_numpy(cost, type_as=type_as),
(a0, b0, M0), (nx.from_numpy(u, type_as=type_as),
nx.from_numpy(v, type_as=type_as), G))
(a0, b0, M0), (nx.from_numpy(u - np.mean(u), type_as=type_as),
nx.from_numpy(v - np.mean(v), type_as=type_as), G))

check_result(result_code)
return cost
Expand Down
8 changes: 7 additions & 1 deletion test/test_ot.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,14 +147,20 @@ def test_emd2_gradients():
b1 = torch.tensor(a, requires_grad=True)
M1 = torch.tensor(M, requires_grad=True)

val = ot.emd2(a1, b1, M1)
val, log = ot.emd2(a1, b1, M1, log=True)

val.backward()

assert a1.shape == a1.grad.shape
assert b1.shape == b1.grad.shape
assert M1.shape == M1.grad.shape

assert np.allclose(a1.grad.cpu().detach().numpy(),
log['u'].cpu().detach().numpy() - log['u'].cpu().detach().numpy().mean())

assert np.allclose(b1.grad.cpu().detach().numpy(),
log['v'].cpu().detach().numpy() - log['v'].cpu().detach().numpy().mean())

# Testing for bug #309, checking for scaling of gradient
a2 = torch.tensor(a, requires_grad=True)
b2 = torch.tensor(a, requires_grad=True)
Expand Down

0 comments on commit 486b0d6

Please sign in to comment.